from vit import ViT, PatchEmbedding, MultiHeadAttention, TransformerEncoder ,ClassificationHead , FeedForwardCompare ViT with Swin, CoAtNet, DeiT
- Swin Transformer: https://arxiv.org/abs/2103.14030
- DeiT (Data-efficient Image Transformers): https://arxiv.org/abs/2012.12877
- CoAtNet: https://arxiv.org/abs/2106.04803
- Attention Is All You Need: https://arxiv.org/abs/1706.03762
- timm repo (has all these models implemented): https://github.com/huggingface/pytorch-image-models
- The Annotated Transformer (Harvard NLP walkthrough): https://nlp.seas.harvard.edu/annotated-transformer/
Key Ideas Oversimplified:
Swin Transformer
- Shifted windows. Instead of global attention (every patch attends to every patch), Swin gets attentoin within local windows. (for example 7x7 patches). Then, it shifts the window grid by half a window between layers do information can flow across boundaries.
- Hierarchichal feature maps. unlike ViT which keeps the same resolution thrughout, Swin progressively merges patches. (Think of this as pooling in CNNs). This gives us multi-scale features which will become very important in detection and segmentation later on.
- Overall, Swim gives efficiency and versatiltiy. You can now get comparable accuracy with local attention of you shift the windows in a clever way.
DEiT
- showed you can get competitive results with less data, as long as you have the right training recipe. String augmentation, good regularization, and some careful hyperparameter tuning.
- Distillation token. DEiT will add a second special token on top of the cls_token to learn from a CNN model. The CNN can teach the ViT what to attend to during training.
CoAtNet
- Brings in the best of both worlds. CNNS are good at local geatures and have nice inductive biases. Transformers are good at global relationships. If you stack CNNs and transformer blocks, you get good results.
- Relative Attention. Take away the absolute positional embeddings like we had in the vanilla ViT, and instead use relative position biases in attentnions scores. This generalizes nicely to different input sizes.
Building on top of our ViT, we export everything we wrote in the first module to vit.py. We can import that work and build on top of it
Baseline Vanilla ViT on CIFAR-100
from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.nn.functional as F # This gives us the softmax()
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.Resize((56, 56)),
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])
train_dataset = CIFAR100(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR100(root='./data', train=False, download=True, transform=transform)device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using: {device}")train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)model = ViT(
img_size=56, patch_size=7, num_hiddens=256,
num_heads=8, num_classes=100, depth=2, dropout=0.1
).to(device)import timeoptimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()
for epoch in range(50):
model.train()
total_loss, correct, total = 0, 0, 0
start = time.time()
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
logits = model(images)
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
correct += (logits.argmax(dim=-1) == labels).sum().item()
total += labels.size(0)
elapsed = time.time() - start
print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f} | Acc: {correct/total:.4f} | Time: {elapsed:.1f}s")Adding windowed attention like SWIN
- Step1: Windowed attention
- lets run a 56x56 image through the steps to understand which changes and where we need to make
class AttentionHead(nn.Module):
def __init__(self, dim, head_dim, dropout=0.1):
super().__init__()
# TODO: Store scale factor (head_dim**-0.5)
# TODO: Create Q, K, V linear projections (dim -> head_dim, no bias)
# TODO: Create dropout layer
pass
def forward(self, x):
# TODO: Compute Q, K, V from x
# TODO: Compute attention scores (Q @ K^T * scale)
# TODO: Apply softmax on dim=-1
# TODO: Apply dropout
# TODO: Multiply attention weights by V and return
pass
class MultiHeadAttention(nn.Module):
# For windowed attention, we need to reshape before we pass to attention.
# We need to run attention independently for each window.
def __init__(self, dim, num_heads, dropout=0.1):
super().__init__()
assert dim % num_heads == 0, "dim must be divisible by num_heads"
# TODO: Compute head_dim = dim // num_heads
# TODO: Store num_heads
# TODO: Create ModuleList of AttentionHead instances (one per head)
# TODO: Create output projection linear layer (dim -> dim)
# TODO: Create dropout layer
pass
def forward(self, x):
# TODO: Run each head on x and collect outputs
# TODO: Concatenate head outputs along last dim
# TODO: Apply projection and dropout
passx = torch.randn(2, 3, 56, 56)patch_emb = PatchEmbedding(img_size=56, patch_size=7, num_hiddens=256)
patches = patch_emb(x)
print("After conv:", patches.shape) # [1, 256, 8, 8]
# hidden dim is the size of the vector that represent each patch. So since we choose 256, we expect each patch to be a vector of len 256
patches_flat = patches.flatten(2).transpose(1, 2)
print("Flattened:", patches_flat.shape) # [2, 64, 256] — 8x8=64 patches. Each one of the 64 patches is now represented by a 256 vector. head = AttentionHead(dim=256, head_dim=32)
q = head.q(patches_flat)
k = head.k(patches_flat)
v = head.v(patches_flat)
print("Q:", q.shape) # [2, 64, 32]
print("K:", k.shape) # [2, 64, 32]. Why 32? Because in MHA, we split the input. So each head recevies 32 of the 256. 32*8 = 256.
attn_scores = torch.matmul(q, k.transpose(-1, -2))
print("Attention matrix:", attn_scores.shape) # [2, 64, 64] — every patch attends to every patch- these are the inputs to attention. 64 patches, each represented by a a 32-dimensional vector. Its 32 because each head receives 256/num_heads, so 256/8 = 32.
# The 64 patches are an 8x8 grid. Window size 4 means 2x2 windows of 4x4 patches each.
B, num_patches, dim = patches_flat.shape
grid_size = 8 # sqrt(64)
window_size = 4
grid = patches_flat.reshape(B, grid_size, grid_size, dim)
print("Grid:", grid.shape) # [2, 8, 8, 256]# Split into windows
# [B, 8, 8, dim] -> [B, 2, 4, 2, 4, dim] -> [B, 2, 2, 4, 4, dim]
windows = grid.reshape(B, grid_size // window_size, window_size, grid_size // window_size, window_size, dim)
windows = windows.permute(0, 1, 3, 2, 4, 5)
print("Windows rearranged:", windows.shape) # [2, 2, 2, 4, 4, 256]- The above reshaping. We have an 8x8 grid. But we specify window size of 4. So the 8x8 grid will now be 4 regions, each a 4x4 grid of original patches. So instead of every patch attentding to all others in the 8x8 grid, they only attend to the region, so 4x4.
┌───────────┬───────────┐
│ │ │
│ window │ window │
│ (0,0) │ (0,1) │
│ 4×4 │ 4×4 │
│ │ │
├───────────┼───────────┤
│ │ │
│ window │ window │
│ (1,0) │ (1,1) │
│ 4×4 │ 4×4 │
│ │ │
└───────────┴───────────┘
- B=2 — batch
- 2, 2 — which window (row, col) — 4 windows total
- 4, 4 — which patch within that window — 16 patches per window
- 256 — hidden dim
# Merge batch and window dims so we can run attention as-is
num_windows = (grid_size // window_size) ** 2 # 4
windows = windows.reshape(B * num_windows, window_size * window_size, dim)
print("Ready for attention:", windows.shape) # [8, 16, 256] — 8 = 2 batches * 4 windows, 16 patches each
# Step 4: Run attention on windowed input
attn_out = head(windows)
print("Attention output:", attn_out.shape) # [8, 16, 32]
# Step 5: Reverse the reshape to get back to [B, 64, dim]
# [8, 16, 32] -> [2, 2, 2, 4, 4, 32] -> [2, 8, 8, 32] -> [2, 64, 32]
attn_out = attn_out.reshape(B, grid_size // window_size, grid_size // window_size, window_size, window_size, -1)
attn_out = attn_out.permute(0, 1, 3, 2, 4, 5)
attn_out = attn_out.reshape(B, num_patches, -1)
print("Back to original shape:", attn_out.shape) # [2, 64, 32]So all we have to do is reshape in MHA
class MultiHeadAttentionShift(nn.Module):
# For windowed attention, we need to reshape before we pass to attention.
# We need to run attention independently for each window.
def __init__(self, dim, num_heads, grid_size,
window_size, dropout=0.1):
super().__init__()
assert dim % num_heads == 0, "dim must be divisible by num_heads"
# TODO: Compute and store head_dim = dim // num_heads
# TODO: Store grid_size, window_size, num_heads
# TODO: Create ModuleList of AttentionHead instances (one per head)
# TODO: Create output projection linear layer (dim -> dim)
# TODO: Create dropout layer
pass
def forward(self, x):
# x shape: [B, num_patches_with_cls, dim]
# TODO: Separate CLS token (first token) from patch tokens
# cls_token = x[:, :1, :]
# x = x[:, 1:, :]
# TODO: Compute num_windows = (grid_size // window_size) ** 2
# TODO: Reshape patches into grid: [B, grid_size, grid_size, dim]
# TODO: Partition grid into windows:
# reshape to [B, grid_size//window_size, window_size, grid_size//window_size, window_size, dim]
# permute to [B, nw, nw, ws, ws, dim]
# reshape to [B*num_windows, window_size*window_size, dim]
# TODO: Append CLS token to every window so it can attend to all patches
# Expand cls_token to [B*num_windows, 1, dim] using repeat
# Concatenate: [B*num_windows, window_size*window_size + 1, dim]
# TODO: Run each attention head and concatenate outputs
# TODO: Separate CLS output (first token) from patch outputs
# TODO: Average CLS across all windows:
# reshape to [B, num_windows, 1, dim] then .mean(dim=1) -> [B, 1, dim]
# TODO: Reverse windowing for patches:
# reshape to [B, nw, nw, ws, ws, dim]
# permute back to [B, grid_size, grid_size, dim]
# reshape to [B, num_patches, dim]
# TODO: Recombine CLS and patches: cat along dim=1
# TODO: Apply projection and dropout
pass
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_dim,
grid_size, window_size=None, dropout=0.1):
super().__init__()
# TODO: Create LayerNorm for pre-attention normalization
# TODO: Create MultiHeadAttentionShift (with grid_size, window_size)
# TODO: Create LayerNorm for pre-FFN normalization
# TODO: Create FeedForward layer
pass
def forward(self, x):
# TODO: Apply pre-norm -> attention -> residual connection
# TODO: Apply pre-norm -> FFN -> residual connection
pass
class TransformerEncoder(nn.Module):
def __init__(self, dim, depth, num_heads, mlp_dim, grid_size, window_size=None, dropout=0.1):
super().__init__()
# TODO: Create ModuleList of TransformerBlock instances (one per depth)
# TODO: Create final LayerNorm
pass
def forward(self, x):
# TODO: Loop through all blocks
# TODO: Apply final LayerNorm
passclass ViTWindowed(nn.Module):
def __init__(self, img_size, patch_size, num_hiddens, num_heads,
num_classes, window_size, depth=6, dropout=0.1):
super().__init__()
# TODO: Compute num_patches = (img_size // patch_size) ** 2
# TODO: Compute grid_size = img_size // patch_size
# TODO: Create PatchEmbedding(img_size, patch_size, num_hiddens)
# TODO: Create CLS token as nn.Parameter(torch.randn(1, 1, num_hiddens))
# TODO: Create positional embedding as nn.Parameter(torch.randn(1, num_patches + 1, num_hiddens))
# TODO: Create TransformerEncoder with grid_size and window_size
# mlp_dim should be num_hiddens * 4
# TODO: Create ClassificationHead(num_hiddens, num_classes)
pass
def forward(self, x):
# TODO: Get batch size
# TODO: 1. Patch embed: [B, 3, H, W] -> flatten and transpose to [B, num_patches, num_hiddens]
# TODO: 2. Prepend CLS token (expand to batch size, then cat)
# TODO: 3. Add positional embeddings
# TODO: 4. Pass through transformer encoder
# TODO: 5. Pass through classification head and return
passmodel = ViTWindowed(
img_size=56, patch_size=7, num_hiddens=256,
num_heads=8, num_classes=100,window_size =4, depth=2, dropout=0.1
).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()
for epoch in range(50):
model.train()
total_loss, correct, total = 0, 0, 0
start = time.time()
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
logits = model(images)
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
correct += (logits.argmax(dim=-1) == labels).sum().item()
total += labels.size(0)
elapsed = time.time() - start
print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f} | Acc: {correct/total:.4f} | Time: {elapsed:.1f}s")So we have the computational advantage of windowed attention, but clearly by removing the global context we have hurt accuracy.
According to the SWIN paper, thats why we need the shifted windows, so patches can actually have some overlap.
Lets add shifted windows now
In our
TransformerBlockwe will alternate between regular and shifted attention.A shifted block will have a
torch.rollcall. Everything will be moved to the left by 2 positions (we are talking about the patches here.)So before with the windowed attention, we had 4 non-verlapping regions of patches, each with 4x4 patches. Now, we roll those patches to create some overlap, so a regular block will still call the original 4x4 grid. The rolled one, will move patches 2 spots. Shifting the entire grid up by 2 rows and to the left by 2 columns.
We then apply the same window partitioning as before, but the rolled ones will recevie the rolled positions of patches.
Block 0 (regular): Block 1 (shifted):
┌──────┬──────┐ ┌──────┬──────┐
│ A │ B │ │ D C │ D C │
│ │ │ roll │ B A │ B A │
├──────┼──────┤ ------> ├──────┼──────┤
│ C │ D │ │ D C │ D C │
│ │ │ │ B A │ B A │
└──────┴──────┘ └──────┴──────┘
class MultiHeadAttentionShift(nn.Module):
# For windowed attention, we need to reshape before we pass to attention.
# We need to run attention independently for each window.
def __init__(self, dim, num_heads, grid_size,
window_size, shift, dropout=0.1):
super().__init__()
assert dim % num_heads == 0, "dim must be divisible by num_heads"
# TODO: Compute and store head_dim = dim // num_heads
# TODO: Store grid_size, window_size, num_heads, shift
# TODO: Create ModuleList of AttentionHead instances (one per head)
# TODO: Create output projection linear layer (dim -> dim)
# TODO: Create dropout layer
# TODO: Compute shift_amount = window_size // 2 if shift else 0
pass
def forward(self, x):
# x shape: [B, num_patches_with_cls, dim]
# TODO: Separate CLS token (first token) from patch tokens
# TODO: Compute num_windows = (grid_size // window_size) ** 2
# TODO: Reshape patches into grid: [B, grid_size, grid_size, dim]
# TODO: If shift is True, apply torch.roll on the grid:
# grid = torch.roll(grid, shifts=(-shift_amount, -shift_amount), dims=(1, 2))
# TODO: Partition grid into windows (same as before):
# reshape -> permute -> reshape to [B*num_windows, ws*ws, dim]
# TODO: Append CLS token to every window
# TODO: Run each attention head and concatenate outputs
# TODO: Separate CLS output from patch outputs
# TODO: Average CLS across all windows -> [B, 1, dim]
# TODO: Reverse windowing for patches -> [B, grid_size, grid_size, dim]
# TODO: If shift is True, reverse the roll:
# torch.roll(out, shifts=(shift_amount, shift_amount), dims=(1, 2))
# TODO: Reshape back to [B, num_patches, dim]
# TODO: Recombine CLS and patches, apply projection and dropout
pass
class TransformerBlockRolled(nn.Module):
def __init__(self, dim, num_heads, mlp_dim,
grid_size, shift, window_size=None, dropout=0.1):
super().__init__()
# TODO: Create LayerNorm for pre-attention normalization
# TODO: Create MultiHeadAttentionShift (pass shift parameter)
# TODO: Create LayerNorm for pre-FFN normalization
# TODO: Create FeedForward layer
# TODO: Store shift flag
pass
def forward(self, x):
# TODO: Apply pre-norm -> attention -> residual connection
# TODO: Apply pre-norm -> FFN -> residual connection
pass
class TransformerEncoderRolled(nn.Module):
def __init__(self, dim, depth, num_heads, mlp_dim, grid_size, window_size=None, dropout=0.1):
super().__init__()
# TODO: Create ModuleList of TransformerBlockRolled instances
# For each layer i, set shift = (i % 2 == 1)
# This alternates between regular (even) and shifted (odd) blocks
# TODO: Create final LayerNorm
pass
def forward(self, x):
# TODO: Loop through all blocks
# TODO: Apply final LayerNorm
passclass ViTWindowedShifted(nn.Module):
def __init__(self, img_size, patch_size, num_hiddens, num_heads,
num_classes, window_size, depth=6, dropout=0.1):
super().__init__()
# TODO: Compute num_patches = (img_size // patch_size) ** 2
# TODO: Compute grid_size = img_size // patch_size
# TODO: Create PatchEmbedding(img_size, patch_size, num_hiddens)
# TODO: Create CLS token as nn.Parameter(torch.randn(1, 1, num_hiddens))
# TODO: Create positional embedding as nn.Parameter(torch.randn(1, num_patches + 1, num_hiddens))
# TODO: Create TransformerEncoderRolled (note: use Rolled version here!)
# mlp_dim should be num_hiddens * 4
# TODO: Create ClassificationHead(num_hiddens, num_classes)
pass
def forward(self, x):
# TODO: Get batch size
# TODO: 1. Patch embed: [B, 3, H, W] -> flatten and transpose to [B, num_patches, num_hiddens]
# TODO: 2. Prepend CLS token (expand to batch size, then cat)
# TODO: 3. Add positional embeddings
# TODO: 4. Pass through transformer encoder
# TODO: 5. Pass through classification head and return
passmodel = ViTWindowedShifted(
img_size=56, patch_size=7, num_hiddens=256,
num_heads=8, num_classes=100, window_size =4, depth=2, dropout=0.1
).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()
for epoch in range(50):
model.train()
total_loss, correct, total = 0, 0, 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
logits = model(images)
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
correct += (logits.argmax(dim=-1) == labels).sum().item()
total += labels.size(0)
print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f} | Acc: {correct/total:.4f}")Check how the pros do it
and their SwinTransformer Block
timm patritions windows like this
this actually supports non-square windows, which is a nice idea.
def window_partition(
x: torch.Tensor,
window_size: Tuple[int, int],
) -> torch.Tensor:
"""Partition into non-overlapping windows.
Args:
x: Input tokens with shape [B, H, W, C].
window_size: Window size.
Returns:
Windows after partition with shape [B * num_windows, window_size, window_size, C].
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
return windowsAnd this is how do they do window reverse:
def window_reverse(windows: torch.Tensor, window_size: Tuple[int, int], H: int, W: int) -> torch.Tensor:
"""Reverse window partition.
Args:
windows: Windows with shape (num_windows*B, window_size, window_size, C).
window_size: Window size.
H: Height of image.
W: Width of image.
Returns:
Tensor with shape (B, H, W, C).
"""
C = windows.shape[-1]
x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
return ximport timm
model = timm.create_model('vit_tiny_patch16_224', pretrained=False)
print(model)