CoAtNet

Review:

  • Brings in the best of both worlds. CNNS are good at local geatures and have nice inductive biases. Transformers are good at global relationships. If you stack CNNs and transformer blocks, you get good results.
  • Relative Attention. Take away the absolute positional embeddings like we had in the vanilla ViT, and instead use relative position biases in attentnions scores. This generalizes nicely to different input sizes.
from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.nn.functional as F # This gives us the softmax()
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((56, 56)),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

train_dataset = CIFAR100(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR100(root='./data', train=False, download=True, transform=transform)
from vit import ViT, PatchEmbedding, MultiHeadAttention, TransformerEncoder ,ClassificationHead , FeedForward                                                               

We basically need to replace the PatchEmbedding with a deeper CNN

class CNNStem(nn.Module):
    def __init__(self, num_hiddens):
        super().__init__()
        # TODO: Build a nn.Sequential stem with 3 conv layers
        # Each layer: Conv2d -> BatchNorm2d -> GELU
        #
        # Layer 1: 3 -> 64 channels, kernel_size=3, stride=2, padding=1
        # Layer 2: 64 -> 128 channels, kernel_size=3, stride=2, padding=1
        # Layer 3: 128 -> num_hiddens channels, kernel_size=3, stride=2, padding=1
        #
        # Hint: stride=2 halves the spatial dimensions each time
        # 56x56 -> 28x28 -> 14x14 -> 7x7
        pass

    def forward(self, x):
        # TODO: Pass x through the stem
        pass
img = torch.randn(1, 3, 56, 56)
# Test your CNNStem
stem = CNNStem(num_hiddens = 512)
outs = stem(img)
print(outs.shape)  # Expected: torch.Size([1, 512, 7, 7])
class CoAtNEt(nn.Module):
    def __init__(self, img_size, patch_size, num_hiddens, num_heads, num_classes, depth=6, dropout=0.1):
        super().__init__()
        # TODO: Calculate num_patches
        # Hint: With 3 stride-2 convs on a 56x56 image, we get a 7x7 feature map = 49 patches
        
        # TODO: Create the CNN stem instead of PatchEmbedding
        # Hint: Use CNNStem(num_hiddens)
        
        # TODO: Create CLS token
        # Hint: nn.Parameter(torch.randn(1, 1, num_hiddens))
        
        # TODO: Create positional embedding
        # Hint: nn.Parameter(torch.randn(1, num_patches + 1, num_hiddens))
        
        # TODO: Create transformer encoder (reuse from vit.py)
        # Hint: TransformerEncoder(num_hiddens, depth, num_heads, mlp_dim=num_hiddens * 4)
        
        # TODO: Create classification head (reuse from vit.py)
        pass

    def forward(self, x):
        batch = x.shape[0]
        
        # TODO: 1. Pass through CNN stem (instead of PatchEmbedding)
        
        # TODO: 2. Flatten spatial dims and transpose to get [B, num_patches, num_hiddens]
        # Hint: .flatten(2).transpose(1, 2)
        
        # TODO: 3. Prepend CLS token
        # Hint: expand cls_token to batch size, then torch.cat
        
        # TODO: 4. Add positional embeddings
        
        # TODO: 5. Pass through transformer encoder
        
        # TODO: 6. Pass through classification head
        
        pass
# Test your CoAtNEt
coatnet = CoAtNEt(patch_size = 7, num_hiddens=512, num_heads = 8, num_classes= 100, img_size = 56)
x = coatnet(img)
print(x.shape)  # Expected: torch.Size([1, 100])
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using: {device}")
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
coatnet = CoAtNEt(patch_size = 7, num_hiddens=512, num_heads = 8, num_classes= 100, img_size = 56)
coatnet.to(device)
import time
optimizer = torch.optim.Adam(coatnet.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()

for epoch in range(50):
    coatnet.train()
    total_loss, correct, total = 0, 0, 0
    start = time.time()

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        logits = coatnet(images)
        loss = criterion(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        correct += (logits.argmax(dim=-1) == labels).sum().item()
        total += labels.size(0)
    elapsed = time.time() - start
    print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f} | Acc: {correct/total:.4f} | Time: {elapsed:.1f}s")

Next Steps

Above we have: CNN Stem → Transformer Encoder (all attention blocks)

The actual CoAtNet paper stacks them in stages:

Stage 0: Conv blocks (S0) Stage 1: Conv blocks (S1) Stage 2: Transformer blocks (S2) Stage 3: Transformer blocks (S3)

# !pip install timm
import timm                                                                                                                                                                        
                                                                                                                                                                                     
# List available coat models                                                                                                                                                       
print(timm.list_models('coatnet*'))                                                                                                                                                   
                                                                                                                                                                                    
# Load one and inspect
model = timm.create_model('coatnet_0_224', pretrained=False)
print(model)

Undersanding the timm implemetation

Stem

  • Conv2d(3→64, stride=2) → BatchNorm → GELU → Conv2d(64→64)

Stage 0 + Stage 1: MbConvBlocks (CNN stages)

  • here, timm uses depth-wise comvilutions like the ones introduces in mobilenetv3.

  • What was a depthwise conv? each channel is convolved independently for speed.

Stage 2 + Stage 3: TransformerBlock2d (Attention stages)

  • this is very much like our ViT implementations with some tweaks that we should look into.

  • Timm uses relative position bias, whears we were using absolute learned embeddings.

  • No cls_token

  • This one is interesting, instead of using nn.Linear in the Attention, they use a conv2:

class AttentionHead(nn.Module):
    def __init__(self, dim, head_dim, dropout=0.1):
        super().__init__()
        self.scale = head_dim**-0.5
        self.q = nn.Linear(dim, head_dim, bias=False)
        self.k = nn.Linear(dim, head_dim, bias=False)
        self.v = nn.Linear(dim, head_dim, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        q, k, v = self.q(x), self.k(x), self.v(x)
        attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        out = torch.matmul(attn, v)
        return out

So instead of doing:

self.q = nn.Linear(dim, head_dim, bias=False)
self.k = nn.Linear(dim, head_dim bias=False)
self.v = nn.Linear(dim, head_dim, bias=False)

they do:

self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1)  # Q, K, V all in one