from text_encoder import TextEncoderCLIP (Contrastive Language-Image Pre-Training)
(clip paper)[https://arxiv.org/abs/2103.00020]
(clip code)[https://github.com/openai/CLIP/blob/main/clip/model.py]
(clip in transformers)[https://github.com/huggingface/transformers/tree/main/src/transformers/models/clip]
SigLIP (Sigmoid Loss for Language–Image Pre-training)
- ( siglip paper)[https://arxiv.org/abs/2303.15343]
Intuition
- This is a dual encoder model. It is trained with contrastive learning to align the embedding pairs of text and images.
- There is no decoder and no generation so this isnt a generative model nor a multimodal LLM yet.
Architecture
- Image Encoder: this is a ViT. We can literally take what we built in module 1 and import it. Its task is to take an image and output its embedding vector.
- Text Encoder: this is a transformer with a causal mask. We can import from what we built in our text encoder notebook. Its task is to take a sentence and output its embedding vector.
- Contrastive loss: this will our training objective. The idea is to make a matrix of a bunch of pairs of captions and images, get their embeddings, and make the diagonal high in cosine similarity and the others far away in cosine similarity.
- we will need both encoders to have the same
hidden_dimso we can get the cosine similarity of pairs and we will need some normalization for their repsective outputs.
Training
- given a batch of B image-text pairs:
- encode all images
[N, 512] - encode all text descriptions of the images
[N, 512] - compute an
NxNsimilarity matrix which will belogits = image_embeddings @ text_embeddings.T * temperature - diagonal of
NxNsimilarity matrix is correct pairs - calculate cross-entropy loss along the rows and the columns.
Imports
We can import our vanilla ViT from ../module1_vision_transformer_foundations/vit.py and our text encoder from text_encoder.py
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
sys.path.append('../module1_vision_transformer_foundations')
from vit import PatchEmbedding, TransformerEncoderclass VisionEncoder(nn.Module):
def __init__(self, img_size, patch_size, num_hiddens, num_heads, num_classes, depth=6, dropout=0.1):
super().__init__()
self.num_patches = (img_size // patch_size) ** 2
self.patch_embedding = PatchEmbedding(img_size, patch_size, num_hiddens)
self.cls_token = nn.Parameter(torch.randn(1, 1, num_hiddens))
self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, num_hiddens))
self.encoder = TransformerEncoder(num_hiddens, depth, num_heads, mlp_dim=num_hiddens * 4)
def forward(self, x):
batch = x.shape[0]
patches = self.patch_embedding(x)
patches = patches.flatten(2).transpose(1, 2)
cls_tokens = self.cls_token.expand(batch, -1, -1)
x = torch.cat((cls_tokens, patches), dim=1)
x = x + self.pos_embedding
x = self.encoder(x)
return xPutting it all together
- now its just as simple as geting our vision encoder, the text encoder, projecting them to be the same dimensions, generating a similarity matrix with normalized scores, getting cross entropy across rows and columns and contrast loss that.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class CLIP(nn.Module):
def __init__(self, vision_encoder, text_encoder, vision_encoder_dim, text_encoder_dim, temperature):
super().__init__()
self.vision_encoder = vision_encoder
self.text_encoder = text_encoder
self.vision_projection = nn.Linear(vision_encoder_dim, text_encoder_dim)
self.temperature = temperature
# self.contrast
def forward(self, pair):
x_vision = self.vision_encoder(pair[0])
cls_emb_vision = x_vision[:, 0]
image_embs = self.vision_projection(cls_emb_vision)
text_embs = self.text_encoder(pair[1])
image_embs = image_embs / image_embs.norm(dim=-1, keepdim=True) # without this, the mean and std of the two are very different!
text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)
logits = (image_embs @ text_embs.T) * self.temperature
labels = torch.arange(logits.shape[0], device=logits.device)
loss_i = F.cross_entropy(logits, labels) # rows: for each image, which text?
loss_t = F.cross_entropy(logits.T, labels) # cols: for each text, which image?
loss = (loss_i + loss_t) / 2
return logits, lossvision_encoder = VisionEncoder(img_size=56, patch_size=7, num_hiddens=768,num_heads=8, num_classes=100, depth=2, dropout=0.1)
# lets input a tokenizer that has been already trained
from transformers import CLIPTokenizer
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
text_encoder = TextEncoder( vocab_size=49408, embed_dim=512, num_heads=8, num_layers=12, max_seq_len=77, projection_dim=512)/Users/jpoberhauser/mambaforge3/envs/sam2_env/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
clip = CLIP(vision_encoder, text_encoder, 768, 512, temperature = 0.1)img = torch.randn(4, 3, 56, 56)
text_input = ["This is the pitcure of a cow in a pasture. The cow is very close to the camera and its a sunny day.",
"This is a dog",
"This is a cat",
"This is an airplane about to take off"]
batch = tokenizer(text_input, padding=True, max_length=77, truncation=True, return_tensors="pt") pair = (img, batch["input_ids"])logits, loss = clip(pair)logits.shape, loss(torch.Size([4, 4]), tensor(1.3864, grad_fn=<DivBackward0>))
from torchvision import datasets, transforms
transform = transforms.Compose([
transforms.Resize((56, 56)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
clip = clip.to(device)from torch.utils.data import DataLoader
optimizer = torch.optim.AdamW(clip.parameters(), lr=1e-4, weight_decay=0.01)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=2)
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']num_epochs = 20
for epoch in range(num_epochs):
clip.train()
total_loss = 0
num_batches = 0
for images, labels in dataloader:
# Build captions from labels
captions = [f"a photo of a {cifar10_classes[l]}" for l in labels]
tokens = tokenizer(captions, padding=True, max_length=77,
truncation=True, return_tensors="pt")
images = images.to(device)
token_ids = tokens["input_ids"].to(device)
logits, loss = clip((images, token_ids))
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / num_batches
print(f"Epoch {epoch+1}/{num_epochs} — loss: {avg_loss:.4f}")Epoch 1/20 — loss: 4.8112
Epoch 2/20 — loss: 4.7977
Epoch 3/20 — loss: 4.7979
Epoch 4/20 — loss: 4.7967
Epoch 5/20 — loss: 4.7955
Epoch 6/20 — loss: 4.7965
Epoch 7/20 — loss: 4.7950
Epoch 8/20 — loss: 4.7934
Epoch 9/20 — loss: 4.7959
Epoch 10/20 — loss: 4.7948
Epoch 11/20 — loss: 4.7939
Epoch 12/20 — loss: 4.7937
Epoch 13/20 — loss: 4.7937
Epoch 14/20 — loss: 4.7948
Epoch 15/20 — loss: 4.7954
Epoch 16/20 — loss: 4.7937
Epoch 17/20 — loss: 4.7925
Epoch 18/20 — loss: 4.7915
Epoch 19/20 — loss: 4.7928
Epoch 20/20 — loss: 4.7921
clip.eval()
# Encode all class prompts once
class_prompts = [f"a photo of a {c}" for c in cifar10_classes]
class_tokens = tokenizer(class_prompts, padding=True, max_length=77,
truncation=True, return_tensors="pt")["input_ids"].to(device)
with torch.no_grad():
text_embs = clip.text_encoder(class_tokens)
text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True) # (10, 512)
# Test on a batch
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
x = clip.vision_encoder(images)
image_embs = clip.vision_projection(x[:, 0])
image_embs = image_embs / image_embs.norm(dim=-1, keepdim=True)
# Each image vs all 10 class prompts
sims = image_embs @ text_embs.T # (B, 10)
preds = sims.argmax(dim=-1)
correct += (preds == labels.to(device)).sum().item()
total += labels.shape[0]
print(f"accuracy: {correct/total:.2%}")accuracy: 17.56%
import matplotlib.pyplot as plt# Encode a bunch of test images
all_image_embs = []
all_labels = []
retrieval_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)
with torch.no_grad():
for images, labels in retrieval_loader:
x = clip.vision_encoder(images.to(device))
embs = clip.vision_projection(x[:, 0])
embs = embs / embs.norm(dim=-1, keepdim=True)
all_image_embs.append(embs.cpu())
all_labels.append(labels)
all_image_embs = torch.cat(all_image_embs)
all_labels = torch.cat(all_labels)
# Query with a text prompt
query = "a photo of a horse"
query_tokens = tokenizer([query], padding=True, max_length=77,
truncation=True, return_tensors="pt")["input_ids"].to(device)
with torch.no_grad():
query_emb = clip.text_encoder(query_tokens)
query_emb = query_emb / query_emb.norm(dim=-1, keepdim=True)
sims = (all_image_embs @ query_emb.cpu().T).squeeze()
top5 = sims.topk(5).indices
fig, axes = plt.subplots(1, 5, figsize=(12, 3))
fig.suptitle(f'Top 5 retrieved images for: "{query}"')
for i, idx in enumerate(top5):
img = test_dataset[idx][0] * 0.5 + 0.5 # undo normalization
axes[i].imshow(img.permute(1, 2, 0).clamp(0, 1))
axes[i].set_title(f"{cifar10_classes[all_labels[idx]]} ({sims[idx]:.2f})")
axes[i].axis('off')
plt.tight_layout()
plt.show() 
Results
- clearly, this model needs some work so lets understand why.
- All the captions for ‘dog’ are the same. thats not a lot of textual information that we are adding, and the cosine similarity of pictures of different dogs will always receive the same text token embeddings, so model wont be able to learn any more languag-visual relationships from different images.
- data is 32x32, not a lot of visual information
- our ViT model is very tiny, real cLIP uses a ViT-B/32 with 12 layers and 224x224 images. ours is 2 layers instead of 12 and images are 56x56.
- The fact that its learning above 10% accuracy means its learning something so the trianing loop and losses are probably corectly configured.
- Our batch size is small. The Real CLIP uses batch size of 32,768!!
- Data is small, real CLIP uses 400M image-text pairs.
Lets compare to a true pre-trained CLIP
from transformers import CLIPModel, CLIPProcessor
# Load pretrained CLIP
pretrained_clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
pretrained_clip.eval()
# --- Zero-shot classification ---
correct = 0
total = 0
# Raw (unnormalized) test dataset for the processor
raw_test = datasets.CIFAR10(root='./data', train=False, download=False)
test_loader_raw = DataLoader(raw_test, batch_size=256, shuffle=False)
class_prompts = [f"a photo of a {c}" for c in cifar10_classes] Loading weights: 100%|██████████| 398/398 [00:00<00:00, 48068.79it/s] CLIPModel LOAD REPORT from: openai/clip-vit-base-patch32 Key | Status | | -------------------------------------+------------+--+- text_model.embeddings.position_ids | UNEXPECTED | | vision_model.embeddings.position_ids | UNEXPECTED | | Notes: - UNEXPECTED :can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- for retrieval, lets extract all the embeddings for the iamges
# --- Retrieval ---
all_image_embs_pt = []
all_labels_pt = []
batch_size = 256
with torch.no_grad():
for i in range(0, len(raw_test), batch_size):
batch = [raw_test[j] for j in range(i, min(i + batch_size, len(raw_test)))]
images = [b[0] for b in batch]
labels = torch.tensor([b[1] for b in batch])
image_inputs = processor(images=images, return_tensors="pt").to(device)
outputs = pretrained_clip.vision_model(**image_inputs)
embs = pretrained_clip.visual_projection(outputs.pooler_output)
embs = embs / embs.norm(dim=-1, keepdim=True)
embs = embs / embs.norm(dim=-1, keepdim=True)
all_image_embs_pt.append(embs.cpu())
all_labels_pt.append(labels)
all_image_embs_pt = torch.cat(all_image_embs_pt)
all_labels_pt = torch.cat(all_labels_pt)- then we extract the query inputs from the text processor and find the most similar
query = "a photo of a horse"
with torch.no_grad():
query_inputs = tokenizer([query], return_tensors="pt", padding=True).to(device)
outputs = pretrained_clip.text_model(**query_inputs)
query_emb = pretrained_clip.text_projection(outputs.pooler_output)
query_emb = query_emb / query_emb.norm(dim=-1, keepdim=True)
sims = (all_image_embs_pt @ query_emb.cpu().T).squeeze()
top5 = sims.topk(5).indices
fig, axes = plt.subplots(1, 5, figsize=(12, 3))
fig.suptitle(f'Pretrained CLIP — Top 5 for: "{query}"')
for i, idx in enumerate(top5):
img = raw_test[idx][0]
axes[i].imshow(img)
axes[i].set_title(f"{cifar10_classes[all_labels_pt[idx]]} ({sims[idx]:.2f})")
axes[i].axis('off')
plt.tight_layout()
plt.show()