import torch
import torch.nn as nn
import torch.nn.functional as F
import mathBuild a ViT from Scratch - Exercises
Fill in the TODO sections to build a Vision Transformer from first principles.
Resources
Essential Paper: - https://arxiv.org/abs/2010.11929 - Dosovitskiy et al.
Reference Implementations: - https://github.com/lucidrains/vit-pytorch - Clean, minimal implementation - https://github.com/google-research/vision_transformer - Official JAX implementation - https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_vit.py
Tutorials & Explainers: - https://jalammar.github.io/illustrated-transformer/ - Jay Alammar - https://nlp.seas.harvard.edu/2018/04/03/attention.html
Building Blocks
| Component | Task |
|---|---|
| PatchEmbedding | Patch Embedding class |
| Attention | Single and multi-head attention |
| FeedForward | Define the MLP block |
| TransformerBlock | Combine attention + FFN + residuals + LayerNorm |
| ViT (full) | Add classification head, add MHA stacking |
| Training loop | Train on FashionMNIST |
Intuition for steps:
- Image to patches
- Think of these as the tokens in NLP. Words into tokens and images into patches.
- Patch embedding
- Converts patch into dense vectors. Token embeddings in NLP will be our patch embeddings in ViT.
- Flattened patches (needs to be 1d data vectors)
- Add positional embeddings
- Same as in NLP. Give transformer ability to learn location/spatial information.
- Add a
cls_tokenfor classification. Learnable token or vector to represent the whole image. - Feed into Transformer Encoder
- MHA -> Add and Norm -> FeedForward -> Add and Norm.
- Pass into classification head.
Step 1: Patch Embedding
- This takes an image and breaks it up into patches.
- We use a Conv2d where kernel_size = stride = patch_size.
- For a 224x224 image with patch_size=16, we get a 14x14 grid = 196 patches.
- Each patch gets projected to a vector of size
num_hiddens.
class PatchEmbedding(nn.Module):
def __init__(self, img_size, patch_size, num_hiddens):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_hiddens = num_hiddens
# TODO: Calculate num_patches from img_size and patch_size
self.num_patches = None
# TODO: Create a Conv2d that acts as patch extraction
# Hint: in_channels=3 (RGB), out_channels=num_hiddens
# kernel_size and stride should both equal patch_size
self.conv = None
def forward(self, x):
# TODO: Apply the convolution to x
passTest the PatchEmbedding
Let’s load an image and see what shapes we get.
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from PIL import Image
import requests
from io import BytesIO
url = "https://ultralytics.com/images/bus.jpg"
img = Image.open(BytesIO(requests.get(url).content)).convert('RGB')
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
img_tensor = transform(img).unsqueeze(0) # [1, 3, 224, 224]
# Test your PatchEmbedding
patch_emb = PatchEmbedding(img_size=224, patch_size=16, num_hiddens=512)
output = patch_emb(img_tensor)
print(output.size()) # Expected: [1, 512, 14, 14]# Visualize patches
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(img_tensor.squeeze().permute(1, 2, 0))
axes[0].set_title(f'Input: {tuple(img_tensor.shape)}')
axes[0].axis('off')
axes[1].imshow(img_tensor.squeeze().permute(1, 2, 0))
for i in range(0, 224, 16):
axes[1].axhline(y=i, color='red', linewidth=0.5)
axes[1].axvline(x=i, color='red', linewidth=0.5)
axes[1].set_title(f'14x14 = 196 patches (16x16 each)')
axes[1].axis('off')
out_viz = output.squeeze()[:16].detach()
grid = out_viz.reshape(4, 4, 14, 14).permute(0, 2, 1, 3).reshape(56, 56)
axes[2].imshow(grid, cmap='viridis')
axes[2].set_title(f'Output: {tuple(output.shape)}\n(showing 16 of 512 channels)')
axes[2].axis('off')
plt.tight_layout()
plt.show()# Flatten the patches for the transformer
# Conv2d output: [B, hidden_dim, H, W] -> [B, num_patches, hidden_dim]
output_flat = output.flatten(2) # [1, 512, 196]
print(output_flat.size())
output_flat_transposed = output_flat.transpose(1, 2)
print(output_flat_transposed.size()) # [1, 196, 512] - ready for transformerPositional Embeddings
Self-attention is permutation invariant. If you shuffle the patches, the output is the same (just shuffled). The model has no idea which patch is where.
Positional embeddings fix this by adding a unique signal to each patch position.
Two approaches: 1. Learned (what ViT uses): nn.Parameter(torch.randn(1, num_patches+1, hidden_dim)) 2. Sinusoidal (from “Attention Is All You Need”): fixed sin/cos waves at different frequencies
def sinusoidal_positional_encoding(num_positions, dim):
"""
PE(pos, 2i) = sin(pos / 10000^(2i/dim))
PE(pos, 2i+1) = cos(pos / 10000^(2i/dim))
"""
position = torch.arange(num_positions).unsqueeze(1).float() # [num_positions, 1]
# TODO: Compute div_term using the formula above
# Hint: use torch.exp and torch.arange(0, dim, 2) and math.log(10000.0)
div_term = None
pe = torch.zeros(num_positions, dim)
# TODO: Fill even indices with sin, odd indices with cos
# Hint: pe[:, 0::2] selects even columns, pe[:, 1::2] selects odd columns
return pe# Visualize sinusoidal positional encoding
pe = sinusoidal_positional_encoding(197, 512) # 196 patches + 1 CLS
fig, axes = plt.subplots(1, 2, figsize=(16, 5))
im = axes[0].imshow(pe.numpy(), aspect='auto', cmap='RdBu')
axes[0].set_xlabel('Embedding dimension')
axes[0].set_ylabel('Patch position')
axes[0].set_title('Sinusoidal Positional Encoding (197 positions x 512 dims)')
plt.colorbar(im, ax=axes[0])
for d in [0, 1, 10, 50, 100, 255]:
axes[1].plot(pe[:, d].numpy(), label=f'dim {d}')
axes[1].set_xlabel('Patch position')
axes[1].set_ylabel('Encoding value')
axes[1].set_title('Individual dimensions - low dims = high freq, high dims = low freq')
axes[1].legend()
plt.tight_layout()
plt.show()Reading the heatmap: * Each row is one patch position (0-196). Each column is one dimension. * Left columns (low dims): rapid alternating stripes = high-frequency waves. Help distinguish neighboring patches. * Right columns (high dims): wide slow-changing bands = low-frequency waves. Help distinguish distant patches. * Every position gets a unique combination of frequencies, like a barcode.
Transformer Blocks
FeedForward (MLP)
Two linear layers with a nonlinearity in between. Expands then compresses.
input (512) -> Linear -> GELU -> Linear -> output (512)
512 -> 2048 -> 2048 -> 512
- Attention tells the model which patches to look at (communication)
- FeedForward tells the model what to do with that info (computation)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.1):
super().__init__()
# TODO: Create nn.Sequential with:
# Linear(dim -> hidden_dim), GELU, Dropout, Linear(hidden_dim -> dim), Dropout
self.net = None
def forward(self, x):
# TODO: Pass x through self.net
passSingle Attention Head
Every patch produces three things: * Query (Q): “What am I looking for?” * Key (K): “What do I contain?” * Value (V): “Here’s my actual content”
Attention score: softmax(Q @ K^T / sqrt(head_dim)) @ V
Q @ K^T= dot product of every query with every key -> NxN attention matrix- Scale by
1/sqrt(head_dim)to keep softmax in a useful range - Softmax turns scores into probability weights
- Weighted sum with V gives the output
class AttentionHead(nn.Module):
def __init__(self, dim, head_dim, dropout=0.1):
super().__init__()
# TODO: Define scale factor = head_dim ** -0.5
self.scale = None
# TODO: Define Q, K, V linear projections (dim -> head_dim, no bias)
self.q = None
self.k = None
self.v = None
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# TODO: Project x into q, k, v
# TODO: Compute attention scores: (Q @ K^T) * scale
# Hint: use torch.matmul(q, k.transpose(-1, -2))
# TODO: Apply softmax along dim=-1
# TODO: Apply dropout
# TODO: Compute weighted sum: attn @ V
# Output shape: [B, num_patches, head_dim]
passMulti-Head Attention
Instead of one big attention head, we run multiple smaller heads in parallel. With dim=512 and num_heads=8, each head works in 64 dimensions. Each head can learn a different “reason” to attend.
After all heads run, we concatenate their outputs and apply an output projection to mix across heads.
class MultiHeadAttention(nn.Module):
def __init__(self, dim, num_heads, dropout=0.1):
super().__init__()
assert dim % num_heads == 0, "dim must be divisible by num_heads"
self.head_dim = dim // num_heads
self.num_heads = num_heads
# TODO: Create nn.ModuleList of AttentionHead instances
# Each head: AttentionHead(dim, self.head_dim, dropout)
self.heads = None
# TODO: Output projection: nn.Linear(dim, dim)
self.proj = None
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# TODO: Run each head on x, collect outputs in a list
# TODO: Concatenate along dim=-1 (8 heads x 64 = 512)
# TODO: Apply output projection and dropout
# Output shape: [B, num_patches, dim]
passTransformer Block
The core repeating unit. Pre-norm style (what ViT uses):
Input
|
+---- LayerNorm -> MultiHeadAttention ----+
| |
+------------- + (residual) <-------------+
|
+---- LayerNorm -> FeedForward -----------+
| |
+------------- + (residual) <-------------+
|
Output
The residual connections (x + ...) let gradients flow and make deep stacking possible.
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_dim, dropout=0.1):
super().__init__()
# TODO: Define norm1, attn (MultiHeadAttention), norm2, ffn (FeedForward)
self.norm1 = None
self.attn = None
self.norm2 = None
self.ffn = None
def forward(self, x):
# TODO: Pre-norm attention with residual: x = x + self.attn(self.norm1(x))
# TODO: Pre-norm feedforward with residual: x = x + self.ffn(self.norm2(x))
passTransformer Encoder
Just stacking N TransformerBlocks in sequence, with a final LayerNorm.
Early blocks learn low-level features (edges, textures). Later blocks learn high-level features (semantic meaning).
class TransformerEncoder(nn.Module):
def __init__(self, dim, depth, num_heads, mlp_dim, dropout=0.1):
super().__init__()
# TODO: Create nn.ModuleList of TransformerBlock instances
self.layers = None
# TODO: Final LayerNorm
self.norm = None
def forward(self, x):
# TODO: Loop through blocks, then apply final norm
pass# Sanity check - test your encoder
encoder = TransformerEncoder(dim=512, depth=6, num_heads=8, mlp_dim=2048, dropout=0.1)
dummy = torch.randn(16, 197, 512) # (batch, num_patches+1, hidden_dim)
out = encoder(dummy)
print(out.shape) # should be [16, 197, 512]
print(f"Parameters: {sum(p.numel() for p in encoder.parameters()):,}")Classification Head
The CLS token (position 0) has attended to all patches across all layers. Extract it and map to class logits.
class ClassificationHead(nn.Module):
def __init__(self, hidden_dim, num_classes):
super().__init__()
# TODO: LayerNorm(hidden_dim) and Linear(hidden_dim, num_classes)
self.norm = None
self.linear = None
def forward(self, x):
# TODO: Extract CLS token: x[:, 0]
# TODO: Normalize and classify
# Output shape: [B, num_classes]
passFull ViT
Put it all together: patch embed -> prepend CLS -> add pos embeddings -> encoder -> classify
class ViT(nn.Module):
def __init__(self, img_size, patch_size, num_hiddens, num_heads, num_classes, depth=6, dropout=0.1):
super().__init__()
self.num_patches = (img_size // patch_size) ** 2
# TODO: Initialize these components:
# self.patch_embedding = PatchEmbedding(...)
# self.cls_token = nn.Parameter(torch.randn(1, 1, num_hiddens))
# self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, num_hiddens)) # +1 for CLS
# self.encoder = TransformerEncoder(...)
# self.classification_head = ClassificationHead(...)
pass
def forward(self, x):
batch = x.shape[0]
# TODO: 1. Patch embed: patches = self.patch_embedding(x), then flatten and transpose
# TODO: 2. Prepend CLS token: expand to batch size, then torch.cat
# TODO: 3. Add positional embeddings
# TODO: 4. Run through encoder
# TODO: 5. Classification head
passTest on FashionMNIST
Note: FashionMNIST is grayscale 28x28. We resize to 56x56 and repeat to 3 channels.
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.Resize((56, 56)),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
])
train_dataset = FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = FashionMNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using: {device}")
model = ViT(
img_size=56, patch_size=7, num_hiddens=256,
num_heads=8, num_classes=10, depth=2, dropout=0.1
).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()
for epoch in range(5):
model.train()
total_loss, correct, total = 0, 0, 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
logits = model(images)
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
correct += (logits.argmax(dim=-1) == labels).sum().item()
total += labels.size(0)
print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f} | Acc: {correct/total:.4f}")model.eval()
correct, total = 0, 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
logits = model(images)
correct += (logits.argmax(dim=-1) == labels).sum().item()
total += labels.size(0)
print(f"Test Accuracy: {correct/total:.4f}")