CoAtNet

Review:

  • Brings in the best of both worlds. CNNS are good at local geatures and have nice inductive biases. Transformers are good at global relationships. If you stack CNNs and transformer blocks, you get good results.
  • Relative Attention. Take away the absolute positional embeddings like we had in the vanilla ViT, and instead use relative position biases in attentnions scores. This generalizes nicely to different input sizes.
from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.nn.functional as F # This gives us the softmax()
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((56, 56)),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

train_dataset = CIFAR100(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR100(root='./data', train=False, download=True, transform=transform)
from vit import ViT, PatchEmbedding, MultiHeadAttention, TransformerEncoder ,ClassificationHead , FeedForward                                                               

We basically need to replace the PatchEmbedding with a deeper CNN

class CNNStem(nn.Module):
    def __init__(self, num_hiddens):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size =3, stride = 2, padding =1),
            nn.BatchNorm2d(64),
            nn.GELU(),
            nn.Conv2d(64, 128, kernel_size =3, stride = 2, padding =1),
            nn.BatchNorm2d(128),
            nn.GELU(),
            nn.Conv2d(128, num_hiddens, kernel_size =3, stride = 2, padding =1),
            nn.BatchNorm2d(num_hiddens),
            nn.GELU()
        )
    def forward(self, x):
        x = self.stem(x)
        return x
     
img = torch.randn(1, 3, 56, 56)
stem = CNNStem(num_hiddens = 512)
outs = stem(x)
print(outs.shape)
torch.Size([2, 512, 7, 7])
class Vit(nn.Module):
    def __init__(self, img_size, patch_size, num_hiddens, num_heads, num_classes, depth=6, dropout=0.1):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_embedding = PatchEmbedding(img_size, patch_size, num_hiddens)
        self.cls_token = nn.Parameter(torch.randn(1, 1, num_hiddens))
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, num_hiddens))
        self.encoder = TransformerEncoder(num_hiddens, depth, num_heads, mlp_dim=num_hiddens * 4)
        self.classification_head = ClassificationHead(num_hiddens, num_classes)

    def forward(self, x):
        batch = x.shape[0]
        patches = self.patch_embedding(x)
        print(patches.shape)
        patches = patches.flatten(2).transpose(1, 2)
        print(patches.shape)
        cls_tokens = self.cls_token.expand(batch, -1, -1)
        x = torch.cat((cls_tokens, patches), dim=1)
        x = x + self.pos_embedding
        x = self.encoder(x)
        x = self.classification_head(x)
        return x
    
class CoAtNEt(nn.Module):
    def __init__(self, img_size, patch_size, num_hiddens, num_heads, num_classes, depth=6, dropout=0.1):
        super().__init__()
        self.num_patches = 7*7
        self.stem = CNNStem(num_hiddens)
        self.cls_token = nn.Parameter(torch.randn(1, 1, num_hiddens))
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, num_hiddens))
        self.encoder = TransformerEncoder(num_hiddens, depth, num_heads, mlp_dim=num_hiddens * 4)
        self.classification_head = ClassificationHead(num_hiddens, num_classes)
        self.patch_size = patch_size

    def forward(self, x):
        batch = x.shape[0]
        patches_equivalent = self.stem(x)
        patches_equivalent = patches_equivalent.flatten(2).transpose(1, 2)
        cls_tokens = self.cls_token.expand(batch, -1, -1)
        x = torch.cat((cls_tokens, patches_equivalent), dim=1)
        x = x + self.pos_embedding
        x = self.encoder(x)
        x = self.classification_head(x)
        return x
