Part 1: Attention Map Visualization

Part 2: Frozen Feature Extraction + Linear Probing

Part 3: Survey Table

For each method (DINOv1, DINOv2, DINOv3, BEiT, JEPA, V-JEPAv2) we get: - What does it predict? (distillation targets / pixels / visual tokens / representations) - Architecture (pure ViT vs hybrid) - Does it need fine-tuning or do frozen features work? - Insights

Part1

import torch                                                                                                                                                   
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from PIL import Image
import requests
from io import BytesIO
dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
dinov2.eval()
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
dinov2 = dinov2.to(device)
Using cache found in /Users/jpoberhauser/.cache/torch/hub/facebookresearch_dinov2_main
/Users/jpoberhauser/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/layers/swiglu_ffn.py:51: UserWarning: xFormers is not available (SwiGLU)
  warnings.warn("xFormers is not available (SwiGLU)")
/Users/jpoberhauser/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/layers/attention.py:33: UserWarning: xFormers is not available (Attention)
  warnings.warn("xFormers is not available (Attention)")
/Users/jpoberhauser/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/layers/block.py:40: UserWarning: xFormers is not available (Block)
  warnings.warn("xFormers is not available (Block)")
  • the loaded model is a ViT-S/14 which means its a patch size of 14 on a 224x224 image. This should give us 256 patches to work on and the cls_token
from torchvision.datasets import CIFAR100
transform = transforms.Compose([
      transforms.Resize((224, 224)),
      transforms.ToTensor(),
      transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  ])
train_dataset_raw = CIFAR100(root='./data', train=True, download=True, transform=transform)
img, label = train_dataset_raw[-1]  # [3, 56, 56]
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
img_display = (img * std + mean).clamp(0, 1)

plt.imshow(img_display.permute(1, 2, 0))  # [3, 56, 56] -> [56, 56, 3]                                                                            
plt.title(f"CIFAR-100 resized to 224×224 | Label: {label}")                                                                                 
plt.axis('off')      
class_names = train_dataset_raw.classes  # list of 100 class names                                                                        
print(f"Label: {label}{class_names[label]}")                                                                                                                     
plt.show()                                                                                                                                
             
Label: 73 → shark

dinov2.to(device)
DinoVisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(14, 14), stride=(14, 14))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0-11): 12 x NestedTensorBlock(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): MemEffAttention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
      (ls2): LayerScale()
      (drop_path2): Identity()
    )
  )
  (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
  (head): Identity()
)
img = img.to(device)
outs = dinov2(img.unsqueeze(0)) # this is just the cls_token embedding!
outs.shape
torch.Size([1, 384])
attention_maps = {}                                                                                                                                            
                
def hook_fn(module, input, output):
    # output is a tuple, the attention weights are computed inside
    # We need to get at the attention weights before they're multiplied by V
    attention_maps['last'] = output
print(dinov2.blocks[-1].attn)
MemEffAttention(
  (qkv): Linear(in_features=384, out_features=1152, bias=True)
  (attn_drop): Dropout(p=0.0, inplace=False)
  (proj): Linear(in_features=384, out_features=384, bias=True)
  (proj_drop): Dropout(p=0.0, inplace=False)
)
with torch.no_grad():
    out = dinov2.get_intermediate_layers(img.unsqueeze(0).to(device),
                                        n=1,
                                        return_class_token=True)
