from vit import ViT, PatchEmbedding, MultiHeadAttention, TransformerEncoder ,ClassificationHead , FeedForward Compare ViT with Swin, CoAtNet, DeiT
- Swin Transformer: https://arxiv.org/abs/2103.14030
- DeiT (Data-efficient Image Transformers): https://arxiv.org/abs/2012.12877
- CoAtNet: https://arxiv.org/abs/2106.04803
- Attention Is All You Need: https://arxiv.org/abs/1706.03762
- timm repo (has all these models implemented): https://github.com/huggingface/pytorch-image-models
- The Annotated Transformer (Harvard NLP walkthrough): https://nlp.seas.harvard.edu/annotated-transformer/
Key Ideas Oversimplified:
Swin Transformer
- Shifted windows. Instead of global attention (every patch attends to every patch), Swin gets attentoin within local windows. (for example 7x7 patches). Then, it shifts the window grid by half a window between layers do information can flow across boundaries.
- Hierarchichal feature maps. unlike ViT which keeps the same resolution thrughout, Swin progressively merges patches. (Think of this as pooling in CNNs). This gives us multi-scale features which will become very important in detection and segmentation later on.
- Overall, Swim gives efficiency and versatiltiy. You can now get comparable accuracy with local attention of you shift the windows in a clever way.
DEiT
- showed you can get competitive results with less data, as long as you have the right training recipe. String augmentation, good regularization, and some careful hyperparameter tuning.
- Distillation token. DEiT will add a second special token on top of the cls_token to learn from a CNN model. The CNN can teach the ViT what to attend to during training.
CoAtNet
- Brings in the best of both worlds. CNNS are good at local geatures and have nice inductive biases. Transformers are good at global relationships. If you stack CNNs and transformer blocks, you get good results.
- Relative Attention. Take away the absolute positional embeddings like we had in the vanilla ViT, and instead use relative position biases in attentnions scores. This generalizes nicely to different input sizes.
Building on top of our ViT, we export everything we wrote in the first module to vit.py. We can import that work and build on top of it
Baseline Vanilla ViT on CIFAR-100
from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.nn.functional as F # This gives us the softmax()
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.Resize((56, 56)),
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])
train_dataset = CIFAR100(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR100(root='./data', train=False, download=True, transform=transform)device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using: {device}")Using: mps
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)model = ViT(
img_size=56, patch_size=7, num_hiddens=256,
num_heads=8, num_classes=100, depth=2, dropout=0.1
).to(device)import timeoptimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()
for epoch in range(50):
model.train()
total_loss, correct, total = 0, 0, 0
start = time.time()
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
logits = model(images)
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
correct += (logits.argmax(dim=-1) == labels).sum().item()
total += labels.size(0)
elapsed = time.time() - start
print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f} | Acc: {correct/total:.4f} | Time: {elapsed:.1f}s")Epoch 1 | Loss: 3.6996 | Acc: 0.1320 | Time: 58.0s
Epoch 2 | Loss: 3.1074 | Acc: 0.2328 | Time: 54.4s
Epoch 3 | Loss: 2.8711 | Acc: 0.2776 | Time: 55.0s
Epoch 4 | Loss: 2.6972 | Acc: 0.3132 | Time: 54.8s
Epoch 5 | Loss: 2.5538 | Acc: 0.3411 | Time: 56.6s
Epoch 6 | Loss: 2.4167 | Acc: 0.3706 | Time: 54.4s
Epoch 7 | Loss: 2.2940 | Acc: 0.3952 | Time: 56.0s
Epoch 8 | Loss: 2.1746 | Acc: 0.4251 | Time: 55.3s
Epoch 9 | Loss: 2.0614 | Acc: 0.4436 | Time: 60.4s
Epoch 10 | Loss: 1.9506 | Acc: 0.4709 | Time: 72.7s
Epoch 11 | Loss: 1.8455 | Acc: 0.4942 | Time: 66.1s
Epoch 12 | Loss: 1.7399 | Acc: 0.5175 | Time: 63.0s
Epoch 13 | Loss: 1.6423 | Acc: 0.5414 | Time: 79.2s
Epoch 14 | Loss: 1.5446 | Acc: 0.5616 | Time: 70.1s
Epoch 15 | Loss: 1.4461 | Acc: 0.5862 | Time: 63.4s
Epoch 16 | Loss: 1.3662 | Acc: 0.6036 | Time: 55.9s
Epoch 17 | Loss: 1.2817 | Acc: 0.6240 | Time: 51.1s
Epoch 18 | Loss: 1.1982 | Acc: 0.6467 | Time: 51.2s
Epoch 19 | Loss: 1.1358 | Acc: 0.6619 | Time: 51.3s
Epoch 20 | Loss: 1.0614 | Acc: 0.6818 | Time: 50.5s
Epoch 21 | Loss: 1.0048 | Acc: 0.6975 | Time: 51.6s
Epoch 22 | Loss: 0.9419 | Acc: 0.7121 | Time: 50.7s
Epoch 23 | Loss: 0.9048 | Acc: 0.7216 | Time: 52.0s
Epoch 24 | Loss: 0.8512 | Acc: 0.7360 | Time: 50.8s
Epoch 25 | Loss: 0.8122 | Acc: 0.7465 | Time: 50.9s
Epoch 26 | Loss: 0.7718 | Acc: 0.7590 | Time: 50.9s
Epoch 27 | Loss: 0.7411 | Acc: 0.7651 | Time: 50.8s
Epoch 28 | Loss: 0.7087 | Acc: 0.7761 | Time: 50.8s
Epoch 29 | Loss: 0.6745 | Acc: 0.7835 | Time: 50.9s
Epoch 30 | Loss: 0.6510 | Acc: 0.7943 | Time: 51.1s
Epoch 31 | Loss: 0.6234 | Acc: 0.8004 | Time: 50.6s
Epoch 32 | Loss: 0.6049 | Acc: 0.8071 | Time: 51.0s
Epoch 33 | Loss: 0.5770 | Acc: 0.8146 | Time: 51.4s
Epoch 34 | Loss: 0.5646 | Acc: 0.8198 | Time: 51.9s
Epoch 35 | Loss: 0.5501 | Acc: 0.8249 | Time: 52.0s
Epoch 36 | Loss: 0.5282 | Acc: 0.8304 | Time: 50.9s
Epoch 37 | Loss: 0.5201 | Acc: 0.8324 | Time: 51.1s
Epoch 38 | Loss: 0.5080 | Acc: 0.8364 | Time: 51.1s
Epoch 39 | Loss: 0.4854 | Acc: 0.8449 | Time: 51.1s
Epoch 40 | Loss: 0.4739 | Acc: 0.8482 | Time: 51.7s
Epoch 41 | Loss: 0.4668 | Acc: 0.8489 | Time: 51.4s
Epoch 42 | Loss: 0.4563 | Acc: 0.8538 | Time: 51.5s
Epoch 43 | Loss: 0.4431 | Acc: 0.8559 | Time: 51.1s
Epoch 44 | Loss: 0.4399 | Acc: 0.8583 | Time: 50.2s
Epoch 45 | Loss: 0.4152 | Acc: 0.8663 | Time: 54.1s
Epoch 46 | Loss: 0.4126 | Acc: 0.8676 | Time: 55.8s
Epoch 47 | Loss: 0.4040 | Acc: 0.8684 | Time: 54.5s
Epoch 48 | Loss: 0.4008 | Acc: 0.8712 | Time: 54.8s
Epoch 49 | Loss: 0.3975 | Acc: 0.8701 | Time: 52.8s
Epoch 50 | Loss: 0.3782 | Acc: 0.8770 | Time: 52.7s
Adding windowed attention like SWIN
- Step1: Windowed attention
- lets run a 56x56 image through the steps to understand which changes and where we need to make
class AttentionHead(nn.Module):
def __init__(self, dim, head_dim, dropout=0.1):
super().__init__()
self.scale = head_dim**-0.5
self.q = nn.Linear(dim, head_dim, bias=False)
self.k = nn.Linear(dim, head_dim, bias=False)
self.v = nn.Linear(dim, head_dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
q, k, v = self.q(x), self.k(x), self.v(x)
attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
return out
class MultiHeadAttention(nn.Module):
# For windowed attention, we need to reshape before we pass to attention.
# We need to run attention independently for each window.
def __init__(self, dim, num_heads, dropout=0.1):
super().__init__()
assert dim % num_heads == 0, "dim must be divisible by num_heads"
self.head_dim = dim // num_heads
self.num_heads = num_heads
self.heads = nn.ModuleList(
[AttentionHead(dim, self.head_dim, dropout) for _ in range(num_heads)]
)
self.proj = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
head_outputs = [head(x) for head in self.heads]
out = torch.cat(head_outputs, dim=-1)
out = self.proj(out)
out = self.dropout(out)
return outx = torch.randn(2, 3, 56, 56)patch_emb = PatchEmbedding(img_size=56, patch_size=7, num_hiddens=256)
patches = patch_emb(x)
print("After conv:", patches.shape) # [1, 256, 8, 8]
# hidden dim is the size of the vector that represent each patch. So since we choose 256, we expect each patch to be a vector of len 256
patches_flat = patches.flatten(2).transpose(1, 2)
print("Flattened:", patches_flat.shape) # [2, 64, 256] — 8x8=64 patches. Each one of the 64 patches is now represented by a 256 vector. After conv: torch.Size([2, 256, 8, 8])
Flattened: torch.Size([2, 64, 256])
head = AttentionHead(dim=256, head_dim=32)
q = head.q(patches_flat)
k = head.k(patches_flat)
v = head.v(patches_flat)
print("Q:", q.shape) # [2, 64, 32]
print("K:", k.shape) # [2, 64, 32]. Why 32? Because in MHA, we split the input. So each head recevies 32 of the 256. 32*8 = 256.
attn_scores = torch.matmul(q, k.transpose(-1, -2))
print("Attention matrix:", attn_scores.shape) # [2, 64, 64] — every patch attends to every patchQ: torch.Size([2, 64, 32])
K: torch.Size([2, 64, 32])
Attention matrix: torch.Size([2, 64, 64])
- these are the inputs to attention. 64 patches, each represented by a a 32-dimensional vector. Its 32 because each head receives 256/num_heads, so 256/8 = 32.
# The 64 patches are an 8x8 grid. Window size 4 means 2x2 windows of 4x4 patches each.
B, num_patches, dim = patches_flat.shape
grid_size = 8 # sqrt(64)
window_size = 4
grid = patches_flat.reshape(B, grid_size, grid_size, dim)
print("Grid:", grid.shape) # [2, 8, 8, 256]Grid: torch.Size([2, 8, 8, 256])
# Split into windows
# [B, 8, 8, dim] -> [B, 2, 4, 2, 4, dim] -> [B, 2, 2, 4, 4, dim]
windows = grid.reshape(B, grid_size // window_size, window_size, grid_size // window_size, window_size, dim)
windows = windows.permute(0, 1, 3, 2, 4, 5)
print("Windows rearranged:", windows.shape) # [2, 2, 2, 4, 4, 256]Windows rearranged: torch.Size([2, 2, 2, 4, 4, 256])
- The above reshaping. We have an 8x8 grid. But we specify window size of 4. So the 8x8 grid will now be 4 regions, each a 4x4 grid of original patches. So instead of every patch attentding to all others in the 8x8 grid, they only attend to the region, so 4x4.
┌───────────┬───────────┐
│ │ │
│ window │ window │
│ (0,0) │ (0,1) │
│ 4×4 │ 4×4 │
│ │ │
├───────────┼───────────┤
│ │ │
│ window │ window │
│ (1,0) │ (1,1) │
│ 4×4 │ 4×4 │
│ │ │
└───────────┴───────────┘
- B=2 — batch
- 2, 2 — which window (row, col) — 4 windows total
- 4, 4 — which patch within that window — 16 patches per window
- 256 — hidden dim
# Merge batch and window dims so we can run attention as-is
num_windows = (grid_size // window_size) ** 2 # 4
windows = windows.reshape(B * num_windows, window_size * window_size, dim)
print("Ready for attention:", windows.shape) # [8, 16, 256] — 8 = 2 batches * 4 windows, 16 patches each
# Step 4: Run attention on windowed input
attn_out = head(windows)
print("Attention output:", attn_out.shape) # [8, 16, 32]
# Step 5: Reverse the reshape to get back to [B, 64, dim]
# [8, 16, 32] -> [2, 2, 2, 4, 4, 32] -> [2, 8, 8, 32] -> [2, 64, 32]
attn_out = attn_out.reshape(B, grid_size // window_size, grid_size // window_size, window_size, window_size, -1)
attn_out = attn_out.permute(0, 1, 3, 2, 4, 5)
attn_out = attn_out.reshape(B, num_patches, -1)
print("Back to original shape:", attn_out.shape) # [2, 64, 32]Ready for attention: torch.Size([8, 16, 256])
Attention output: torch.Size([8, 16, 32])
Back to original shape: torch.Size([2, 64, 32])
So all we have to do is reshape in MHA
class MultiHeadAttentionShift(nn.Module):
# For windowed attention, we need to reshape before we pass to attention.
# We need to run attention independently for each window.
def __init__(self, dim, num_heads, grid_size,
window_size, dropout=0.1):
super().__init__()
assert dim % num_heads == 0, "dim must be divisible by num_heads"
self.head_dim = dim // num_heads
self.grid_size = grid_size
self.window_size = window_size
self.num_heads = num_heads
self.heads = nn.ModuleList(
[AttentionHead(dim, self.head_dim, dropout) for _ in range(num_heads)]
)
self.proj = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B, num_patches_with_cls, dim = x.shape
cls_token = x[:, :1, :]
x = x[:, 1:, :]
num_patches = x.shape[1]
num_windows = (self.grid_size // self.window_size) ** 2
# Window the patches
grid = x.reshape(B, self.grid_size, self.grid_size, dim)
windows = grid.reshape(B, self.grid_size // self.window_size, self.window_size,
self.grid_size // self.window_size, self.window_size, dim)
windows = windows.permute(0, 1, 3, 2, 4, 5)
windows = windows.reshape(B * num_windows, self.window_size * self.window_size, dim)
# Append CLS to every window so it can attend to all patches
cls_expanded = cls_token.repeat(num_windows, 1, 1) # [B*num_windows, 1, dim]
windows = torch.cat((cls_expanded, windows), dim=1) # [B*num_windows, 17, dim]
# Run attention
head_outputs = [head(windows) for head in self.heads]
out = torch.cat(head_outputs, dim=-1)
# Separate CLS and patches
cls_out = out[:, :1, :] # [B*num_windows, 1, dim]
out = out[:, 1:, :] # [B*num_windows, 16, dim]
# Average CLS across all windows (each window produced a CLS update)
cls_out = cls_out.reshape(B, num_windows, 1, dim).mean(dim=1) # [B, 1, dim]
# Reverse windowing for patches
nw = self.grid_size // self.window_size
out = out.reshape(B, nw, nw, self.window_size, self.window_size, dim)
out = out.permute(0, 1, 3, 2, 4, 5)
out = out.reshape(B, num_patches, dim)
# Recombine
out = torch.cat((cls_out, out), dim=1)
out = self.proj(out)
out = self.dropout(out)
return out
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_dim,
grid_size, window_size = None, dropout=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = MultiHeadAttentionShift(dim, num_heads, grid_size, window_size, dropout)
self.norm2 = nn.LayerNorm(dim)
self.ffn = FeedForward(dim, mlp_dim, dropout)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
class TransformerEncoder(nn.Module):
def __init__(self, dim, depth, num_heads, mlp_dim, grid_size, window_size = None, dropout=0.1):
super().__init__()
self.layers = nn.ModuleList(
[TransformerBlock(dim, num_heads, mlp_dim,grid_size,
window_size, dropout) for _ in range(depth)]
)
self.norm = nn.LayerNorm(dim)
def forward(self, x):
for block in self.layers:
x = block(x)
return self.norm(x)class ViTWindowed(nn.Module):
def __init__(self, img_size, patch_size, num_hiddens, num_heads,
num_classes, window_size, depth=6, dropout=0.1):
super().__init__()
self.num_patches = (img_size // patch_size) ** 2
self.grid_size = img_size // patch_size
# --- Patch Embedding ---
self.patch_embedding = PatchEmbedding(img_size, patch_size, num_hiddens)
# --- CLS token ---
self.cls_token = nn.Parameter(torch.randn(1, 1, num_hiddens))
# --- Positional Embedding (+1 for CLS) ---
self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, num_hiddens))
# --- Transformer Encoder ---
self.encoder = TransformerEncoder(num_hiddens, depth, num_heads,
mlp_dim=num_hiddens*4,
grid_size=self.grid_size,
window_size=window_size)
# --- Classification Head ---
self.classification_head = ClassificationHead(num_hiddens, num_classes)
def forward(self, x):
batch = x.shape[0]
# 1. Patch embed: [B, 3, H, W] -> [B, num_patches, num_hiddens]
patches = self.patch_embedding(x)
patches = patches.flatten(2).transpose(1, 2)
# 2. Prepend CLS token: [B, num_patches, D] -> [B, num_patches+1, D]
cls_tokens = self.cls_token.expand(batch, -1, -1)
x = torch.cat((cls_tokens, patches), dim=1)
# 3. Add positional embeddings
x = x + self.pos_embedding
# 4. Transformer encoder
x = self.encoder(x)
# 5. Classification: extract CLS token -> head
x = self.classification_head(x)
return x # placeholder — returns full sequence for nowmodel = ViTWindowed(
img_size=56, patch_size=7, num_hiddens=256,
num_heads=8, num_classes=100,window_size =4, depth=2, dropout=0.1
).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()
for epoch in range(50):
model.train()
total_loss, correct, total = 0, 0, 0
start = time.time()
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
logits = model(images)
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
correct += (logits.argmax(dim=-1) == labels).sum().item()
total += labels.size(0)
elapsed = time.time() - start
print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f} | Acc: {correct/total:.4f} | Time: {elapsed:.1f}s")Epoch 1 | Loss: 3.7496 | Acc: 0.1264 | Time: 57.0s
Epoch 2 | Loss: 3.1659 | Acc: 0.2223 | Time: 58.7s
Epoch 3 | Loss: 2.9425 | Acc: 0.2640 | Time: 53.9s
Epoch 4 | Loss: 2.7823 | Acc: 0.2941 | Time: 59.9s
Epoch 5 | Loss: 2.6435 | Acc: 0.3220 | Time: 55.8s
Epoch 6 | Loss: 2.5224 | Acc: 0.3474 | Time: 56.3s
Epoch 7 | Loss: 2.4097 | Acc: 0.3710 | Time: 57.2s
Epoch 8 | Loss: 2.2967 | Acc: 0.3939 | Time: 58.2s
Epoch 9 | Loss: 2.1961 | Acc: 0.4162 | Time: 57.0s
Epoch 10 | Loss: 2.0823 | Acc: 0.4409 | Time: 56.4s
Epoch 11 | Loss: 1.9976 | Acc: 0.4574 | Time: 56.8s
Epoch 12 | Loss: 1.8917 | Acc: 0.4810 | Time: 57.9s
Epoch 13 | Loss: 1.8016 | Acc: 0.5026 | Time: 55.1s
Epoch 14 | Loss: 1.7087 | Acc: 0.5207 | Time: 54.8s
Epoch 15 | Loss: 1.6214 | Acc: 0.5426 | Time: 55.6s
Epoch 16 | Loss: 1.5370 | Acc: 0.5631 | Time: 56.3s
Epoch 17 | Loss: 1.4640 | Acc: 0.5773 | Time: 56.0s
Epoch 18 | Loss: 1.3876 | Acc: 0.5964 | Time: 56.6s
Epoch 19 | Loss: 1.3182 | Acc: 0.6122 | Time: 60.0s
Epoch 20 | Loss: 1.2485 | Acc: 0.6307 | Time: 57.9s
Epoch 21 | Loss: 1.1762 | Acc: 0.6527 | Time: 55.6s
Epoch 22 | Loss: 1.1289 | Acc: 0.6598 | Time: 56.7s
Epoch 23 | Loss: 1.0762 | Acc: 0.6748 | Time: 56.7s
Epoch 24 | Loss: 1.0145 | Acc: 0.6890 | Time: 57.4s
Epoch 25 | Loss: 0.9755 | Acc: 0.6991 | Time: 56.8s
Epoch 26 | Loss: 0.9296 | Acc: 0.7122 | Time: 53.8s
Epoch 27 | Loss: 0.8896 | Acc: 0.7235 | Time: 54.2s
Epoch 28 | Loss: 0.8536 | Acc: 0.7329 | Time: 53.8s
Epoch 29 | Loss: 0.8169 | Acc: 0.7424 | Time: 54.1s
Epoch 30 | Loss: 0.7873 | Acc: 0.7511 | Time: 54.2s
Epoch 31 | Loss: 0.7482 | Acc: 0.7644 | Time: 54.0s
Epoch 32 | Loss: 0.7343 | Acc: 0.7696 | Time: 53.8s
Epoch 33 | Loss: 0.7042 | Acc: 0.7775 | Time: 54.0s
Epoch 34 | Loss: 0.6793 | Acc: 0.7827 | Time: 53.9s
Epoch 35 | Loss: 0.6627 | Acc: 0.7883 | Time: 54.2s
Epoch 36 | Loss: 0.6359 | Acc: 0.7970 | Time: 53.4s
Epoch 37 | Loss: 0.6173 | Acc: 0.8026 | Time: 53.3s
Epoch 38 | Loss: 0.5972 | Acc: 0.8071 | Time: 53.5s
Epoch 39 | Loss: 0.5896 | Acc: 0.8101 | Time: 54.1s
Epoch 40 | Loss: 0.5706 | Acc: 0.8173 | Time: 53.7s
Epoch 41 | Loss: 0.5563 | Acc: 0.8210 | Time: 53.6s
Epoch 42 | Loss: 0.5368 | Acc: 0.8265 | Time: 54.5s
Epoch 43 | Loss: 0.5265 | Acc: 0.8311 | Time: 54.2s
Epoch 44 | Loss: 0.5198 | Acc: 0.8320 | Time: 54.1s
Epoch 45 | Loss: 0.5039 | Acc: 0.8367 | Time: 53.1s
Epoch 46 | Loss: 0.4898 | Acc: 0.8416 | Time: 53.7s
Epoch 47 | Loss: 0.4767 | Acc: 0.8455 | Time: 53.9s
Epoch 48 | Loss: 0.4665 | Acc: 0.8490 | Time: 53.8s
Epoch 49 | Loss: 0.4664 | Acc: 0.8468 | Time: 55.4s
Epoch 50 | Loss: 0.4482 | Acc: 0.8544 | Time: 53.7s
So we have the computational advantage of windowed attention, but clearly by removing the global context we have hurt accuracy.
According to the SWIN paper, thats why we need the shifted windows, so patches can actually have some overlap.
Lets add shifted windows now
In our
TransformerBlockwe will alternate between regular and shifted attention.A shifted block will have a
torch.rollcall. Everything will be moved to the left by 2 positions (we are talking about the patches here.)So before with the windowed attention, we had 4 non-verlapping regions of patches, each with 4x4 patches. Now, we roll those patches to create some overlap, so a regular block will still call the original 4x4 grid. The rolled one, will move patches 2 spots. Shifting the entire grid up by 2 rows and to the left by 2 columns.
We then apply the same window partitioning as before, but the rolled ones will recevie the rolled positions of patches.
Block 0 (regular): Block 1 (shifted):
┌──────┬──────┐ ┌──────┬──────┐
│ A │ B │ │ D C │ D C │
│ │ │ roll │ B A │ B A │
├──────┼──────┤ ------> ├──────┼──────┤
│ C │ D │ │ D C │ D C │
│ │ │ │ B A │ B A │
└──────┴──────┘ └──────┴──────┘
class MultiHeadAttentionShift(nn.Module):
# For windowed attention, we need to reshape before we pass to attention.
# We need to run attention independently for each window.
def __init__(self, dim, num_heads, grid_size,
window_size, shift, dropout=0.1):
super().__init__()
assert dim % num_heads == 0, "dim must be divisible by num_heads"
self.head_dim = dim // num_heads
self.grid_size = grid_size
self.window_size = window_size
self.num_heads = num_heads
self.heads = nn.ModuleList(
[AttentionHead(dim, self.head_dim, dropout) for _ in range(num_heads)]
)
self.proj = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)
self.shift = shift
self.shift_amount = window_size // 2 if shift else 0
def forward(self, x):
B, num_patches_with_cls, dim = x.shape
cls_token = x[:, :1, :]
x = x[:, 1:, :]
num_patches = x.shape[1]
num_windows = (self.grid_size // self.window_size) ** 2
# Window the patches
grid = x.reshape(B, self.grid_size, self.grid_size, dim)
if self.shift:
grid = torch.roll(grid, shifts=(-self.shift_amount, -self.shift_amount), dims=(1, 2))
windows = grid.reshape(B, self.grid_size // self.window_size, self.window_size,
self.grid_size // self.window_size, self.window_size, dim)
windows = windows.permute(0, 1, 3, 2, 4, 5)
windows = windows.reshape(B * num_windows, self.window_size * self.window_size, dim)
# Append CLS to every window so it can attend to all patches
cls_expanded = cls_token.repeat(num_windows, 1, 1) # [B*num_windows, 1, dim]
windows = torch.cat((cls_expanded, windows), dim=1) # [B*num_windows, 17, dim]
# Run attention
head_outputs = [head(windows) for head in self.heads]
out = torch.cat(head_outputs, dim=-1)
# Separate CLS and patches
cls_out = out[:, :1, :] # [B*num_windows, 1, dim]
out = out[:, 1:, :] # [B*num_windows, 16, dim]
# Average CLS across all windows (each window produced a CLS update)
cls_out = cls_out.reshape(B, num_windows, 1, dim).mean(dim=1) # [B, 1, dim]
# Reverse windowing for patches
nw = self.grid_size // self.window_size
out = out.reshape(B, nw, nw, self.window_size, self.window_size, dim)
out = out.permute(0, 1, 3, 2, 4, 5)
out = out.reshape(B, self.grid_size, self.grid_size, dim)
if self.shift:
grid = torch.roll(out, shifts=(self.shift_amount, self.shift_amount), dims=(1, 2))
out = out.reshape(B, num_patches, dim)
# Recombine
out = torch.cat((cls_out, out), dim=1)
out = self.proj(out)
out = self.dropout(out)
return out
class TransformerBlockRolled(nn.Module):
def __init__(self, dim, num_heads, mlp_dim,
grid_size, shift, window_size = None, dropout=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = MultiHeadAttentionShift(dim, num_heads, grid_size, window_size, shift, dropout)
self.norm2 = nn.LayerNorm(dim)
self.ffn = FeedForward(dim, mlp_dim, dropout)
self.shift = shift
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
class TransformerEncoderRolled(nn.Module):
def __init__(self, dim, depth, num_heads, mlp_dim, grid_size, window_size = None, dropout=0.1):
super().__init__()
self.layers = nn.ModuleList()
for i in range(depth):
shift = (i % 2 == 1)
print()
self.layers.append(TransformerBlockRolled(dim, num_heads, mlp_dim,grid_size,
shift,window_size, dropout) )
self.norm = nn.LayerNorm(dim)
def forward(self, x):
for block in self.layers:
x = block(x)
return self.norm(x)class ViTWindowedShifted(nn.Module):
def __init__(self, img_size, patch_size, num_hiddens, num_heads,
num_classes, window_size, depth=6, dropout=0.1):
super().__init__()
self.num_patches = (img_size // patch_size) ** 2
self.grid_size = img_size // patch_size
# --- Patch Embedding ---
self.patch_embedding = PatchEmbedding(img_size, patch_size, num_hiddens)
# --- CLS token ---
self.cls_token = nn.Parameter(torch.randn(1, 1, num_hiddens))
# --- Positional Embedding (+1 for CLS) ---
self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, num_hiddens))
# --- Transformer Encoder ---
self.encoder = TransformerEncoderRolled(num_hiddens, depth, num_heads,
mlp_dim=num_hiddens*4,
grid_size=self.grid_size,
window_size=window_size)
# --- Classification Head ---
self.classification_head = ClassificationHead(num_hiddens, num_classes)
def forward(self, x):
batch = x.shape[0]
# 1. Patch embed: [B, 3, H, W] -> [B, num_patches, num_hiddens]
patches = self.patch_embedding(x)
patches = patches.flatten(2).transpose(1, 2)
# 2. Prepend CLS token: [B, num_patches, D] -> [B, num_patches+1, D]
cls_tokens = self.cls_token.expand(batch, -1, -1)
x = torch.cat((cls_tokens, patches), dim=1)
# 3. Add positional embeddings
x = x + self.pos_embedding
# 4. Transformer encoder
x = self.encoder(x)
# 5. Classification: extract CLS token -> head
x = self.classification_head(x)
return x # placeholder — returns full sequence for nowmodel = ViTWindowedShifted(
img_size=56, patch_size=7, num_hiddens=256,
num_heads=8, num_classes=100, window_size =4, depth=2, dropout=0.1
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()
for epoch in range(50):
model.train()
total_loss, correct, total = 0, 0, 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
logits = model(images)
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
correct += (logits.argmax(dim=-1) == labels).sum().item()
total += labels.size(0)
print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f} | Acc: {correct/total:.4f}")Epoch 1 | Loss: 3.7394 | Acc: 0.1248
Epoch 2 | Loss: 3.1830 | Acc: 0.2174
Epoch 3 | Loss: 2.9413 | Acc: 0.2651
Epoch 4 | Loss: 2.7716 | Acc: 0.2981
Epoch 5 | Loss: 2.6348 | Acc: 0.3237
Epoch 6 | Loss: 2.5124 | Acc: 0.3502
Epoch 7 | Loss: 2.4017 | Acc: 0.3705
Epoch 8 | Loss: 2.2843 | Acc: 0.3942
Epoch 9 | Loss: 2.1829 | Acc: 0.4169
Epoch 10 | Loss: 2.0772 | Acc: 0.4395
Epoch 11 | Loss: 1.9787 | Acc: 0.4642
Epoch 12 | Loss: 1.8810 | Acc: 0.4828
Epoch 13 | Loss: 1.7851 | Acc: 0.5055
Epoch 14 | Loss: 1.6865 | Acc: 0.5281
Epoch 15 | Loss: 1.5974 | Acc: 0.5486
Epoch 16 | Loss: 1.5154 | Acc: 0.5672
Epoch 17 | Loss: 1.4368 | Acc: 0.5840
Epoch 18 | Loss: 1.3535 | Acc: 0.6067
Epoch 19 | Loss: 1.2870 | Acc: 0.6219
Epoch 20 | Loss: 1.2167 | Acc: 0.6377
Epoch 21 | Loss: 1.1499 | Acc: 0.6550
Epoch 22 | Loss: 1.1002 | Acc: 0.6680
Epoch 23 | Loss: 1.0403 | Acc: 0.6844
Epoch 24 | Loss: 0.9908 | Acc: 0.6974
Epoch 25 | Loss: 0.9418 | Acc: 0.7084
Epoch 26 | Loss: 0.8953 | Acc: 0.7232
Epoch 27 | Loss: 0.8451 | Acc: 0.7359
Epoch 28 | Loss: 0.8149 | Acc: 0.7467
Epoch 29 | Loss: 0.7849 | Acc: 0.7520
Epoch 30 | Loss: 0.7448 | Acc: 0.7640
Epoch 31 | Loss: 0.7195 | Acc: 0.7692
Epoch 32 | Loss: 0.6944 | Acc: 0.7786
Epoch 33 | Loss: 0.6670 | Acc: 0.7872
Epoch 34 | Loss: 0.6464 | Acc: 0.7931
Epoch 35 | Loss: 0.6207 | Acc: 0.7999
Epoch 36 | Loss: 0.6093 | Acc: 0.8052
Epoch 37 | Loss: 0.5885 | Acc: 0.8096
Epoch 38 | Loss: 0.5661 | Acc: 0.8179
Epoch 39 | Loss: 0.5576 | Acc: 0.8208
Epoch 40 | Loss: 0.5319 | Acc: 0.8264
Epoch 41 | Loss: 0.5228 | Acc: 0.8307
Epoch 42 | Loss: 0.5141 | Acc: 0.8318
Epoch 43 | Loss: 0.4983 | Acc: 0.8377
Epoch 44 | Loss: 0.4794 | Acc: 0.8457
Epoch 45 | Loss: 0.4794 | Acc: 0.8446
Epoch 46 | Loss: 0.4662 | Acc: 0.8480
Epoch 47 | Loss: 0.4619 | Acc: 0.8507
Epoch 48 | Loss: 0.4512 | Acc: 0.8529
Epoch 49 | Loss: 0.4297 | Acc: 0.8591
Epoch 50 | Loss: 0.4296 | Acc: 0.8604
Check how the pros do it
and their SwinTransformer Block
timm patritions windows like this
this actually supports non-square windows, which is a nice idea.
def window_partition(
x: torch.Tensor,
window_size: Tuple[int, int],
) -> torch.Tensor:
"""Partition into non-overlapping windows.
Args:
x: Input tokens with shape [B, H, W, C].
window_size: Window size.
Returns:
Windows after partition with shape [B * num_windows, window_size, window_size, C].
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
return windowsAnd this is how do they do window reverse:
def window_reverse(windows: torch.Tensor, window_size: Tuple[int, int], H: int, W: int) -> torch.Tensor:
"""Reverse window partition.
Args:
windows: Windows with shape (num_windows*B, window_size, window_size, C).
window_size: Window size.
H: Height of image.
W: Width of image.
Returns:
Tensor with shape (B, H, W, C).
"""
C = windows.shape[-1]
x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
return ximport timm
model = timm.create_model('vit_tiny_patch16_224', pretrained=False)
print(model) /Users/jpoberhauser/mambaforge3/envs/sam2_env/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
VisionTransformer(
(patch_embed): PatchEmbed(
(proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
(norm): Identity()
)
(pos_drop): Dropout(p=0.0, inplace=False)
(patch_drop): Identity()
(norm_pre): Identity()
(blocks): Sequential(
(0): Block(
(norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=192, out_features=576, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(norm): Identity()
(proj): Linear(in_features=192, out_features=192, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=192, out_features=768, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=768, out_features=192, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
(1): Block(
(norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=192, out_features=576, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(norm): Identity()
(proj): Linear(in_features=192, out_features=192, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=192, out_features=768, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=768, out_features=192, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
(2): Block(
(norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=192, out_features=576, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(norm): Identity()
(proj): Linear(in_features=192, out_features=192, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=192, out_features=768, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=768, out_features=192, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
(3): Block(
(norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=192, out_features=576, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(norm): Identity()
(proj): Linear(in_features=192, out_features=192, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=192, out_features=768, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=768, out_features=192, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
(4): Block(
(norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=192, out_features=576, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(norm): Identity()
(proj): Linear(in_features=192, out_features=192, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=192, out_features=768, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=768, out_features=192, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
(5): Block(
(norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=192, out_features=576, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(norm): Identity()
(proj): Linear(in_features=192, out_features=192, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=192, out_features=768, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=768, out_features=192, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
(6): Block(
(norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=192, out_features=576, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(norm): Identity()
(proj): Linear(in_features=192, out_features=192, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=192, out_features=768, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=768, out_features=192, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
(7): Block(
(norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=192, out_features=576, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(norm): Identity()
(proj): Linear(in_features=192, out_features=192, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=192, out_features=768, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=768, out_features=192, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
(8): Block(
(norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=192, out_features=576, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(norm): Identity()
(proj): Linear(in_features=192, out_features=192, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=192, out_features=768, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=768, out_features=192, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
(9): Block(
(norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=192, out_features=576, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(norm): Identity()
(proj): Linear(in_features=192, out_features=192, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=192, out_features=768, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=768, out_features=192, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
(10): Block(
(norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=192, out_features=576, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(norm): Identity()
(proj): Linear(in_features=192, out_features=192, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=192, out_features=768, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=768, out_features=192, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
(11): Block(
(norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=192, out_features=576, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(norm): Identity()
(proj): Linear(in_features=192, out_features=192, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=192, out_features=768, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=768, out_features=192, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
)
(norm): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
(fc_norm): Identity()
(head_drop): Dropout(p=0.0, inplace=False)
(head): Linear(in_features=192, out_features=1000, bias=True)
)