from vit import PatchEmbedding, TransformerEncoderStarting Masked Auto Encoders
Now we try to understand this aspect of self-supervised learning with a simple masked image model.
- Take an image, mask out a large percentage of patches (75% for MAE)
- Feed only the visible patches into the encoder
- Train a decoder to reconstruct the masked patches
- Take an image, mask out a large percentage of patches (75% for MAE)
Reading: - MAE: https://arxiv.org/abs/2111.06377 - SimMIM: https://arxiv.org/abs/2111.09886
From the MAE paper:
Main idea, mask a high percentage of the patches from an image and train a model to reconstruct the missing patches in pixel space.
If it can do that well, you have a sclabale vision learner.
Steps:
- We use the same approach to patching as our vanilla ViT. Project that into emebddings given a hidden_dim.
- Mask patches. We randomly choose 75% of the patches, and remove them entirely from the sequence we have. We only pass the encoder the remaining 25% of the patches. (this makes training much faster)
- Run visible patches through transformer encoder. This can be the same as our vanilla ViT Encoder.
- Take encoder outputs, and put them back in original positions. The masked positions withh now have a shared learnable mask token. We add positional embeddings back to full grid of patches.
- Decoder: run the full grid of masked and un-maksed tokens through a decoder. This will be a few transformer blocks and the task is to predict original pixels on the masked patches.
- Caluclate the loss. MSE between decoders’ outputs and actual pixel values form the masked patches.
This is a recipe for pretraining. When it comes to fine-tuning, they just get the cnoder and a classification head.
Decoder
- How does a decoder to predict pixels actually work?
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR100
import torch.nn.functional as F # This gives us the softmax()
import math
# Take a real image patch from your data
transform = transforms.Compose([
transforms.Resize((56, 56)),
transforms.ToTensor(), # values in [0, 1]
])
train_dataset_raw = CIFAR100(root='./data', train=True, download=True, transform=transform)
img, label = train_dataset_raw[0] # [3, 56, 56]plt.imshow(img.permute(1, 2, 0)) # [3, 56, 56] -> [56, 56, 3]
plt.title(f"CIFAR-100 resized to 56×56 | Label: {label}")
plt.axis('off')
class_names = train_dataset_raw.classes # list of 100 class names
print(f"Label: {label} → {class_names[label]}")
plt.show()
patch = img[:, 49:56, 49:56] # [3, 7, 7]
print("Patch shape:", patch.shape)
print("Pixel value range:", patch.min().item(), "to", patch.max().item())patch_flat = patch.reshape(-1) # [147]
print("Flattened patch (decoder target):", patch_flat.shape)
# so the decoder must predict these 147 valuespatch_flatrandom_prediction = torch.rand(147) # untrained decoder output
trained_prediction = patch_flat + torch.randn(147) * 0.05 # nearly trained decoder simulation, just the patch with noisefig, axes = plt.subplots(1, 3, figsize=(10, 3))
axes[0].imshow(patch.permute(1, 2, 0)) # [H, W, C]
axes[0].set_title("Ground truth patch\n(decoder target)")
axes[1].imshow(random_prediction.reshape(3, 7, 7).permute(1, 2, 0).clamp(0, 1))
axes[1].set_title("Random prediction\n(untrained decoder)")
axes[2].imshow(trained_prediction.reshape(3, 7, 7).permute(1, 2, 0).clamp(0, 1))
axes[2].set_title("Good prediction\n(trained decoder)")
for ax in axes:
ax.axis('off')
plt.suptitle(f"Single 7\u00d77 patch = {147} pixel values the decoder must predict")
plt.tight_layout()
print(f"Loss (random): {F.mse_loss(random_prediction, patch_flat):.4f}")
print(f"Loss (trained): {F.mse_loss(trained_prediction, patch_flat):.4f}")
plt.show()Ok, lets build it out
from vit import TransformerBlockimport numpy as np
class PredictPixel(nn.Module):
def __init__(self, decoder_hidden_dim, patch_size, mask_ratio=.75):
super().__init__()
# TODO: Store decoder_hidden_dim and patch_size
# TODO: Create a Linear layer mapping decoder_hidden_dim -> patch_size * patch_size * 3
pass
def forward(self, x):
"""
Decoder gives us one vector per patch position.
Each patch is 7x7x3. So each patch has 147 pixel values.
"""
# TODO: Pass x through the linear layer and return the result
pass
class Mask(nn.Module):
def __init__(self, mask_ratio=0.75):
super().__init__()
# TODO: Store mask_ratio
pass
def forward(self, x):
# TODO: Get num_patches from x.shape[1]
# TODO: Create random permutation of patch indices using torch.randperm
# TODO: Split into masked_locations (first 75%) and un_masked_locations (last 25%)
# TODO: Select masked_patches and un_masked_patches using the indices
# TODO: Return un_masked_patches, masked_patches, masked_locations, un_masked_locations
pass
class Decoder(nn.Module):
"""
Take the encoded patches that have not been masked and re-construct
"""
def __init__(self, decoder_hidden_dim, hidden_dim, num_heads, depth, num_patches):
super().__init__()
# TODO: Store decoder_hidden_dim, hidden_dim, num_heads, num_patches
# TODO: Create project_down Linear layer: hidden_dim -> decoder_hidden_dim
# TODO: Create learnable pos_embedding Parameter of shape (1, num_patches, decoder_hidden_dim)
# TODO: Create learnable mask_token Parameter of shape (1, 1, decoder_hidden_dim)
# TODO: Create nn.ModuleList of TransformerBlock layers (depth blocks)
# Each block: TransformerBlock(decoder_hidden_dim, num_heads, mlp_dim=decoder_hidden_dim*4)
pass
def forward(self, x, masked_locations, un_masked_locations):
# TODO: Get batch size B from x.shape
# TODO: Project encoder output down: x = self.project_down(x)
# TODO: Expand mask_token to shape (B, len(masked_locations), decoder_hidden_dim)
# TODO: Create full_sequence of zeros with shape (B, 64, decoder_hidden_dim) on same device as x
# TODO: Place projected encoder output at un_masked_locations in full_sequence
# TODO: Place mask_tokens at masked_locations in full_sequence
# TODO: Add positional embedding to full_sequence
# TODO: Run full_sequence through each transformer block in self.layers
# TODO: Return full_sequence
passOne thing I didnt know the differnce between
nn.Sequentialandnn.ModuleListnn.ModuleList is just a container — it stores modules but doesn’t have a forward method. You can’t call self.layers(x).
You can use nn.Sequential for the encode/decoder, but ModuleList gives you more control (you can skip layers, add conditions, etc.), which is why most transformer implementations use it with a loop.
class MAE(nn.Module):
def __init__(self, img_size, patch_size, num_hiddens, num_heads, mask_ratio, decoder_hidden_dim, depth=6, dropout=0.1):
super().__init__()
# TODO: Calculate num_patches = (img_size // patch_size) ** 2
# TODO: Store mask_ratio, depth, patch_size
# TODO: Create PatchEmbedding(img_size, patch_size, num_hiddens)
# TODO: Create Mask()
# TODO: Create learnable pos_embedding Parameter of shape (1, num_patches, num_hiddens)
# TODO: Create TransformerEncoder(num_hiddens, depth, num_heads, mlp_dim=num_hiddens*4)
# TODO: Create Decoder(decoder_hidden_dim, num_hiddens, num_heads, depth, num_patches=self.num_patches)
# TODO: Create PredictPixel(decoder_hidden_dim, self.patch_size)
pass
def forward(self, x):
# TODO: Get batch size B from x.shape[0]
#
# TODO: Extract raw patches as reconstruction targets using unfold:
# raw = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
# raw = raw.permute(0, 2, 3, 1, 4, 5).reshape(B, -1, self.patch_size * self.patch_size * 3)
#
# TODO: Patch embed the input: patches = self.patch_embedding(x)
# TODO: Flatten and transpose: patches = patches.flatten(2).transpose(1, 2)
#
# TODO: Add positional embedding to patches
# TODO: Apply mask to get un_masked_patches, masked_patches, masked_locations, un_masked_locations
#
# TODO: Encode only the un_masked_patches through self.encoder
# TODO: Decode using self.decoder with masked_locations and un_masked_locations
# TODO: Predict pixels using self.pixel_head
#
# TODO: Select predictions at masked_locations only
# TODO: Select targets (raw) at masked_locations only
# TODO: Compute MSE loss between predictions and targets
# TODO: Return loss, x (the full predictions)
passimg = torch.randn(2, 3, 56, 56)mae_model = MAE(patch_size =7, num_hiddens=256, num_heads=8,mask_ratio=.75, decoder_hidden_dim = 128, img_size = 56)loss, out = mae_model(img)lossdef visualize_mae(model, dataset, device, patch_size=7, img_size=56):
model.to(device)
model.eval()
img, label = dataset[0]
x = img.unsqueeze(0).to(device) # [1, 3, 56, 56]
with torch.no_grad():
# Get patches and mask
patches = model.patch_embedding(x)
patches = patches.flatten(2).transpose(1, 2)
patches = patches + model.pos_embedding
un_masked_patches, masked_patches, masked_locations, un_masked_locations = model.mask(patches)
# Encode and decode
encoded = model.encoder(un_masked_patches)
decoded = model.decoder(encoded, masked_locations, un_masked_locations)
predicted = model.pixel_head(decoded) # [1, 64, 147]
# Build raw patches as target
grid_size = img_size // patch_size
raw_patches = img.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)
# [3, 8, 8, 7, 7] -> [64, 3, 7, 7]
raw_patches = raw_patches.permute(1, 2, 0, 3, 4).reshape(-1, 3, patch_size, patch_size)
# Reconstruct images
pred_patches = predicted[0].cpu().reshape(-1, 3, patch_size, patch_size).clamp(0, 1)
fig, axes = plt.subplots(1, 4, figsize=(20, 5))
# 1. Original image
axes[0].imshow(img.permute(1, 2, 0))
axes[0].set_title("Original")
# 2. Masked image (only visible patches shown)
masked_img = torch.zeros_like(img)
for idx in un_masked_locations:
row = idx // grid_size
col = idx % grid_size
masked_img[:, row*patch_size:(row+1)*patch_size, col*patch_size:(col+1)*patch_size] = \
img[:, row*patch_size:(row+1)*patch_size, col*patch_size:(col+1)*patch_size]
axes[1].imshow(masked_img.permute(1, 2, 0))
axes[1].set_title(f"Visible patches ({len(un_masked_locations)}/64)")
# 3. Reconstruction (full)
recon_img = torch.zeros(3, img_size, img_size)
for i in range(grid_size * grid_size):
row = i // grid_size
col = i % grid_size
recon_img[:, row*patch_size:(row+1)*patch_size, col*patch_size:(col+1)*patch_size] = \
pred_patches[i]
axes[2].imshow(recon_img.permute(1, 2, 0).clamp(0, 1))
axes[2].set_title("Full reconstruction")
# 4. Only masked patches reconstructed, visible patches from original
mixed_img = img.clone()
for idx in masked_locations:
row = idx // grid_size
col = idx % grid_size
mixed_img[:, row*patch_size:(row+1)*patch_size, col*patch_size:(col+1)*patch_size] = \
pred_patches[idx]
axes[3].imshow(mixed_img.permute(1, 2, 0).clamp(0, 1))
axes[3].set_title("Masked patches filled in")
for ax in axes:
ax.axis('off')
plt.tight_layout()
plt.show()device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using: {device}")visualize_mae(mae_model, train_dataset_raw, device)# untrained modeltraining loop
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.Resize((56, 56)),
transforms.ToTensor(),
])
train_dataset = CIFAR100(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
model = MAE(
img_size=56, patch_size=7, num_hiddens=256, num_heads=8,
mask_ratio=0.75, decoder_hidden_dim=128, depth=2, dropout=0.1
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)for epoch in range(50):
model.train()
total_loss = 0
for images, _ in train_loader:
images = images.to(device)
loss, predictions = model(images)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f}")visualize_mae(model, train_dataset_raw, device)Now lets explore a transformers implementation from HuggingFace
# !pip install transformers
from transformers import ViTMAEModel, ViTMAEForPreTraining
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
print(model)