with torch.no_grad():
    # Prepare tokens the same way the model does
    x = dinov2.prepare_tokens_with_masks(img.unsqueeze(0).to(device))

    # Run through all blocks, grab attention from the last one
    for i, block in enumerate(dinov2.blocks):
        if i < len(dinov2.blocks) - 1:
            x = block(x)
        else:
            # Last block — manually get attention
            attn = block.attn
            B, N, C = x.shape
            qkv = attn.qkv(block.norm1(x)).reshape(B, N, 3, attn.num_heads, C // attn.num_heads).permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[2]
            attn_weights = (q @ k.transpose(-2, -1)) * attn.scale
            attn_weights = attn_weights.softmax(dim=-1)
print(attn_weights.shape)
torch.Size([1, 6, 257, 257])
## the cls_token is now in 
cls_attn = attn_weights[0, :, 0, 1:] # this tells us how each of the 6 heads is attentfing to the patches in the image
cls_attn.shape
torch.Size([6, 256])
cls_attn = cls_attn.reshape(6, 16, 16) 
fig, axes = plt.subplots(1, 7, figsize=(20, 4))

# Original image
axes[0].imshow(img_display.permute(1, 2, 0))
axes[0].set_title("Original")

# Each head's attention
for i in range(6):
    axes[i+1].imshow(img_display.permute(1, 2, 0))
    axes[i+1].imshow(cls_attn[i].cpu(), alpha=0.6, cmap='viridis', interpolation='bilinear')
    axes[i+1].set_title(f"Head {i}")

for ax in axes:
    ax.axis('off')
plt.suptitle("DINOv2 — CLS token attention per head")
plt.tight_layout()
plt.show()

  • so now, we can actually visualize the attention maps of each of the heads in Dinov2.

  • below, we can pick a specific patch and see which other patches it pays high attention to, in each head.

patch_idx = 64  # +1 because position 0 is CLS
patch_attn = attn_weights[0, :, patch_idx + 1, 1:]  # [6, 256]
patch_attn = patch_attn.reshape(6, 16, 16)

fig, axes = plt.subplots(1, 7, figsize=(20, 4))

# Original image
axes[0].imshow(img_display.permute(1, 2, 0))
axes[0].set_title("Original")

# Each head's attention
for i in range(6):
    axes[i+1].imshow(img_display.permute(1, 2, 0))
    axes[i+1].imshow(patch_attn[i].cpu(), alpha=0.6, cmap='viridis', interpolation='bilinear')
    axes[i+1].set_title(f"Head {i}")

for ax in axes:
    ax.axis('off')
plt.suptitle("DINOv2 — patch attention per head")
plt.tight_layout()
plt.show()

patch_idx = 120  # +1 because position 0 is CLS
patch_attn = attn_weights[0, :, patch_idx + 1, 1:]  # [6, 256]
patch_attn = patch_attn.reshape(6, 16, 16)

fig, axes = plt.subplots(1, 7, figsize=(20, 4))

# Original image
axes[0].imshow(img_display.permute(1, 2, 0))
axes[0].set_title("Original")

# Each head's attention
for i in range(6):
    axes[i+1].imshow(img_display.permute(1, 2, 0))
    axes[i+1].imshow(patch_attn[i].cpu(), alpha=0.6, cmap='viridis', interpolation='bilinear')
    axes[i+1].set_title(f"Head {i}")

for ax in axes:
    ax.axis('off')
plt.suptitle("DINOv2 — patch attention per head")
plt.tight_layout()
plt.show()

CLIP

  • lets try with clip. We download a vit-B/16 which has patches of size 16.
raw_dataset = CIFAR100(root='./data', train=True, download=True)                                                                                               
img_pil, label = raw_dataset[-1]  
from transformers import CLIPVisionModel, CLIPProcessor                                                                                                        
from torchvision.datasets import CIFAR100
clip_vision = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch16", attn_implementation="eager")                                                     
clip_vision.eval()
clip_vision = clip_vision.to(device)
Loading weights: 100%|██████████| 199/199 [00:00<00:00, 41886.21it/s]

CLIPVisionModel LOAD REPORT from: openai/clip-vit-base-patch16

Key                                                          | Status     |  | 

-------------------------------------------------------------+------------+--+-

text_model.encoder.layers.{0...11}.mlp.fc2.bias              | UNEXPECTED |  | 

text_model.encoder.layers.{0...11}.mlp.fc2.weight            | UNEXPECTED |  | 

text_model.encoder.layers.{0...11}.mlp.fc1.weight            | UNEXPECTED |  | 

text_model.encoder.layers.{0...11}.layer_norm2.weight        | UNEXPECTED |  | 

text_model.encoder.layers.{0...11}.layer_norm2.bias          | UNEXPECTED |  | 

text_model.encoder.layers.{0...11}.self_attn.k_proj.bias     | UNEXPECTED |  | 

text_model.encoder.layers.{0...11}.layer_norm1.weight        | UNEXPECTED |  | 

text_model.encoder.layers.{0...11}.self_attn.out_proj.weight | UNEXPECTED |  | 

text_model.encoder.layers.{0...11}.self_attn.v_proj.weight   | UNEXPECTED |  | 

text_model.encoder.layers.{0...11}.mlp.fc1.bias              | UNEXPECTED |  | 

text_model.encoder.layers.{0...11}.self_attn.v_proj.bias     | UNEXPECTED |  | 

text_model.encoder.layers.{0...11}.self_attn.k_proj.weight   | UNEXPECTED |  | 

text_model.encoder.layers.{0...11}.self_attn.q_proj.weight   | UNEXPECTED |  | 

text_model.encoder.layers.{0...11}.layer_norm1.bias          | UNEXPECTED |  | 

text_model.encoder.layers.{0...11}.self_attn.q_proj.bias     | UNEXPECTED |  | 

text_model.encoder.layers.{0...11}.self_attn.out_proj.bias   | UNEXPECTED |  | 

logit_scale                                                  | UNEXPECTED |  | 

text_model.final_layer_norm.weight                           | UNEXPECTED |  | 

text_model.embeddings.token_embedding.weight                 | UNEXPECTED |  | 

text_projection.weight                                       | UNEXPECTED |  | 

text_model.embeddings.position_embedding.weight              | UNEXPECTED |  | 

text_model.embeddings.position_ids                           | UNEXPECTED |  | 

vision_model.embeddings.position_ids                         | UNEXPECTED |  | 

text_model.final_layer_norm.bias                             | UNEXPECTED |  | 

visual_projection.weight                                     | UNEXPECTED |  | 



Notes:

- UNEXPECTED :can be ignored when loading from different task/architecture; not ok if you expect identical arch.
raw_dataset = CIFAR100(root='./data', train=True, download=True)
img_pil, label = raw_dataset[-1]
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
inputs = processor(images=img_pil, return_tensors="pt").to(device)
with torch.no_grad():
      outputs = clip_vision(**inputs, output_attentions=True)
last_attn = outputs.attentions[-1]  # [1, 12, 197, 197]
print(last_attn.shape)
torch.Size([1, 12, 197, 197])
clip_cls_attn = last_attn[0, :, 0, 1:].cpu()  # [12, 196]                                                                                                            
clip_cls_attn = clip_cls_attn.reshape(12, 14, 14).cpu()  # 14x14 grid  
img_tensor = transforms.Compose([                                                                                                                              
      transforms.Resize((224, 224)),
      transforms.ToTensor(),                                                                                                                                     
  ])(img_pil)

img_display = img_tensor.cpu()                                                                                                                                 
cls_attn = cls_attn.cpu()                                                                                                                                      
clip_cls_attn = clip_cls_attn.cpu()
fig, axes = plt.subplots(2, 7, figsize=(20, 7))                                                                                                                
                                                                                                                                                                 
# Top row: DINOv2 (6 heads)
axes[0, 0].imshow(img_display.permute(1, 2, 0))
axes[0, 0].set_title("Original")
for i in range(6):
    axes[0, i+1].imshow(img_display.permute(1, 2, 0))
    axes[0, i+1].imshow(cls_attn[i].cpu(), alpha=0.6, cmap='viridis', interpolation='bilinear')
    axes[0, i+1].set_title(f"DINO Head {i}")

# Bottom row: CLIP (first 6 of 12 heads)
axes[1, 0].imshow(img_display.permute(1, 2, 0))
axes[1, 0].set_title("Original")
for i in range(6):
    axes[1, i+1].imshow(img_display.permute(1, 2, 0))
    axes[1, i+1].imshow(clip_cls_attn[i].cpu(), alpha=0.6, cmap='viridis', interpolation='bilinear')
    axes[1, i+1].set_title(f"CLIP Head {i}")

for ax in axes.flat:
    ax.axis('off')
plt.suptitle("DINOv2 vs CLIP — CLS attention per head (last layer)")
plt.tight_layout()
plt.show()

mid_attn = outputs.attentions[5]  # 0-indexed, so index 5 = layer 6
mid_cls_attn = mid_attn[0, :, 0, 1:].reshape(12, 14, 14).cpu()

fig, axes = plt.subplots(1, 7, figsize=(20, 4))
axes[0].imshow(img_display.permute(1, 2, 0))
axes[0].set_title("Original")
for i in range(6):
    axes[i+1].imshow(img_display.permute(1, 2, 0))
    axes[i+1].imshow(mid_cls_attn[i], alpha=0.6, cmap='viridis', interpolation='bilinear')
    axes[i+1].set_title(f"CLIP L6 Head {i}")
for ax in axes:
    ax.axis('off')
plt.suptitle("CLIP — middle layer attention")
plt.tight_layout()
plt.show()

dino_layers = []
with torch.no_grad():
    x = dinov2.prepare_tokens_with_masks(img.unsqueeze(0).to(device))
    for block in dinov2.blocks:
        attn = block.attn
        B, N, C = x.shape
        qkv = attn.qkv(block.norm1(x)).reshape(B, N, 3, attn.num_heads, C // attn.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        weights = (q @ k.transpose(-2, -1)) * attn.scale
        weights = weights.softmax(dim=-1)
        dino_layers.append(weights[0, 0, 0, 1:].cpu().reshape(16, 16))  # head 0, CLS row
        x = block(x)

fig, axes = plt.subplots(2, len(dino_layers), figsize=(24, 6))
for i, attn_map in enumerate(dino_layers):
    axes[0, i].imshow(img_display.permute(1, 2, 0))
    axes[0, i].imshow(attn_map, alpha=0.6, cmap='viridis', interpolation='bilinear')
    axes[0, i].set_title(f"DINO L{i}")

for i, layer_attn in enumerate(outputs.attentions):
    axes[1, i].imshow(img_display.permute(1, 2, 0))
    axes[1, i].imshow(layer_attn[0, 0, 0, 1:].cpu().reshape(14, 14), alpha=0.6, cmap='viridis', interpolation='bilinear')
    axes[1, i].set_title(f"CLIP L{i}")

for ax in axes.flat:
    ax.axis('off')
plt.suptitle("Head 0 attention across layers — DINO (top) vs CLIP (bottom)")
plt.tight_layout()
plt.show()

Notice how the last layers in CLIP dont make that much sense? This is one of the differences in using self-supervised vs language-supervised pre-training. CLIP is forced to compress those representations into a single cls_token so that it can be comapred to its language token.

Dino can keep things more sharp since the student is tasked with matching the representation of the teacher. So since we feed crops in DINO, we force the model to remebeber everything in the image so it can get good contrastive scores against its teacher.

The first layers in CLIP are very sharp though since thats where its keeping intermediate representations of the image. The last layers become more of a “summary”

Using these models to classify CIFAR100