import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from PIL import Image
import requests
from io import BytesIOPart 1: Attention Map Visualization
lets load a pre0trained Dino-v2 to run iamges through and extract self-attention maps.
overlay visualizations from the cls_token row and visualize it.
compare to a pre-trained CLIP Vit.
bonus: use (LiFT Model)[https://github.com/bpiyush/LiFT/] to make chirality-aware embeddings from DinoV2 in videos for more temporally coherent video embddings.
Part 2: Frozen Feature Extraction + Linear Probing
- We can load the CIFAR-100 dataset we have been working on thus far, extract cls embeddings from Dinov2 to train a simple
nn.Linearclassifier on top of the frozen model’s embeddings. Compare accuracy on our vanilla implrementations.
Part 3: Survey Table
- Finally, we create a quick survey of the self-supervised vision models out there.
For each method (DINOv1, DINOv2, DINOv3, BEiT, JEPA, V-JEPAv2) we get: - What does it predict? (distillation targets / pixels / visual tokens / representations) - Architecture (pure ViT vs hybrid) - Does it need fine-tuning or do frozen features work? - Insights
Part 2
dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
dinov2.eval()
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
dinov2 = dinov2.to(device)Using cache found in /Users/jpoberhauser/.cache/torch/hub/facebookresearch_dinov2_main
/Users/jpoberhauser/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/layers/swiglu_ffn.py:51: UserWarning: xFormers is not available (SwiGLU)
warnings.warn("xFormers is not available (SwiGLU)")
/Users/jpoberhauser/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/layers/attention.py:33: UserWarning: xFormers is not available (Attention)
warnings.warn("xFormers is not available (Attention)")
/Users/jpoberhauser/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/layers/block.py:40: UserWarning: xFormers is not available (Block)
warnings.warn("xFormers is not available (Block)")
from torchvision.datasets import CIFAR100
from torch.utils.data import DataLoadertransform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
train_dataset = CIFAR100(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR100(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# this part takes a bit on my local MPS apple ~8 minutes
def extract_features(model, loader, device):
all_features = []
all_labels = []
with torch.no_grad():
for images, labels in loader:
features = model(images.to(device)) # [B, 384] CLS embedding
all_features.append(features.cpu())
all_labels.append(labels)
return torch.cat(all_features), torch.cat(all_labels)
train_features, train_labels = extract_features(dinov2, train_loader, device)
test_features, test_labels = extract_features(dinov2, test_loader, device)--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) Cell In[12], line 12 9 all_labels.append(labels) 10 return torch.cat(all_features), torch.cat(all_labels) ---> 12 train_features, train_labels = extract_features(dinov2, train_loader, device) 13 test_features, test_labels = extract_features(dinov2, test_loader, device) Cell In[12], line 8, in extract_features(model, loader, device) 6 for images, labels in loader: 7 features = model(images.to(device)) # [B, 384] CLS embedding ----> 8 all_features.append(features.cpu()) 9 all_labels.append(labels) 10 return torch.cat(all_features), torch.cat(all_labels) KeyboardInterrupt:
print(train_features.shape) # [50000, 384]
print(test_features.shape) # [10000, 384]torch.Size([50000, 384])
torch.Size([10000, 384])
from torch.utils.data import TensorDataset
train_feat_dataset = TensorDataset(train_features, train_labels)
test_feat_dataset = TensorDataset(test_features, test_labels)
train_feat_loader = DataLoader(train_feat_dataset, batch_size=256, shuffle=True)
test_feat_loader = DataLoader(test_feat_dataset, batch_size=256, shuffle=False)linear_probe = nn.Linear(384, 100).to(device)# build the simplest linear probe.
# the above is the only thing that is updating parameters, everything else is completely frozen
optimizer = torch.optim.Adam(linear_probe.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()for epoch in range(50):
linear_probe.train()
total_loss, correct, total = 0, 0, 0
for features, labels in train_feat_loader:
features, labels = features.to(device), labels.to(device)
logits = linear_probe(features)
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
correct += (logits.argmax(dim=-1) == labels).sum().item()
total += labels.size(0)
# Test accuracy
linear_probe.eval()
test_correct, test_total = 0, 0
with torch.no_grad():
for features, labels in test_feat_loader:
features, labels = features.to(device), labels.to(device)
logits = linear_probe(features)
test_correct += (logits.argmax(dim=-1) == labels).sum().item()
test_total += labels.size(0)
print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_feat_loader):.4f} | Train Acc: {correct/total:.4f} | Test Acc: {test_correct/test_total:.4f}")Epoch 50 | Loss: 0.1266 | Train Acc: 0.9635 | Test Acc: 0.8085