ruCLIP Base [vit-base-patch32-384]Бесплатно
Russian Contrastive Language–Image Pre-training. Модель-ранжировщик текстов и изображений, 150 млн параметров.
ruCLIP - это мультимодальная модель для ранжирования изображений и подписей к ним, а также получения семантической близости изображений и текстов. Архитектура впервые представлена OpenAI.
Apache 2.0
0.6 GB
0.1
ruCLIP (Russian Contrastive Language – Image Pre-training) обучена для русского языка на открытых данных, собранных из Рунета.
240 миллионов уникальных пар картинка-текст в обучающей выборке.
Информация об использовании модели:
ruCLIP — это модель, состоящая из двух частей (или нейронных сетей):
- Image Encoder — часть для кодирования изображений и перевода их в общее векторное пространство. В качестве архитектуры в оригинальной работе берутся ResNet разных размеров и Visual Transformer — тоже разных размеров. В ruCLIP Base в качестве image encoder используется ViT-B/32.
- Text Encoder — часть для кодирования текстов и перевода их в общее векторное пространство. В качестве архитектуры используется текстовый Transformer.
Similarity
[{"собачка": 0.9324831366539001}, {"кошка": 0.02790665067732334}, {"мышка": 0.029953204095363617}, {"машина": 0.0034359991550445557}, {"стол": 0.0031091528944671154}, {"дом": 0.0018060706788673997}, {"жидкость": 0.0013057001633569598}]
KFServing
Класс KFServingRuClipModel
представлен ниже.
Вы можете подавать на вход модели:
- ссылки на изображения
- картинки в формате base64
На выходе модель покажет близость между текстами и картинками. Чем ближе значение к 1, тем ближе семантическое сходство картинки и текста.
import os
import json
from collections import OrderedDict
from typing import Dict
import kfserving
import requests
import torch
import numpy as np
import io
from io import BytesIO
import base64
from PIL import Image
import re
from ruclip import CLIP, RuCLIPProcessor
def open_images_base64(img_strs):
return [Image.open(BytesIO(base64.b64decode(img_str))) for img_str in img_strs]
def open_image_link(links):
imgs = []
for img_link in links:
response = requests.get(img_link)
imgs.append(Image.open(BytesIO(response.content)))
return imgs
def create_image(sim_plt):
my_stringIObytes = io.BytesIO()
sim_plt.savefig(my_stringIObytes, format="jpg")
my_stringIObytes.seek(0)
my_base64 = base64.b64encode(my_stringIObytes.read())
return my_base64
class KFServingRuClipModel(kfserving.KFModel):
def __init__(self, title: str, model_path="./ruclip-vit-base-patch32-384"):
super().__init__(name)
self.name = name
self.ready = False
self.model_path = model_path
def load(self):
self.device = "cuda"
self.clip = CLIP.from_pretrained(self.model_path).eval().to(self.device)
self.clip_processor = RuCLIPProcessor.from_pretrained(self.model_path)
self.ready = True
def get_text_latents(self, texts):
with torch.no_grad():
inputs = self.clip_processor(text=texts, images=None)
text_latents = self.clip.encode_text(
input_ids=inputs["input_ids"].to(self.device),
)
text_latents = text_latents / text_latents.norm(dim=-1, keepdim=True)
return text_latents
def get_logits(self, text_latents, pil_images):
with torch.no_grad():
inputs = self.clip_processor(text=None, images=pil_images)
image_latents = self.clip.encode_image(
pixel_values=inputs["pixel_values"].to(self.device)
)
image_latents = image_latents / image_latents.norm(dim=-1, keepdim=True)
logits_per_text = torch.matmul(text_latents, image_latents.t())
logits_per_image = logits_per_text.t()
return logits_per_text, logits_per_image
def get_similarity_scores(self, texts, images):
"""
Find the most similar image to text.
`texts`: array of texts or one text ["some_desc"]
`images`: array of images.
"""
text_latents = self.get_text_latents(texts)
results = []
for pil_image in images:
_, logits_per_image = self.get_logits(text_latents, [pil_image])
probs_raw = (logits_per_image * self.clip.logit_scale.exp().detach()).softmax(dim=-1)[0]
label_id = probs_raw.argmax().item()
confidence = probs_raw.max().item()
probs = []
for i in range(len(texts)):
probs.append({texts[i]: probs_raw[i].item()})
buffered = BytesIO()
pil_image.save(buffered, format="JPEG")
img = base64.b64encode(buffered.getvalue()).decode("utf-8")
results.append({"image": img, "text": texts[label_id], "confidence" : confidence, "all_res": probs})
return results
def predict(self, request: Dict) -> Dict:
texts = request["instances"][0]["texts"]
img_strs = request["instances"][0].get("images", None)
if img_strs is not None:
images = open_images_base64(img_strs)
images_links = request["instances"][0].get("image_links", None)
if images_links is not None:
images = open_image_link(images_links)
error_msg = None
predictions = []
try:
predictions = self.get_similarity_scores(texts, images)
except Exception as ex:
print(ex)
error_msg = ex
if error_msg is not None:
return {"predictions": predictions, "error_message": str(error_msg)}
else:
return {"predictions": predictions}
Функция predict возвращает массив со словарями для каждой картинки вида:
predictions: [
{
"image": "base64 image",
"text": "наиболее подходящий текст для картинки" ,
"confidence": "мера близости для лучшей картинки",
"all_res": [{"менее вероятный текст": 0.31}, {"другой текст": 0.33},} ...]
}
]
Пример работы с моделью
!pip install -r requirements.txt
from kfserving_ru_CLIP import KFServingRuClipModel
model = KFServingRuClipModel("kfserving-clip")
model.load()
# zero-shot and links
url_cat = "https://cs11.livemaster.ru/storage/topic/NxN/2c/9b/9cf0a41d13ecb11439e6145dff576315df83op.jpg?h=3KvOPndE06tlraLLSmkHPQ"
url_cat2 = "https://pbs.twimg.com/profile_images/560798448962633728/rDEdUfV_.jpeg"
url_dog = "https://ichef.bbci.co.uk/news/640/cpsprodpb/475B/production/_98776281_gettyimages-521697453.jpg"
result = model.predict({"instances": [{
"texts": ["собачка", "кошка", "мышка", "машина", "стол", "дом", "жидкость"],
"image_links": [url_dog, url_cat, url_cat2]
}]
})
"Res: ", result
`{"predictions": "[{"image": b"/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQ...", "text": "собачка", "confidence": 0.9324831366539001, "all_res": [{"собачка": 0.9324831366539001}, {"кошка": 0.02790665067732334}, {"мышка": 0.029953204095363617}, {"машина": 0.0034359991550445557}, {"стол": 0.0031091528944671154}, {"дом": 0.0018060706788673997}, {"жидкость": 0.0013057001633569598}]}, ... ]", "error_message": None}`
Оценки модели на популярных датасетах
Косинусная близость между текстами и картинками для модели ruCLIP
Предсказания топ 5 классов для изображений с помощью ruCLIP
Сравнение моделей на задаче zero-shot классификации для разных датасетов. Жирным выделена лучшая метрика для каждого из датасетов без учета оригинального CLIP без переводчика.
ruCLIP Small [rugpt3-small] | ruCLIP Base [vit-base-patch32-384] | CLIP [vit-base-patch16-224] original + MT | CLIP [vit-base-patch16-224] original | |
---|---|---|---|---|
Food101 | 0.138 | 0.642 | 0.663 | 0.883 |
CIFAR10 | 0.808 | 0.862 | 0.859 | 0.893 |
CIFAR100 | 0.440 | 0.529 | 0.603 | 0.647 |
Birdsnap | 0.0360 | 0.161 | 0.126 | 0.396 |
SUN397 | 0.258 | 0.510 | 0.447 | 0.631 |
Stanford Cars | 0.023 | 0.572 | 0.567 | 0.637 |
DTD | 0.169 | 0.390 | 0.243 | 0.432 |
MNIST | 0.137 | 0.404 | 0.559 | 0.559 |
STL10 | 0.910 | 0.946 | 0.967 | 0.970 |
PCam | 0.484 | 0.506 | 0.603 | 0.573 |
CLEVR | 0.104 | 0.188 | 0.240 | 0.240 |
Rendered SST2 | 0.483 | 0.508 | 0.484 | 0.484 |
FGVC Aircraft | 0.020 | 0.053 | 0.220 | 0.244 |
Oxford Pets | 0.462 | 0.587 | 0.507 | 0.874 |
Caltech101 | 0.59 | 0.834 | 0.791 | 0.883 |
HatefulMemes | 0.527 | 0.537 | 0.579 | 0.589 |
ImageNet | 0.538 | 0.451 | 0.392 | 0.638 |
Flowers102 | 0.063 | 0.449 | 0.357 | 0.697 |
Звездочками показана средняя zero-shot оценка моделей на 16 датасетах. Также, как и в статье, на признаках, которые достает CLIP для изображений были обучены логистические регрессии с использованием 1-2-4-8-16 изображений для каждого класса. Поскольку признаки, которые извлекаются у openai и openai_mt одинаковые — для openai_mt нет отдельного графика few-shot классификации. Также мы посчитали усредненный few-shot график для модели ruCLIP Base без учета трех датасетов - PCam, Oxford Pets и FGVC Aircraft, на которых модель проигрывает заметнее остальных можно видеть (пунктирная линия), что среднее качество становится лучше в сравнении с ruCLIP Small.
То же самое, но отдельно по каждому датасету.
Сравнение linear-prob метрики для трех моделей на разных датасетах.
ruCLIP Small [rugpt3-small] | ruCLIP Base [vit-base-patch32-384] | CLIP [vit-base-patch16-224] original | |
---|---|---|---|
Food101 | 0.874 | 0.851 | 0.901 |
CIFAR10 | 0.948 | 0.934 | 0.953 |
CIFAR100 | 0.794 | 0.745 | 0.808 |
Birdsnap | 0.584 | 0.434 | 0.664 |
SUN397 | 0.753 | 0.721 | 0.777 |
Stanford Cars | 0.806 | 0.766 | 0.866 |
DTD | 0.738 | 0.703 | 0.770 |
MNIST | 0.985 | 0.965 | 0.989 |
STL10 | 0.977 | 0.968 | 0.982 |
PCam | 0.833 | 0.835 | 0.830 |
CLEVR | 0.524 | 0.308 | 0.604 |
Rendered SST2 | 0.568 | 0.651 | 0.606 |
FGVC Aircraft | 0.500 | 0.283 | 0.604 |
Oxford Pets | 0.895 | 0.730 | 0.931 |
Caltech101 | 0.937 | 0.922 | 0.956 |
HatefulMemes | 0.638 | 0.581 | 0.645 |
Здесь указаны графики корреляции zero-shot и linear-prob результатов для разных моделей.
ruCLIP Base [vit-base-patch32-384]
ruCLIP Small [rugpt3-small]
CLIP [vit-base-patch16-224] original + MT
CLIP [vit-base-patch16-224] original
Сравнение разных моделей на ImageNet датасетах
resnet101 | CLIP [vit-base-patch16-224] original | CLIP [vit-base-patch16-224] original + MT | ruCLIP Base [vit-base-patch32-384] | ruCLIP Small [rugpt3-small] | |
---|---|---|---|---|---|
ImageNet | 0.739 | 0.638 | 0.392 | 0.451 | 0.538 |
ImageNetV2 | 0.618 | 0.582 | 0.353 | 0.389 | 0.458 |
ImageNet-R | 0.272 | 0.490 | 0.353 | 0.473 | 0.241 |
ImageNet-A | 0.022 | 0.265 | 0.157 | 0.114 | 0.080 |
ImageNet-Sketch | 0.265 | 0.448 | 0.291 | 0.374 | 0.251 |
Zero-shot классификация для разных датасетов на моделях ruCLIP.
ruCLIP Base [vit-base-patch32-384] | ruCLIP Large [vit-large-patch14-224] | ruCLIP Large [vit-large-patch14-336] exclusive | ruCLIP Base [vit-base-patch16-384] exclusive | |
---|---|---|---|---|
Food101, acc | 0.642 | 0.597 | 0.712 💥 | 0.689 |
CIFAR10, acc | 0.862 | 0.878 | 0.906 💥 | 0.845 |
CIFAR100, acc | 0.529 | 0.511 | 0.591 💥 | 0.569 |
Birdsnap, acc | 0.161 | 0.172 | 0.213 💥 | 0.195 |
SUN397, acc | 0.510 | 0.484 | 0.523 💥 | 0.521 |
Stanford Cars, acc | 0.572 | 0.559 | 0.659 💥 | 0.626 |
DTD, acc | 0.390 | 0.370 | 0.408 | 0.421 💥 |
MNIST, acc | 0.404 | 0.337 | 0.242 | 0.478 💥 |
STL10, acc | 0.946 | 0.934 | 0.956 | 0.964 💥 |
PCam, acc | 0.506 | 0.520 | 0.554 💥 | 0.501 |
CLEVR, acc | 0.188 💥 | 0.152 | 0.142 | 0.132 |
Rendered SST2, acc | 0.508 | 0.529 | 0.539 💥 | 0.525 |
ImageNet, acc | 0.451 | 0.426 | 0.488 💥 | 0.482 |
FGVC Aircraft, mean-per-class | 0.053 | 0.046 | 0.075 💥 | 0.046 |
Oxford Pets, mean-per-class | 0.587 | 0.604 | 0.546 | 0.635 💥 |
Caltech101, mean-per-class | 0.834 | 0.777 | 0.835 💥 | 0.835 💥 |
Flowers102, mean-per-class | 0.449 | 0.455 | 0.517 💥 | 0.452 |
Hateful Memes, roc-auc | 0.537 | 0.530 | 0.519 | 0.543💥 |
Few-shot классификация для разных датасетов на моделях ruCLIP.
ruCLIP Base [vit-base-patch32-384] | ruCLIP Large [vit-large-patch14-224] | ruCLIP Large [vit-large-patch14-336] exclusive | ruCLIP Base [vit-base-patch16-384] exclusive | |
---|---|---|---|---|
Food101 | 0.851 | 0.840 | 0.896 💥 | 0.890 |
CIFAR10 | 0.934 | 0.927 | 0.943 💥 | 0.942 |
CIFAR100 | 0.745 | 0.734 | 0.770 | 0.773 💥 |
Birdsnap | 0.434 | 0.567 | 0.609 | 0.612 💥 |
SUN397 | 0.721 | 0.731 | 0.759 💥 | 0.758 |
Stanford Cars | 0.766 | 0.797 | 0.831 | 0.840 💥 |
DTD | 0.703 | 0.711 | 0.731 | 0.749 💥 |
MNIST | 0.965 | 0.949 | 0.949 | 0.971 💥 |
STL10 | 0.968 | 0.973 | 0.981 💥 | 0.974 |
PCam | 0.835 | 0.791 | 0.807 | 0.846 💥 |
CLEVR | 0.308 | 0.358 | 0.318 | 0.378 💥 |
Rendered SST2 | 0.651 | 0.651 | 0.637 | 0.661 💥 |
FGVC Aircraft | 0.283 | 0.290 | 0.341 | 0.362 💥 |
Oxford Pets | 0.730 | 0.819 | 0.753 | 0.856 💥 |
Caltech101 | 0.922 | 0.914 | 0.937 💥 | 0.932 |
HatefulMemes | 0.581 | 0.563 | 0.585 💥 | 0.578 |
Полезные ссылки
Обратная связь
Круглосуточная поддержка по телефону 8 800 444-24-99, почте support@cloud.ru и в Telegram