Build a ViT from Scratch - Exercises

Fill in the TODO sections to build a Vision Transformer from first principles.

Resources

Essential Paper: - https://arxiv.org/abs/2010.11929 - Dosovitskiy et al.

Reference Implementations: - https://github.com/lucidrains/vit-pytorch - Clean, minimal implementation - https://github.com/google-research/vision_transformer - Official JAX implementation - https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_vit.py

Tutorials & Explainers: - https://jalammar.github.io/illustrated-transformer/ - Jay Alammar - https://nlp.seas.harvard.edu/2018/04/03/attention.html


Building Blocks

Component Task
PatchEmbedding Patch Embedding class
Attention Single and multi-head attention
FeedForward Define the MLP block
TransformerBlock Combine attention + FFN + residuals + LayerNorm
ViT (full) Add classification head, add MHA stacking
Training loop Train on FashionMNIST

Intuition for steps:

  1. Image to patches
    1. Think of these as the tokens in NLP. Words into tokens and images into patches.
  2. Patch embedding
    1. Converts patch into dense vectors. Token embeddings in NLP will be our patch embeddings in ViT.
  3. Flattened patches (needs to be 1d data vectors)
  4. Add positional embeddings
    1. Same as in NLP. Give transformer ability to learn location/spatial information.
  5. Add a cls_token for classification. Learnable token or vector to represent the whole image.
  6. Feed into Transformer Encoder
    1. MHA -> Add and Norm -> FeedForward -> Add and Norm.
  7. Pass into classification head.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

Step 1: Patch Embedding

  • This takes an image and breaks it up into patches.
  • We use a Conv2d where kernel_size = stride = patch_size.
  • For a 224x224 image with patch_size=16, we get a 14x14 grid = 196 patches.
  • Each patch gets projected to a vector of size num_hiddens.
class PatchEmbedding(nn.Module):
    def __init__(self, img_size, patch_size, num_hiddens):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_hiddens = num_hiddens
        # TODO: Calculate num_patches from img_size and patch_size
        self.num_patches = None

        # TODO: Create a Conv2d that acts as patch extraction
        # Hint: in_channels=3 (RGB), out_channels=num_hiddens
        # kernel_size and stride should both equal patch_size
        self.conv = None

    def forward(self, x):
        # TODO: Apply the convolution to x
        pass

Test the PatchEmbedding

Let’s load an image and see what shapes we get.

import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from PIL import Image
import requests
from io import BytesIO

url = "https://ultralytics.com/images/bus.jpg"
img = Image.open(BytesIO(requests.get(url).content)).convert('RGB')

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])
img_tensor = transform(img).unsqueeze(0)  # [1, 3, 224, 224]

# Test your PatchEmbedding
patch_emb = PatchEmbedding(img_size=224, patch_size=16, num_hiddens=512)
output = patch_emb(img_tensor)
print(output.size())  # Expected: [1, 512, 14, 14]
# Visualize patches
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(img_tensor.squeeze().permute(1, 2, 0))
axes[0].set_title(f'Input: {tuple(img_tensor.shape)}')
axes[0].axis('off')

axes[1].imshow(img_tensor.squeeze().permute(1, 2, 0))
for i in range(0, 224, 16):
    axes[1].axhline(y=i, color='red', linewidth=0.5)
    axes[1].axvline(x=i, color='red', linewidth=0.5)
axes[1].set_title(f'14x14 = 196 patches (16x16 each)')
axes[1].axis('off')

out_viz = output.squeeze()[:16].detach()
grid = out_viz.reshape(4, 4, 14, 14).permute(0, 2, 1, 3).reshape(56, 56)
axes[2].imshow(grid, cmap='viridis')
axes[2].set_title(f'Output: {tuple(output.shape)}\n(showing 16 of 512 channels)')
axes[2].axis('off')

plt.tight_layout()
plt.show()
# Flatten the patches for the transformer
# Conv2d output: [B, hidden_dim, H, W] -> [B, num_patches, hidden_dim]
output_flat = output.flatten(2)        # [1, 512, 196]
print(output_flat.size())

output_flat_transposed = output_flat.transpose(1, 2)
print(output_flat_transposed.size())   # [1, 196, 512] - ready for transformer

Positional Embeddings

Self-attention is permutation invariant. If you shuffle the patches, the output is the same (just shuffled). The model has no idea which patch is where.

Positional embeddings fix this by adding a unique signal to each patch position.

Two approaches: 1. Learned (what ViT uses): nn.Parameter(torch.randn(1, num_patches+1, hidden_dim)) 2. Sinusoidal (from “Attention Is All You Need”): fixed sin/cos waves at different frequencies