net = Vit(patch_size = 7, num_hiddens=512, num_heads = 8, num_classes= 100, img_size = 56)
coatnet = CoAtNEt(patch_size = 7, num_hiddens=512, num_heads = 8, num_classes= 100, img_size = 56)
x = net(img)
torch.Size([1, 512, 8, 8])
torch.Size([1, 64, 512])
x = coatnet(img)
torch.Size([1, 512, 7, 7])
torch.Size([1, 49, 512])
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using: {device}")
Using: mps
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
coatnet.to(device)
CoAtNEt(
  (stem): CNNStem(
    (stem): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): GELU(approximate='none')
      (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): GELU(approximate='none')
      (6): Conv2d(128, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (7): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): GELU(approximate='none')
    )
  )
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerBlock(
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): MultiHeadAttention(
          (heads): ModuleList(
            (0-7): 8 x AttentionHead(
              (q): Linear(in_features=512, out_features=64, bias=False)
              (k): Linear(in_features=512, out_features=64, bias=False)
              (v): Linear(in_features=512, out_features=64, bias=False)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (proj): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (ffn): FeedForward(
          (net): Sequential(
            (0): Linear(in_features=512, out_features=2048, bias=True)
            (1): GELU(approximate='none')
            (2): Dropout(p=0.1, inplace=False)
            (3): Linear(in_features=2048, out_features=512, bias=True)
            (4): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (classification_head): ClassificationHead(
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (linear): Linear(in_features=512, out_features=100, bias=True)
  )
)
import time
optimizer = torch.optim.Adam(coatnet.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()

for epoch in range(50):
    coatnet.train()
    total_loss, correct, total = 0, 0, 0
    start = time.time()

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        logits = coatnet(images)
        loss = criterion(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        correct += (logits.argmax(dim=-1) == labels).sum().item()
        total += labels.size(0)
    elapsed = time.time() - start
    print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f} | Acc: {correct/total:.4f} | Time: {elapsed:.1f}s")
Epoch 1 | Loss: 3.6707 | Acc: 0.1298 | Time: 234.9s
Epoch 2 | Loss: 3.1504 | Acc: 0.2177 | Time: 224.4s
Epoch 3 | Loss: 2.9129 | Acc: 0.2650 | Time: 223.2s
Epoch 4 | Loss: 2.7719 | Acc: 0.2941 | Time: 220.1s
Epoch 5 | Loss: 2.6528 | Acc: 0.3191 | Time: 220.2s
Epoch 6 | Loss: 2.5491 | Acc: 0.3434 | Time: 220.4s
Epoch 7 | Loss: 2.4791 | Acc: 0.3539 | Time: 220.2s
Epoch 8 | Loss: 2.3921 | Acc: 0.3714 | Time: 220.2s
Epoch 9 | Loss: 2.3086 | Acc: 0.3898 | Time: 221.1s
Epoch 10 | Loss: 2.2702 | Acc: 0.3980 | Time: 221.0s
Epoch 11 | Loss: 2.2146 | Acc: 0.4096 | Time: 220.9s
Epoch 12 | Loss: 2.1748 | Acc: 0.4175 | Time: 222.1s
Epoch 13 | Loss: 2.1202 | Acc: 0.4317 | Time: 220.7s
Epoch 14 | Loss: 2.0260 | Acc: 0.4503 | Time: 221.0s
Epoch 15 | Loss: 1.9948 | Acc: 0.4582 | Time: 220.7s
Epoch 16 | Loss: 1.9251 | Acc: 0.4743 | Time: 221.8s
Epoch 17 | Loss: 1.8899 | Acc: 0.4811 | Time: 221.2s
Epoch 18 | Loss: 1.7947 | Acc: 0.5026 | Time: 220.7s
Epoch 19 | Loss: 1.7059 | Acc: 0.5215 | Time: 220.5s
Epoch 20 | Loss: 1.6568 | Acc: 0.5325 | Time: 220.6s
Epoch 21 | Loss: 1.6600 | Acc: 0.5282 | Time: 220.9s
Epoch 22 | Loss: 1.6555 | Acc: 0.5297 | Time: 221.1s
Epoch 23 | Loss: 1.5758 | Acc: 0.5522 | Time: 220.7s
Epoch 24 | Loss: 1.4820 | Acc: 0.5721 | Time: 220.8s
Epoch 25 | Loss: 1.3498 | Acc: 0.6058 | Time: 221.4s
Epoch 26 | Loss: 1.2785 | Acc: 0.6217 | Time: 222.6s
Epoch 27 | Loss: 1.1744 | Acc: 0.6486 | Time: 222.6s
Epoch 28 | Loss: 1.1640 | Acc: 0.6528 | Time: 222.8s
Epoch 29 | Loss: 1.1275 | Acc: 0.6616 | Time: 222.4s
Epoch 30 | Loss: 1.0543 | Acc: 0.6809 | Time: 222.8s
Epoch 31 | Loss: 0.9848 | Acc: 0.6999 | Time: 222.8s
Epoch 32 | Loss: 0.8899 | Acc: 0.7274 | Time: 223.6s
Epoch 33 | Loss: 0.9022 | Acc: 0.7201 | Time: 222.6s
Epoch 34 | Loss: 0.7973 | Acc: 0.7506 | Time: 221.1s
Epoch 35 | Loss: 0.7528 | Acc: 0.7635 | Time: 220.5s
Epoch 36 | Loss: 0.7421 | Acc: 0.7658 | Time: 220.6s
Epoch 37 | Loss: 0.6831 | Acc: 0.7841 | Time: 220.7s
Epoch 38 | Loss: 0.6923 | Acc: 0.7841 | Time: 220.9s
Epoch 39 | Loss: 0.6746 | Acc: 0.7881 | Time: 220.5s
Epoch 40 | Loss: 0.6440 | Acc: 0.7968 | Time: 220.5s
Epoch 41 | Loss: 0.5860 | Acc: 0.8143 | Time: 220.5s
Epoch 42 | Loss: 0.5642 | Acc: 0.8200 | Time: 221.5s
Epoch 43 | Loss: 0.5572 | Acc: 0.8241 | Time: 221.9s
Epoch 44 | Loss: 0.6289 | Acc: 0.8013 | Time: 221.9s
Epoch 45 | Loss: 0.5685 | Acc: 0.8197 | Time: 221.9s
Epoch 46 | Loss: 0.6295 | Acc: 0.8023 | Time: 221.9s
Epoch 47 | Loss: 0.5720 | Acc: 0.8190 | Time: 221.8s
Epoch 48 | Loss: 0.5371 | Acc: 0.8305 | Time: 221.7s
Epoch 49 | Loss: 0.4337 | Acc: 0.8604 | Time: 222.4s
Epoch 50 | Loss: 0.4385 | Acc: 0.8594 | Time: 222.7s

Next Steps

Above we have: CNN Stem → Transformer Encoder (all attention blocks)

The actual CoAtNet paper stacks them in stages:

Stage 0: Conv blocks (S0) Stage 1: Conv blocks (S1) Stage 2: Transformer blocks (S2) Stage 3: Transformer blocks (S3)

# !pip install timm
import timm                                                                                                                                                                        
                                                                                                                                                                                     
# List available coat models                                                                                                                                                       
print(timm.list_models('coatnet*'))                                                                                                                                                   
                                                                                                                                                                                    
# Load one and inspect
model = timm.create_model('coatnet_0_224', pretrained=False)
print(model)
['coatnet_0_224', 'coatnet_0_rw_224', 'coatnet_1_224', 'coatnet_1_rw_224', 'coatnet_2_224', 'coatnet_2_rw_224', 'coatnet_3_224', 'coatnet_3_rw_224', 'coatnet_4_224', 'coatnet_5_224', 'coatnet_bn_0_rw_224', 'coatnet_nano_cc_224', 'coatnet_nano_rw_224', 'coatnet_pico_rw_224', 'coatnet_rmlp_0_rw_224', 'coatnet_rmlp_1_rw2_224', 'coatnet_rmlp_1_rw_224', 'coatnet_rmlp_2_rw_224', 'coatnet_rmlp_2_rw_384', 'coatnet_rmlp_3_rw_224', 'coatnet_rmlp_nano_rw_224']
MaxxVit(
  (stem): Stem(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (norm1): BatchNormAct2d(
      64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      (drop): Identity()
      (act): GELU()
    )
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  )
  (stages): Sequential(
    (0): MaxxVitStage(
      (blocks): Sequential(
        (0): MbConvBlock(
          (shortcut): Downsample2d(
            (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
            (expand): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))
          )
          (pre_norm): BatchNormAct2d(
            64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): Identity()
          )
          (down): Identity()
          (conv1_1x1): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm1): BatchNormAct2d(
            384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): GELU()
          )
          (conv2_kxk): Conv2d(384, 384, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=384, bias=False)
          (norm2): BatchNormAct2d(
            384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): GELU()
          )
          (se): SEModule(
            (fc1): Conv2d(384, 24, kernel_size=(1, 1), stride=(1, 1))
            (bn): Identity()
            (act): SiLU(inplace=True)
            (fc2): Conv2d(24, 384, kernel_size=(1, 1), stride=(1, 1))
            (gate): Sigmoid()
          )
          (conv3_1x1): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1))
          (drop_path): Identity()
        )
        (1): MbConvBlock(
          (shortcut): Identity()
          (pre_norm): BatchNormAct2d(
            96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): Identity()
          )
          (down): Identity()
          (conv1_1x1): Conv2d(96, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm1): BatchNormAct2d(
            384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): GELU()
          )
          (conv2_kxk): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)
          (norm2): BatchNormAct2d(
            384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): GELU()
          )
          (se): SEModule(
            (fc1): Conv2d(384, 24, kernel_size=(1, 1), stride=(1, 1))
            (bn): Identity()
            (act): SiLU(inplace=True)
            (fc2): Conv2d(24, 384, kernel_size=(1, 1), stride=(1, 1))
            (gate): Sigmoid()
          )
          (conv3_1x1): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1))
          (drop_path): Identity()
        )
      )
    )
    (1): MaxxVitStage(
      (blocks): Sequential(
        (0): MbConvBlock(
          (shortcut): Downsample2d(
            (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
            (expand): Conv2d(96, 192, kernel_size=(1, 1), stride=(1, 1))
          )
          (pre_norm): BatchNormAct2d(
            96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): Identity()
          )
          (down): Identity()
          (conv1_1x1): Conv2d(96, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm1): BatchNormAct2d(
            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): GELU()
          )
          (conv2_kxk): Conv2d(768, 768, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=768, bias=False)
          (norm2): BatchNormAct2d(
            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): GELU()
          )
          (se): SEModule(
            (fc1): Conv2d(768, 48, kernel_size=(1, 1), stride=(1, 1))
            (bn): Identity()
            (act): SiLU(inplace=True)
            (fc2): Conv2d(48, 768, kernel_size=(1, 1), stride=(1, 1))
            (gate): Sigmoid()
          )
          (conv3_1x1): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1))
          (drop_path): Identity()
        )
        (1): MbConvBlock(
          (shortcut): Identity()
          (pre_norm): BatchNormAct2d(
            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): Identity()
          )
          (down): Identity()
          (conv1_1x1): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm1): BatchNormAct2d(
            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): GELU()
          )
          (conv2_kxk): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=768, bias=False)
          (norm2): BatchNormAct2d(
            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): GELU()
          )
          (se): SEModule(
            (fc1): Conv2d(768, 48, kernel_size=(1, 1), stride=(1, 1))
            (bn): Identity()
            (act): SiLU(inplace=True)
            (fc2): Conv2d(48, 768, kernel_size=(1, 1), stride=(1, 1))
            (gate): Sigmoid()
          )
          (conv3_1x1): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1))
          (drop_path): Identity()
        )
        (2): MbConvBlock(
          (shortcut): Identity()
          (pre_norm): BatchNormAct2d(
            192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): Identity()
          )
          (down): Identity()
          (conv1_1x1): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm1): BatchNormAct2d(
            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): GELU()
          )
          (conv2_kxk): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=768, bias=False)
          (norm2): BatchNormAct2d(
            768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): GELU()
          )
          (se): SEModule(
            (fc1): Conv2d(768, 48, kernel_size=(1, 1), stride=(1, 1))
            (bn): Identity()
            (act): SiLU(inplace=True)
            (fc2): Conv2d(48, 768, kernel_size=(1, 1), stride=(1, 1))
            (gate): Sigmoid()
          )
          (conv3_1x1): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1))
          (drop_path): Identity()
        )
      )
    )
    (2): MaxxVitStage(
      (blocks): Sequential(
        (0): TransformerBlock2d(
          (shortcut): Downsample2d(
            (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
            (expand): Conv2d(192, 384, kernel_size=(1, 1), stride=(1, 1))
          )
          (norm1): Sequential(
            (norm): LayerNorm2d((192,), eps=1e-06, elementwise_affine=True)
            (down): Downsample2d(
              (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
              (expand): Identity()
            )
          )
          (attn): Attention2d(
            (qkv): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1))
            (rel_pos): RelPosBias()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Conv2d(384, 384, kernel_size=(1, 1), stride=(1, 1))
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): Identity()
          (drop_path1): Identity()
          (norm2): LayerNorm2d((384,), eps=1e-06, elementwise_affine=True)
          (mlp): ConvMlp(
            (fc1): Conv2d(384, 1536, kernel_size=(1, 1), stride=(1, 1))
            (norm): Identity()
            (act): GELU()
            (drop): Dropout(p=0.0, inplace=False)
            (fc2): Conv2d(1536, 384, kernel_size=(1, 1), stride=(1, 1))
          )
          (ls2): Identity()
          (drop_path2): Identity()
        )
        (1): TransformerBlock2d(
          (shortcut): Identity()
          (norm1): LayerNorm2d((384,), eps=1e-06, elementwise_affine=True)
          (attn): Attention2d(
            (qkv): Conv2d(384, 1152, kernel_size=(1, 1), stride=(1, 1))
            (rel_pos): RelPosBias()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Conv2d(384, 384, kernel_size=(1, 1), stride=(1, 1))
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): Identity()
          (drop_path1): Identity()
          (norm2): LayerNorm2d((384,), eps=1e-06, elementwise_affine=True)
          (mlp): ConvMlp(
            (fc1): Conv2d(384, 1536, kernel_size=(1, 1), stride=(1, 1))
            (norm): Identity()
            (act): GELU()
            (drop): Dropout(p=0.0, inplace=False)
            (fc2): Conv2d(1536, 384, kernel_size=(1, 1), stride=(1, 1))
          )
          (ls2): Identity()
          (drop_path2): Identity()
        )
        (2): TransformerBlock2d(
          (shortcut): Identity()
          (norm1): LayerNorm2d((384,), eps=1e-06, elementwise_affine=True)
          (attn): Attention2d(
            (qkv): Conv2d(384, 1152, kernel_size=(1, 1), stride=(1, 1))
            (rel_pos): RelPosBias()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Conv2d(384, 384, kernel_size=(1, 1), stride=(1, 1))
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): Identity()
          (drop_path1): Identity()
          (norm2): LayerNorm2d((384,), eps=1e-06, elementwise_affine=True)
          (mlp): ConvMlp(
            (fc1): Conv2d(384, 1536, kernel_size=(1, 1), stride=(1, 1))
            (norm): Identity()
            (act): GELU()
            (drop): Dropout(p=0.0, inplace=False)
            (fc2): Conv2d(1536, 384, kernel_size=(1, 1), stride=(1, 1))
          )
          (ls2): Identity()
          (drop_path2): Identity()
        )
        (3): TransformerBlock2d(
          (shortcut): Identity()
          (norm1): LayerNorm2d((384,), eps=1e-06, elementwise_affine=True)
          (attn): Attention2d(
            (qkv): Conv2d(384, 1152, kernel_size=(1, 1), stride=(1, 1))
            (rel_pos): RelPosBias()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Conv2d(384, 384, kernel_size=(1, 1), stride=(1, 1))
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): Identity()
          (drop_path1): Identity()
          (norm2): LayerNorm2d((384,), eps=1e-06, elementwise_affine=True)
          (mlp): ConvMlp(
            (fc1): Conv2d(384, 1536, kernel_size=(1, 1), stride=(1, 1))
            (norm): Identity()
            (act): GELU()
            (drop): Dropout(p=0.0, inplace=False)
            (fc2): Conv2d(1536, 384, kernel_size=(1, 1), stride=(1, 1))
          )
          (ls2): Identity()
          (drop_path2): Identity()
        )
        (4): TransformerBlock2d(
          (shortcut): Identity()
          (norm1): LayerNorm2d((384,), eps=1e-06, elementwise_affine=True)
          (attn): Attention2d(
            (qkv): Conv2d(384, 1152, kernel_size=(1, 1), stride=(1, 1))
            (rel_pos): RelPosBias()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Conv2d(384, 384, kernel_size=(1, 1), stride=(1, 1))
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): Identity()
          (drop_path1): Identity()
          (norm2): LayerNorm2d((384,), eps=1e-06, elementwise_affine=True)
          (mlp): ConvMlp(
            (fc1): Conv2d(384, 1536, kernel_size=(1, 1), stride=(1, 1))
            (norm): Identity()
            (act): GELU()
            (drop): Dropout(p=0.0, inplace=False)
            (fc2): Conv2d(1536, 384, kernel_size=(1, 1), stride=(1, 1))
          )
          (ls2): Identity()
          (drop_path2): Identity()
        )
      )
    )
    (3): MaxxVitStage(
      (blocks): Sequential(
        (0): TransformerBlock2d(
          (shortcut): Downsample2d(
            (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
            (expand): Conv2d(384, 768, kernel_size=(1, 1), stride=(1, 1))
          )
          (norm1): Sequential(
            (norm): LayerNorm2d((384,), eps=1e-06, elementwise_affine=True)
            (down): Downsample2d(
              (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
              (expand): Identity()
            )
          )
          (attn): Attention2d(
            (qkv): Conv2d(384, 2304, kernel_size=(1, 1), stride=(1, 1))
            (rel_pos): RelPosBias()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Conv2d(768, 768, kernel_size=(1, 1), stride=(1, 1))
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): Identity()
          (drop_path1): Identity()
          (norm2): LayerNorm2d((768,), eps=1e-06, elementwise_affine=True)
          (mlp): ConvMlp(
            (fc1): Conv2d(768, 3072, kernel_size=(1, 1), stride=(1, 1))
            (norm): Identity()
            (act): GELU()
            (drop): Dropout(p=0.0, inplace=False)
            (fc2): Conv2d(3072, 768, kernel_size=(1, 1), stride=(1, 1))
          )
          (ls2): Identity()
          (drop_path2): Identity()
        )
        (1): TransformerBlock2d(
          (shortcut): Identity()
          (norm1): LayerNorm2d((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention2d(
            (qkv): Conv2d(768, 2304, kernel_size=(1, 1), stride=(1, 1))
            (rel_pos): RelPosBias()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Conv2d(768, 768, kernel_size=(1, 1), stride=(1, 1))
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): Identity()
          (drop_path1): Identity()
          (norm2): LayerNorm2d((768,), eps=1e-06, elementwise_affine=True)
          (mlp): ConvMlp(
            (fc1): Conv2d(768, 3072, kernel_size=(1, 1), stride=(1, 1))
            (norm): Identity()
            (act): GELU()
            (drop): Dropout(p=0.0, inplace=False)
            (fc2): Conv2d(3072, 768, kernel_size=(1, 1), stride=(1, 1))
          )
          (ls2): Identity()
          (drop_path2): Identity()
        )
      )
    )
  )
  (norm): Identity()
  (head): NormMlpClassifierHead(
    (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Identity())
    (norm): LayerNorm2d((768,), eps=1e-06, elementwise_affine=True)
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (pre_logits): Sequential(
      (fc): Linear(in_features=768, out_features=768, bias=True)
      (act): Tanh()
    )
    (drop): Dropout(p=0.0, inplace=False)
    (fc): Linear(in_features=768, out_features=1000, bias=True)
  )
)

Undersanding the timm implemetation

Stem

  • Conv2d(3→64, stride=2) → BatchNorm → GELU → Conv2d(64→64)

Stage 0 + Stage 1: MbConvBlocks (CNN stages)

  • here, timm uses depth-wise comvilutions like the ones introduces in mobilenetv3.

  • What was a depthwise conv? each channel is convolved independently for speed.

Stage 2 + Stage 3: TransformerBlock2d (Attention stages)

  • this is very much like our ViT implementations with some tweaks that we should look into.

  • Timm uses relative position bias, whears we were using absolute learned embeddings.

  • No cls_token

  • This one is interesting, instead of using nn.Linear in the Attention, they use a conv2:

class AttentionHead(nn.Module):
    def __init__(self, dim, head_dim, dropout=0.1):
        super().__init__()
        self.scale = head_dim**-0.5
        self.q = nn.Linear(dim, head_dim, bias=False)
        self.k = nn.Linear(dim, head_dim, bias=False)
        self.v = nn.Linear(dim, head_dim, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        q, k, v = self.q(x), self.k(x), self.v(x)
        attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        out = torch.matmul(attn, v)
        return out

So instead of doing:

self.q = nn.Linear(dim, head_dim, bias=False)
self.k = nn.Linear(dim, head_dim bias=False)
self.v = nn.Linear(dim, head_dim, bias=False)

they do:

self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1)  # Q, K, V all in one