from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.nn.functional as F # This gives us the softmax()
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.Resize((56, 56)),
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])
train_dataset = CIFAR100(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR100(root='./data', train=False, download=True, transform=transform)CoAtNet
Review:
- Brings in the best of both worlds. CNNS are good at local geatures and have nice inductive biases. Transformers are good at global relationships. If you stack CNNs and transformer blocks, you get good results.
- Relative Attention. Take away the absolute positional embeddings like we had in the vanilla ViT, and instead use relative position biases in attentnions scores. This generalizes nicely to different input sizes.
from vit import ViT, PatchEmbedding, MultiHeadAttention, TransformerEncoder ,ClassificationHead , FeedForward We basically need to replace the PatchEmbedding with a deeper CNN
class CNNStem(nn.Module):
def __init__(self, num_hiddens):
super().__init__()
self.stem = nn.Sequential(
nn.Conv2d(3, 64, kernel_size =3, stride = 2, padding =1),
nn.BatchNorm2d(64),
nn.GELU(),
nn.Conv2d(64, 128, kernel_size =3, stride = 2, padding =1),
nn.BatchNorm2d(128),
nn.GELU(),
nn.Conv2d(128, num_hiddens, kernel_size =3, stride = 2, padding =1),
nn.BatchNorm2d(num_hiddens),
nn.GELU()
)
def forward(self, x):
x = self.stem(x)
return x
img = torch.randn(1, 3, 56, 56)stem = CNNStem(num_hiddens = 512)
outs = stem(x)
print(outs.shape)torch.Size([2, 512, 7, 7])
class Vit(nn.Module):
def __init__(self, img_size, patch_size, num_hiddens, num_heads, num_classes, depth=6, dropout=0.1):
super().__init__()
self.num_patches = (img_size // patch_size) ** 2
self.patch_embedding = PatchEmbedding(img_size, patch_size, num_hiddens)
self.cls_token = nn.Parameter(torch.randn(1, 1, num_hiddens))
self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, num_hiddens))
self.encoder = TransformerEncoder(num_hiddens, depth, num_heads, mlp_dim=num_hiddens * 4)
self.classification_head = ClassificationHead(num_hiddens, num_classes)
def forward(self, x):
batch = x.shape[0]
patches = self.patch_embedding(x)
print(patches.shape)
patches = patches.flatten(2).transpose(1, 2)
print(patches.shape)
cls_tokens = self.cls_token.expand(batch, -1, -1)
x = torch.cat((cls_tokens, patches), dim=1)
x = x + self.pos_embedding
x = self.encoder(x)
x = self.classification_head(x)
return x
class CoAtNEt(nn.Module):
def __init__(self, img_size, patch_size, num_hiddens, num_heads, num_classes, depth=6, dropout=0.1):
super().__init__()
self.num_patches = 7*7
self.stem = CNNStem(num_hiddens)
self.cls_token = nn.Parameter(torch.randn(1, 1, num_hiddens))
self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, num_hiddens))
self.encoder = TransformerEncoder(num_hiddens, depth, num_heads, mlp_dim=num_hiddens * 4)
self.classification_head = ClassificationHead(num_hiddens, num_classes)
self.patch_size = patch_size
def forward(self, x):
batch = x.shape[0]
patches_equivalent = self.stem(x)
patches_equivalent = patches_equivalent.flatten(2).transpose(1, 2)
cls_tokens = self.cls_token.expand(batch, -1, -1)
x = torch.cat((cls_tokens, patches_equivalent), dim=1)
x = x + self.pos_embedding
x = self.encoder(x)
x = self.classification_head(x)
return xnet = Vit(patch_size = 7, num_hiddens=512, num_heads = 8, num_classes= 100, img_size = 56)
coatnet = CoAtNEt(patch_size = 7, num_hiddens=512, num_heads = 8, num_classes= 100, img_size = 56)x = net(img)torch.Size([1, 512, 8, 8])
torch.Size([1, 64, 512])
x = coatnet(img)torch.Size([1, 512, 7, 7])
torch.Size([1, 49, 512])
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using: {device}")Using: mps
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)coatnet.to(device)CoAtNEt(
(stem): CNNStem(
(stem): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): GELU(approximate='none')
(3): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): GELU(approximate='none')
(6): Conv2d(128, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(7): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): GELU(approximate='none')
)
)
(encoder): TransformerEncoder(
(layers): ModuleList(
(0-5): 6 x TransformerBlock(
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(attn): MultiHeadAttention(
(heads): ModuleList(
(0-7): 8 x AttentionHead(
(q): Linear(in_features=512, out_features=64, bias=False)
(k): Linear(in_features=512, out_features=64, bias=False)
(v): Linear(in_features=512, out_features=64, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(proj): Linear(in_features=512, out_features=512, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(ffn): FeedForward(
(net): Sequential(
(0): Linear(in_features=512, out_features=2048, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=2048, out_features=512, bias=True)
(4): Dropout(p=0.1, inplace=False)
)
)
)
)
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
(classification_head): ClassificationHead(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(linear): Linear(in_features=512, out_features=100, bias=True)
)
)
import time
optimizer = torch.optim.Adam(coatnet.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()
for epoch in range(50):
coatnet.train()
total_loss, correct, total = 0, 0, 0
start = time.time()
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
logits = coatnet(images)
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
correct += (logits.argmax(dim=-1) == labels).sum().item()
total += labels.size(0)
elapsed = time.time() - start
print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f} | Acc: {correct/total:.4f} | Time: {elapsed:.1f}s")Epoch 1 | Loss: 3.6707 | Acc: 0.1298 | Time: 234.9s
Epoch 2 | Loss: 3.1504 | Acc: 0.2177 | Time: 224.4s
Epoch 3 | Loss: 2.9129 | Acc: 0.2650 | Time: 223.2s
Epoch 4 | Loss: 2.7719 | Acc: 0.2941 | Time: 220.1s
Epoch 5 | Loss: 2.6528 | Acc: 0.3191 | Time: 220.2s
Epoch 6 | Loss: 2.5491 | Acc: 0.3434 | Time: 220.4s
Epoch 7 | Loss: 2.4791 | Acc: 0.3539 | Time: 220.2s
Epoch 8 | Loss: 2.3921 | Acc: 0.3714 | Time: 220.2s
Epoch 9 | Loss: 2.3086 | Acc: 0.3898 | Time: 221.1s
Epoch 10 | Loss: 2.2702 | Acc: 0.3980 | Time: 221.0s
Epoch 11 | Loss: 2.2146 | Acc: 0.4096 | Time: 220.9s
Epoch 12 | Loss: 2.1748 | Acc: 0.4175 | Time: 222.1s
Epoch 13 | Loss: 2.1202 | Acc: 0.4317 | Time: 220.7s
Epoch 14 | Loss: 2.0260 | Acc: 0.4503 | Time: 221.0s
Epoch 15 | Loss: 1.9948 | Acc: 0.4582 | Time: 220.7s
Epoch 16 | Loss: 1.9251 | Acc: 0.4743 | Time: 221.8s
Epoch 17 | Loss: 1.8899 | Acc: 0.4811 | Time: 221.2s
Epoch 18 | Loss: 1.7947 | Acc: 0.5026 | Time: 220.7s
Epoch 19 | Loss: 1.7059 | Acc: 0.5215 | Time: 220.5s
Epoch 20 | Loss: 1.6568 | Acc: 0.5325 | Time: 220.6s
Epoch 21 | Loss: 1.6600 | Acc: 0.5282 | Time: 220.9s
Epoch 22 | Loss: 1.6555 | Acc: 0.5297 | Time: 221.1s
Epoch 23 | Loss: 1.5758 | Acc: 0.5522 | Time: 220.7s
Epoch 24 | Loss: 1.4820 | Acc: 0.5721 | Time: 220.8s
Epoch 25 | Loss: 1.3498 | Acc: 0.6058 | Time: 221.4s
Epoch 26 | Loss: 1.2785 | Acc: 0.6217 | Time: 222.6s
Epoch 27 | Loss: 1.1744 | Acc: 0.6486 | Time: 222.6s
Epoch 28 | Loss: 1.1640 | Acc: 0.6528 | Time: 222.8s
Epoch 29 | Loss: 1.1275 | Acc: 0.6616 | Time: 222.4s
Epoch 30 | Loss: 1.0543 | Acc: 0.6809 | Time: 222.8s
Epoch 31 | Loss: 0.9848 | Acc: 0.6999 | Time: 222.8s
Epoch 32 | Loss: 0.8899 | Acc: 0.7274 | Time: 223.6s
Epoch 33 | Loss: 0.9022 | Acc: 0.7201 | Time: 222.6s
Epoch 34 | Loss: 0.7973 | Acc: 0.7506 | Time: 221.1s
Epoch 35 | Loss: 0.7528 | Acc: 0.7635 | Time: 220.5s
Epoch 36 | Loss: 0.7421 | Acc: 0.7658 | Time: 220.6s
Epoch 37 | Loss: 0.6831 | Acc: 0.7841 | Time: 220.7s
Epoch 38 | Loss: 0.6923 | Acc: 0.7841 | Time: 220.9s
Epoch 39 | Loss: 0.6746 | Acc: 0.7881 | Time: 220.5s
Epoch 40 | Loss: 0.6440 | Acc: 0.7968 | Time: 220.5s
Epoch 41 | Loss: 0.5860 | Acc: 0.8143 | Time: 220.5s
Epoch 42 | Loss: 0.5642 | Acc: 0.8200 | Time: 221.5s
Epoch 43 | Loss: 0.5572 | Acc: 0.8241 | Time: 221.9s
Epoch 44 | Loss: 0.6289 | Acc: 0.8013 | Time: 221.9s
Epoch 45 | Loss: 0.5685 | Acc: 0.8197 | Time: 221.9s
Epoch 46 | Loss: 0.6295 | Acc: 0.8023 | Time: 221.9s
Epoch 47 | Loss: 0.5720 | Acc: 0.8190 | Time: 221.8s
Epoch 48 | Loss: 0.5371 | Acc: 0.8305 | Time: 221.7s
Epoch 49 | Loss: 0.4337 | Acc: 0.8604 | Time: 222.4s
Epoch 50 | Loss: 0.4385 | Acc: 0.8594 | Time: 222.7s
Next Steps
Above we have: CNN Stem → Transformer Encoder (all attention blocks)
The actual CoAtNet paper stacks them in stages:
Stage 0: Conv blocks (S0) Stage 1: Conv blocks (S1) Stage 2: Transformer blocks (S2) Stage 3: Transformer blocks (S3)
# !pip install timmimport timm
# List available coat models
print(timm.list_models('coatnet*'))
# Load one and inspect
model = timm.create_model('coatnet_0_224', pretrained=False)
print(model)['coatnet_0_224', 'coatnet_0_rw_224', 'coatnet_1_224', 'coatnet_1_rw_224', 'coatnet_2_224', 'coatnet_2_rw_224', 'coatnet_3_224', 'coatnet_3_rw_224', 'coatnet_4_224', 'coatnet_5_224', 'coatnet_bn_0_rw_224', 'coatnet_nano_cc_224', 'coatnet_nano_rw_224', 'coatnet_pico_rw_224', 'coatnet_rmlp_0_rw_224', 'coatnet_rmlp_1_rw2_224', 'coatnet_rmlp_1_rw_224', 'coatnet_rmlp_2_rw_224', 'coatnet_rmlp_2_rw_384', 'coatnet_rmlp_3_rw_224', 'coatnet_rmlp_nano_rw_224']
MaxxVit(
(stem): Stem(
(conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(norm1): BatchNormAct2d(
64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): GELU()
)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(stages): Sequential(
(0): MaxxVitStage(
(blocks): Sequential(
(0): MbConvBlock(
(shortcut): Downsample2d(
(pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
(expand): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))
)
(pre_norm): BatchNormAct2d(
64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(down): Identity()
(conv1_1x1): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm1): BatchNormAct2d(
384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): GELU()
)
(conv2_kxk): Conv2d(384, 384, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=384, bias=False)
(norm2): BatchNormAct2d(
384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): GELU()
)
(se): SEModule(
(fc1): Conv2d(384, 24, kernel_size=(1, 1), stride=(1, 1))
(bn): Identity()
(act): SiLU(inplace=True)
(fc2): Conv2d(24, 384, kernel_size=(1, 1), stride=(1, 1))
(gate): Sigmoid()
)
(conv3_1x1): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1))
(drop_path): Identity()
)
(1): MbConvBlock(
(shortcut): Identity()
(pre_norm): BatchNormAct2d(
96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(down): Identity()
(conv1_1x1): Conv2d(96, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm1): BatchNormAct2d(
384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): GELU()
)
(conv2_kxk): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)
(norm2): BatchNormAct2d(
384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): GELU()
)
(se): SEModule(
(fc1): Conv2d(384, 24, kernel_size=(1, 1), stride=(1, 1))
(bn): Identity()
(act): SiLU(inplace=True)
(fc2): Conv2d(24, 384, kernel_size=(1, 1), stride=(1, 1))
(gate): Sigmoid()
)
(conv3_1x1): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1))
(drop_path): Identity()
)
)
)
(1): MaxxVitStage(
(blocks): Sequential(
(0): MbConvBlock(
(shortcut): Downsample2d(
(pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
(expand): Conv2d(96, 192, kernel_size=(1, 1), stride=(1, 1))
)
(pre_norm): BatchNormAct2d(
96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(down): Identity()
(conv1_1x1): Conv2d(96, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm1): BatchNormAct2d(
768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): GELU()
)
(conv2_kxk): Conv2d(768, 768, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=768, bias=False)
(norm2): BatchNormAct2d(
768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): GELU()
)
(se): SEModule(
(fc1): Conv2d(768, 48, kernel_size=(1, 1), stride=(1, 1))
(bn): Identity()
(act): SiLU(inplace=True)
(fc2): Conv2d(48, 768, kernel_size=(1, 1), stride=(1, 1))
(gate): Sigmoid()
)
(conv3_1x1): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1))
(drop_path): Identity()
)
(1): MbConvBlock(
(shortcut): Identity()
(pre_norm): BatchNormAct2d(
192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(down): Identity()
(conv1_1x1): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm1): BatchNormAct2d(
768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): GELU()
)
(conv2_kxk): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=768, bias=False)
(norm2): BatchNormAct2d(
768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): GELU()
)
(se): SEModule(
(fc1): Conv2d(768, 48, kernel_size=(1, 1), stride=(1, 1))
(bn): Identity()
(act): SiLU(inplace=True)
(fc2): Conv2d(48, 768, kernel_size=(1, 1), stride=(1, 1))
(gate): Sigmoid()
)
(conv3_1x1): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1))
(drop_path): Identity()
)
(2): MbConvBlock(
(shortcut): Identity()
(pre_norm): BatchNormAct2d(
192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): Identity()
)
(down): Identity()
(conv1_1x1): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)
(norm1): BatchNormAct2d(
768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): GELU()
)
(conv2_kxk): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=768, bias=False)
(norm2): BatchNormAct2d(
768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): GELU()
)
(se): SEModule(
(fc1): Conv2d(768, 48, kernel_size=(1, 1), stride=(1, 1))
(bn): Identity()
(act): SiLU(inplace=True)
(fc2): Conv2d(48, 768, kernel_size=(1, 1), stride=(1, 1))
(gate): Sigmoid()
)
(conv3_1x1): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1))
(drop_path): Identity()
)
)
)
(2): MaxxVitStage(
(blocks): Sequential(
(0): TransformerBlock2d(
(shortcut): Downsample2d(
(pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
(expand): Conv2d(192, 384, kernel_size=(1, 1), stride=(1, 1))
)
(norm1): Sequential(
(norm): LayerNorm2d((192,), eps=1e-06, elementwise_affine=True)
(down): Downsample2d(
(pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
(expand): Identity()
)
)
(attn): Attention2d(
(qkv): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1))
(rel_pos): RelPosBias()
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Conv2d(384, 384, kernel_size=(1, 1), stride=(1, 1))
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm2d((384,), eps=1e-06, elementwise_affine=True)
(mlp): ConvMlp(
(fc1): Conv2d(384, 1536, kernel_size=(1, 1), stride=(1, 1))
(norm): Identity()
(act): GELU()
(drop): Dropout(p=0.0, inplace=False)
(fc2): Conv2d(1536, 384, kernel_size=(1, 1), stride=(1, 1))
)
(ls2): Identity()
(drop_path2): Identity()
)
(1): TransformerBlock2d(
(shortcut): Identity()
(norm1): LayerNorm2d((384,), eps=1e-06, elementwise_affine=True)
(attn): Attention2d(
(qkv): Conv2d(384, 1152, kernel_size=(1, 1), stride=(1, 1))
(rel_pos): RelPosBias()
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Conv2d(384, 384, kernel_size=(1, 1), stride=(1, 1))
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm2d((384,), eps=1e-06, elementwise_affine=True)
(mlp): ConvMlp(
(fc1): Conv2d(384, 1536, kernel_size=(1, 1), stride=(1, 1))
(norm): Identity()
(act): GELU()
(drop): Dropout(p=0.0, inplace=False)
(fc2): Conv2d(1536, 384, kernel_size=(1, 1), stride=(1, 1))
)
(ls2): Identity()
(drop_path2): Identity()
)
(2): TransformerBlock2d(
(shortcut): Identity()
(norm1): LayerNorm2d((384,), eps=1e-06, elementwise_affine=True)
(attn): Attention2d(
(qkv): Conv2d(384, 1152, kernel_size=(1, 1), stride=(1, 1))
(rel_pos): RelPosBias()
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Conv2d(384, 384, kernel_size=(1, 1), stride=(1, 1))
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm2d((384,), eps=1e-06, elementwise_affine=True)
(mlp): ConvMlp(
(fc1): Conv2d(384, 1536, kernel_size=(1, 1), stride=(1, 1))
(norm): Identity()
(act): GELU()
(drop): Dropout(p=0.0, inplace=False)
(fc2): Conv2d(1536, 384, kernel_size=(1, 1), stride=(1, 1))
)
(ls2): Identity()
(drop_path2): Identity()
)
(3): TransformerBlock2d(
(shortcut): Identity()
(norm1): LayerNorm2d((384,), eps=1e-06, elementwise_affine=True)
(attn): Attention2d(
(qkv): Conv2d(384, 1152, kernel_size=(1, 1), stride=(1, 1))
(rel_pos): RelPosBias()
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Conv2d(384, 384, kernel_size=(1, 1), stride=(1, 1))
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm2d((384,), eps=1e-06, elementwise_affine=True)
(mlp): ConvMlp(
(fc1): Conv2d(384, 1536, kernel_size=(1, 1), stride=(1, 1))
(norm): Identity()
(act): GELU()
(drop): Dropout(p=0.0, inplace=False)
(fc2): Conv2d(1536, 384, kernel_size=(1, 1), stride=(1, 1))
)
(ls2): Identity()
(drop_path2): Identity()
)
(4): TransformerBlock2d(
(shortcut): Identity()
(norm1): LayerNorm2d((384,), eps=1e-06, elementwise_affine=True)
(attn): Attention2d(
(qkv): Conv2d(384, 1152, kernel_size=(1, 1), stride=(1, 1))
(rel_pos): RelPosBias()
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Conv2d(384, 384, kernel_size=(1, 1), stride=(1, 1))
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm2d((384,), eps=1e-06, elementwise_affine=True)
(mlp): ConvMlp(
(fc1): Conv2d(384, 1536, kernel_size=(1, 1), stride=(1, 1))
(norm): Identity()
(act): GELU()
(drop): Dropout(p=0.0, inplace=False)
(fc2): Conv2d(1536, 384, kernel_size=(1, 1), stride=(1, 1))
)
(ls2): Identity()
(drop_path2): Identity()
)
)
)
(3): MaxxVitStage(
(blocks): Sequential(
(0): TransformerBlock2d(
(shortcut): Downsample2d(
(pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
(expand): Conv2d(384, 768, kernel_size=(1, 1), stride=(1, 1))
)
(norm1): Sequential(
(norm): LayerNorm2d((384,), eps=1e-06, elementwise_affine=True)
(down): Downsample2d(
(pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
(expand): Identity()
)
)
(attn): Attention2d(
(qkv): Conv2d(384, 2304, kernel_size=(1, 1), stride=(1, 1))
(rel_pos): RelPosBias()
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Conv2d(768, 768, kernel_size=(1, 1), stride=(1, 1))
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm2d((768,), eps=1e-06, elementwise_affine=True)
(mlp): ConvMlp(
(fc1): Conv2d(768, 3072, kernel_size=(1, 1), stride=(1, 1))
(norm): Identity()
(act): GELU()
(drop): Dropout(p=0.0, inplace=False)
(fc2): Conv2d(3072, 768, kernel_size=(1, 1), stride=(1, 1))
)
(ls2): Identity()
(drop_path2): Identity()
)
(1): TransformerBlock2d(
(shortcut): Identity()
(norm1): LayerNorm2d((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention2d(
(qkv): Conv2d(768, 2304, kernel_size=(1, 1), stride=(1, 1))
(rel_pos): RelPosBias()
(attn_drop): Dropout(p=0.0, inplace=False)
(proj): Conv2d(768, 768, kernel_size=(1, 1), stride=(1, 1))
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm2d((768,), eps=1e-06, elementwise_affine=True)
(mlp): ConvMlp(
(fc1): Conv2d(768, 3072, kernel_size=(1, 1), stride=(1, 1))
(norm): Identity()
(act): GELU()
(drop): Dropout(p=0.0, inplace=False)
(fc2): Conv2d(3072, 768, kernel_size=(1, 1), stride=(1, 1))
)
(ls2): Identity()
(drop_path2): Identity()
)
)
)
)
(norm): Identity()
(head): NormMlpClassifierHead(
(global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Identity())
(norm): LayerNorm2d((768,), eps=1e-06, elementwise_affine=True)
(flatten): Flatten(start_dim=1, end_dim=-1)
(pre_logits): Sequential(
(fc): Linear(in_features=768, out_features=768, bias=True)
(act): Tanh()
)
(drop): Dropout(p=0.0, inplace=False)
(fc): Linear(in_features=768, out_features=1000, bias=True)
)
)
Undersanding the timm implemetation
Stem
- Conv2d(3→64, stride=2) → BatchNorm → GELU → Conv2d(64→64)
Stage 0 + Stage 1: MbConvBlocks (CNN stages)
here, timm uses depth-wise comvilutions like the ones introduces in mobilenetv3.
What was a depthwise conv? each channel is convolved independently for speed.
Stage 2 + Stage 3: TransformerBlock2d (Attention stages)
this is very much like our ViT implementations with some tweaks that we should look into.
Timm uses relative position bias, whears we were using absolute learned embeddings.
No cls_token
This one is interesting, instead of using nn.Linear in the Attention, they use a conv2:
class AttentionHead(nn.Module):
def __init__(self, dim, head_dim, dropout=0.1):
super().__init__()
self.scale = head_dim**-0.5
self.q = nn.Linear(dim, head_dim, bias=False)
self.k = nn.Linear(dim, head_dim, bias=False)
self.v = nn.Linear(dim, head_dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
q, k, v = self.q(x), self.k(x), self.v(x)
attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
return outSo instead of doing:
self.q = nn.Linear(dim, head_dim, bias=False)
self.k = nn.Linear(dim, head_dim bias=False)
self.v = nn.Linear(dim, head_dim, bias=False)they do:
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1) # Q, K, V all in one