def sinusoidal_positional_encoding(num_positions, dim):
    """
    PE(pos, 2i)   = sin(pos / 10000^(2i/dim))
    PE(pos, 2i+1) = cos(pos / 10000^(2i/dim))
    """
    position = torch.arange(num_positions).unsqueeze(1).float()  # [num_positions, 1]
    # TODO: Compute div_term using the formula above
    # Hint: use torch.exp and torch.arange(0, dim, 2) and math.log(10000.0)
    div_term = None

    pe = torch.zeros(num_positions, dim)
    # TODO: Fill even indices with sin, odd indices with cos
    # Hint: pe[:, 0::2] selects even columns, pe[:, 1::2] selects odd columns

    return pe
# Visualize sinusoidal positional encoding
pe = sinusoidal_positional_encoding(197, 512)  # 196 patches + 1 CLS

fig, axes = plt.subplots(1, 2, figsize=(16, 5))

im = axes[0].imshow(pe.numpy(), aspect='auto', cmap='RdBu')
axes[0].set_xlabel('Embedding dimension')
axes[0].set_ylabel('Patch position')
axes[0].set_title('Sinusoidal Positional Encoding (197 positions x 512 dims)')
plt.colorbar(im, ax=axes[0])

for d in [0, 1, 10, 50, 100, 255]:
    axes[1].plot(pe[:, d].numpy(), label=f'dim {d}')
axes[1].set_xlabel('Patch position')
axes[1].set_ylabel('Encoding value')
axes[1].set_title('Individual dimensions - low dims = high freq, high dims = low freq')
axes[1].legend()

plt.tight_layout()
plt.show()

Reading the heatmap: * Each row is one patch position (0-196). Each column is one dimension. * Left columns (low dims): rapid alternating stripes = high-frequency waves. Help distinguish neighboring patches. * Right columns (high dims): wide slow-changing bands = low-frequency waves. Help distinguish distant patches. * Every position gets a unique combination of frequencies, like a barcode.

Transformer Blocks

FeedForward (MLP)

Two linear layers with a nonlinearity in between. Expands then compresses.

input (512) -> Linear -> GELU -> Linear -> output (512)
       512  ->  2048           ->  2048  ->  512
  • Attention tells the model which patches to look at (communication)
  • FeedForward tells the model what to do with that info (computation)
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.1):
        super().__init__()
        # TODO: Create nn.Sequential with:
        # Linear(dim -> hidden_dim), GELU, Dropout, Linear(hidden_dim -> dim), Dropout
        self.net = None

    def forward(self, x):
        # TODO: Pass x through self.net
        pass

Single Attention Head

Every patch produces three things: * Query (Q): “What am I looking for?” * Key (K): “What do I contain?” * Value (V): “Here’s my actual content”

Attention score: softmax(Q @ K^T / sqrt(head_dim)) @ V

  • Q @ K^T = dot product of every query with every key -> NxN attention matrix
  • Scale by 1/sqrt(head_dim) to keep softmax in a useful range
  • Softmax turns scores into probability weights
  • Weighted sum with V gives the output
class AttentionHead(nn.Module):
    def __init__(self, dim, head_dim, dropout=0.1):
        super().__init__()
        # TODO: Define scale factor = head_dim ** -0.5
        self.scale = None
        # TODO: Define Q, K, V linear projections (dim -> head_dim, no bias)
        self.q = None
        self.k = None
        self.v = None
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # TODO: Project x into q, k, v
        # TODO: Compute attention scores: (Q @ K^T) * scale
        # Hint: use torch.matmul(q, k.transpose(-1, -2))
        # TODO: Apply softmax along dim=-1
        # TODO: Apply dropout
        # TODO: Compute weighted sum: attn @ V
        # Output shape: [B, num_patches, head_dim]
        pass

Multi-Head Attention

Instead of one big attention head, we run multiple smaller heads in parallel. With dim=512 and num_heads=8, each head works in 64 dimensions. Each head can learn a different “reason” to attend.

After all heads run, we concatenate their outputs and apply an output projection to mix across heads.

class MultiHeadAttention(nn.Module):
    def __init__(self, dim, num_heads, dropout=0.1):
        super().__init__()
        assert dim % num_heads == 0, "dim must be divisible by num_heads"
        self.head_dim = dim // num_heads
        self.num_heads = num_heads
        # TODO: Create nn.ModuleList of AttentionHead instances
        # Each head: AttentionHead(dim, self.head_dim, dropout)
        self.heads = None
        # TODO: Output projection: nn.Linear(dim, dim)
        self.proj = None
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # TODO: Run each head on x, collect outputs in a list
        # TODO: Concatenate along dim=-1 (8 heads x 64 = 512)
        # TODO: Apply output projection and dropout
        # Output shape: [B, num_patches, dim]
        pass

