from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.nn.functional as F # This gives us the softmax()
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.Resize((56, 56)),
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])
train_dataset = CIFAR100(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR100(root='./data', train=False, download=True, transform=transform)CoAtNet
Review:
- Brings in the best of both worlds. CNNS are good at local geatures and have nice inductive biases. Transformers are good at global relationships. If you stack CNNs and transformer blocks, you get good results.
- Relative Attention. Take away the absolute positional embeddings like we had in the vanilla ViT, and instead use relative position biases in attentnions scores. This generalizes nicely to different input sizes.
from vit import ViT, PatchEmbedding, MultiHeadAttention, TransformerEncoder ,ClassificationHead , FeedForward We basically need to replace the PatchEmbedding with a deeper CNN
class CNNStem(nn.Module):
def __init__(self, num_hiddens):
super().__init__()
# TODO: Build a nn.Sequential stem with 3 conv layers
# Each layer: Conv2d -> BatchNorm2d -> GELU
#
# Layer 1: 3 -> 64 channels, kernel_size=3, stride=2, padding=1
# Layer 2: 64 -> 128 channels, kernel_size=3, stride=2, padding=1
# Layer 3: 128 -> num_hiddens channels, kernel_size=3, stride=2, padding=1
#
# Hint: stride=2 halves the spatial dimensions each time
# 56x56 -> 28x28 -> 14x14 -> 7x7
pass
def forward(self, x):
# TODO: Pass x through the stem
passimg = torch.randn(1, 3, 56, 56)# Test your CNNStem
stem = CNNStem(num_hiddens = 512)
outs = stem(img)
print(outs.shape) # Expected: torch.Size([1, 512, 7, 7])class CoAtNEt(nn.Module):
def __init__(self, img_size, patch_size, num_hiddens, num_heads, num_classes, depth=6, dropout=0.1):
super().__init__()
# TODO: Calculate num_patches
# Hint: With 3 stride-2 convs on a 56x56 image, we get a 7x7 feature map = 49 patches
# TODO: Create the CNN stem instead of PatchEmbedding
# Hint: Use CNNStem(num_hiddens)
# TODO: Create CLS token
# Hint: nn.Parameter(torch.randn(1, 1, num_hiddens))
# TODO: Create positional embedding
# Hint: nn.Parameter(torch.randn(1, num_patches + 1, num_hiddens))
# TODO: Create transformer encoder (reuse from vit.py)
# Hint: TransformerEncoder(num_hiddens, depth, num_heads, mlp_dim=num_hiddens * 4)
# TODO: Create classification head (reuse from vit.py)
pass
def forward(self, x):
batch = x.shape[0]
# TODO: 1. Pass through CNN stem (instead of PatchEmbedding)
# TODO: 2. Flatten spatial dims and transpose to get [B, num_patches, num_hiddens]
# Hint: .flatten(2).transpose(1, 2)
# TODO: 3. Prepend CLS token
# Hint: expand cls_token to batch size, then torch.cat
# TODO: 4. Add positional embeddings
# TODO: 5. Pass through transformer encoder
# TODO: 6. Pass through classification head
pass# Test your CoAtNEt
coatnet = CoAtNEt(patch_size = 7, num_hiddens=512, num_heads = 8, num_classes= 100, img_size = 56)
x = coatnet(img)
print(x.shape) # Expected: torch.Size([1, 100])device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using: {device}")train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)coatnet = CoAtNEt(patch_size = 7, num_hiddens=512, num_heads = 8, num_classes= 100, img_size = 56)
coatnet.to(device)import time
optimizer = torch.optim.Adam(coatnet.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()
for epoch in range(50):
coatnet.train()
total_loss, correct, total = 0, 0, 0
start = time.time()
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
logits = coatnet(images)
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
correct += (logits.argmax(dim=-1) == labels).sum().item()
total += labels.size(0)
elapsed = time.time() - start
print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f} | Acc: {correct/total:.4f} | Time: {elapsed:.1f}s")Next Steps
Above we have: CNN Stem → Transformer Encoder (all attention blocks)
The actual CoAtNet paper stacks them in stages:
Stage 0: Conv blocks (S0) Stage 1: Conv blocks (S1) Stage 2: Transformer blocks (S2) Stage 3: Transformer blocks (S3)
# !pip install timmimport timm
# List available coat models
print(timm.list_models('coatnet*'))
# Load one and inspect
model = timm.create_model('coatnet_0_224', pretrained=False)
print(model)Undersanding the timm implemetation
Stem
- Conv2d(3→64, stride=2) → BatchNorm → GELU → Conv2d(64→64)
Stage 0 + Stage 1: MbConvBlocks (CNN stages)
here, timm uses depth-wise comvilutions like the ones introduces in mobilenetv3.
What was a depthwise conv? each channel is convolved independently for speed.
Stage 2 + Stage 3: TransformerBlock2d (Attention stages)
this is very much like our ViT implementations with some tweaks that we should look into.
Timm uses relative position bias, whears we were using absolute learned embeddings.
No cls_token
This one is interesting, instead of using nn.Linear in the Attention, they use a conv2:
class AttentionHead(nn.Module):
def __init__(self, dim, head_dim, dropout=0.1):
super().__init__()
self.scale = head_dim**-0.5
self.q = nn.Linear(dim, head_dim, bias=False)
self.k = nn.Linear(dim, head_dim, bias=False)
self.v = nn.Linear(dim, head_dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
q, k, v = self.q(x), self.k(x), self.v(x)
attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
return outSo instead of doing:
self.q = nn.Linear(dim, head_dim, bias=False)
self.k = nn.Linear(dim, head_dim bias=False)
self.v = nn.Linear(dim, head_dim, bias=False)they do:
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1) # Q, K, V all in one