Transformer Block

The core repeating unit. Pre-norm style (what ViT uses):

Input
  |
  +---- LayerNorm -> MultiHeadAttention ----+
  |                                          |
  +------------- + (residual) <-------------+
  |
  +---- LayerNorm -> FeedForward -----------+
  |                                          |
  +------------- + (residual) <-------------+
  |
Output

The residual connections (x + ...) let gradients flow and make deep stacking possible.

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_dim, dropout=0.1):
        super().__init__()
        # TODO: Define norm1, attn (MultiHeadAttention), norm2, ffn (FeedForward)
        self.norm1 = None
        self.attn = None
        self.norm2 = None
        self.ffn = None

    def forward(self, x):
        # TODO: Pre-norm attention with residual: x = x + self.attn(self.norm1(x))
        # TODO: Pre-norm feedforward with residual: x = x + self.ffn(self.norm2(x))
        pass

Transformer Encoder

Just stacking N TransformerBlocks in sequence, with a final LayerNorm.

Early blocks learn low-level features (edges, textures). Later blocks learn high-level features (semantic meaning).

class TransformerEncoder(nn.Module):
    def __init__(self, dim, depth, num_heads, mlp_dim, dropout=0.1):
        super().__init__()
        # TODO: Create nn.ModuleList of TransformerBlock instances
        self.layers = None
        # TODO: Final LayerNorm
        self.norm = None

    def forward(self, x):
        # TODO: Loop through blocks, then apply final norm
        pass
# Sanity check - test your encoder
encoder = TransformerEncoder(dim=512, depth=6, num_heads=8, mlp_dim=2048, dropout=0.1)
dummy = torch.randn(16, 197, 512)  # (batch, num_patches+1, hidden_dim)
out = encoder(dummy)
print(out.shape)  # should be [16, 197, 512]
print(f"Parameters: {sum(p.numel() for p in encoder.parameters()):,}")

Classification Head

The CLS token (position 0) has attended to all patches across all layers. Extract it and map to class logits.

class ClassificationHead(nn.Module):
    def __init__(self, hidden_dim, num_classes):
        super().__init__()
        # TODO: LayerNorm(hidden_dim) and Linear(hidden_dim, num_classes)
        self.norm = None
        self.linear = None

    def forward(self, x):
        # TODO: Extract CLS token: x[:, 0]
        # TODO: Normalize and classify
        # Output shape: [B, num_classes]
        pass

Full ViT

Put it all together: patch embed -> prepend CLS -> add pos embeddings -> encoder -> classify

class ViT(nn.Module):
    def __init__(self, img_size, patch_size, num_hiddens, num_heads, num_classes, depth=6, dropout=0.1):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        # TODO: Initialize these components:
        # self.patch_embedding = PatchEmbedding(...)
        # self.cls_token = nn.Parameter(torch.randn(1, 1, num_hiddens))
        # self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, num_hiddens))  # +1 for CLS
        # self.encoder = TransformerEncoder(...)
        # self.classification_head = ClassificationHead(...)
        pass

    def forward(self, x):
        batch = x.shape[0]
        # TODO: 1. Patch embed: patches = self.patch_embedding(x), then flatten and transpose
        # TODO: 2. Prepend CLS token: expand to batch size, then torch.cat
        # TODO: 3. Add positional embeddings
        # TODO: 4. Run through encoder
        # TODO: 5. Classification head
        pass

Test on FashionMNIST

Note: FashionMNIST is grayscale 28x28. We resize to 56x56 and repeat to 3 channels.

from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((56, 56)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
])

train_dataset = FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = FashionMNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using: {device}")

model = ViT(
    img_size=56, patch_size=7, num_hiddens=256,
    num_heads=8, num_classes=10, depth=2, dropout=0.1
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()

for epoch in range(5):
    model.train()
    total_loss, correct, total = 0, 0, 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        logits = model(images)
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        correct += (logits.argmax(dim=-1) == labels).sum().item()
        total += labels.size(0)

    print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f} | Acc: {correct/total:.4f}")
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        logits = model(images)
        correct += (logits.argmax(dim=-1) == labels).sum().item()
        total += labels.size(0)

print(f"Test Accuracy: {correct/total:.4f}")