Starting Masked Auto Encoders

Reading: - MAE: https://arxiv.org/abs/2111.06377 - SimMIM: https://arxiv.org/abs/2111.09886

from vit import PatchEmbedding, TransformerEncoder

From the MAE paper:

Steps:

This is a recipe for pretraining. When it comes to fine-tuning, they just get the cnoder and a classification head.

Decoder

  • How does a decoder to predict pixels actually work?
import matplotlib.pyplot as plt                                                                                                           
import torch    
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR100


import torch.nn.functional as F # This gives us the softmax()
import math                                                                                                                          
                                                                                                                                        
# Take a real image patch from your data
transform = transforms.Compose([                                                                                                          
    transforms.Resize((56, 56)),
    transforms.ToTensor(),  # values in [0, 1]
])

train_dataset_raw = CIFAR100(root='./data', train=True, download=True, transform=transform)
img, label = train_dataset_raw[0]  # [3, 56, 56]
plt.imshow(img.permute(1, 2, 0))  # [3, 56, 56] -> [56, 56, 3]                                                                            
plt.title(f"CIFAR-100 resized to 56×56 | Label: {label}")                                                                                 
plt.axis('off')      
class_names = train_dataset_raw.classes  # list of 100 class names                                                                        
print(f"Label: {label}{class_names[label]}")                                                                                                                     
plt.show()                                                                                                                                
             
patch = img[:, 49:56, 49:56]  # [3, 7, 7]
print("Patch shape:", patch.shape)
print("Pixel value range:", patch.min().item(), "to", patch.max().item())
patch_flat = patch.reshape(-1)  # [147]
print("Flattened patch (decoder target):", patch_flat.shape)
# so the decoder must predict these 147 values
patch_flat
random_prediction = torch.rand(147)  # untrained decoder output
trained_prediction = patch_flat + torch.randn(147) * 0.05  # nearly trained decoder simulation, just the patch with noise
fig, axes = plt.subplots(1, 3, figsize=(10, 3))

axes[0].imshow(patch.permute(1, 2, 0))  # [H, W, C]
axes[0].set_title("Ground truth patch\n(decoder target)")

axes[1].imshow(random_prediction.reshape(3, 7, 7).permute(1, 2, 0).clamp(0, 1))
axes[1].set_title("Random prediction\n(untrained decoder)")

axes[2].imshow(trained_prediction.reshape(3, 7, 7).permute(1, 2, 0).clamp(0, 1))
axes[2].set_title("Good prediction\n(trained decoder)")

for ax in axes:
    ax.axis('off')

plt.suptitle(f"Single 7\u00d77 patch = {147} pixel values the decoder must predict")
plt.tight_layout()
print(f"Loss (random): {F.mse_loss(random_prediction, patch_flat):.4f}")
print(f"Loss (trained): {F.mse_loss(trained_prediction, patch_flat):.4f}")
plt.show()

Ok, lets build it out

from vit import TransformerBlock
import numpy as np

class PredictPixel(nn.Module):
    def __init__(self, decoder_hidden_dim, patch_size, mask_ratio=.75):
        super().__init__()
        # TODO: Store decoder_hidden_dim and patch_size
        # TODO: Create a Linear layer mapping decoder_hidden_dim -> patch_size * patch_size * 3
        pass

    def forward(self, x):
        """
        Decoder gives us one vector per patch position.
        Each patch is 7x7x3. So each patch has 147 pixel values.
        """
        # TODO: Pass x through the linear layer and return the result
        pass


class Mask(nn.Module):
    def __init__(self, mask_ratio=0.75):
        super().__init__()
        # TODO: Store mask_ratio
        pass

    def forward(self, x):
        # TODO: Get num_patches from x.shape[1]
        # TODO: Create random permutation of patch indices using torch.randperm
        # TODO: Split into masked_locations (first 75%) and un_masked_locations (last 25%)
        # TODO: Select masked_patches and un_masked_patches using the indices
        # TODO: Return un_masked_patches, masked_patches, masked_locations, un_masked_locations
        pass


class Decoder(nn.Module):
    """
    Take the encoded patches that have not been masked and re-construct
    """
    def __init__(self, decoder_hidden_dim, hidden_dim, num_heads, depth, num_patches):
        super().__init__()
        # TODO: Store decoder_hidden_dim, hidden_dim, num_heads, num_patches
        # TODO: Create project_down Linear layer: hidden_dim -> decoder_hidden_dim
        # TODO: Create learnable pos_embedding Parameter of shape (1, num_patches, decoder_hidden_dim)
        # TODO: Create learnable mask_token Parameter of shape (1, 1, decoder_hidden_dim)
        # TODO: Create nn.ModuleList of TransformerBlock layers (depth blocks)
        #       Each block: TransformerBlock(decoder_hidden_dim, num_heads, mlp_dim=decoder_hidden_dim*4)
        pass

    def forward(self, x, masked_locations, un_masked_locations):
        # TODO: Get batch size B from x.shape
        # TODO: Project encoder output down: x = self.project_down(x)
        # TODO: Expand mask_token to shape (B, len(masked_locations), decoder_hidden_dim)
        # TODO: Create full_sequence of zeros with shape (B, 64, decoder_hidden_dim) on same device as x
        # TODO: Place projected encoder output at un_masked_locations in full_sequence
        # TODO: Place mask_tokens at masked_locations in full_sequence
        # TODO: Add positional embedding to full_sequence
        # TODO: Run full_sequence through each transformer block in self.layers
        # TODO: Return full_sequence
        pass
  • One thing I didnt know the differnce between nn.Sequential and nn.ModuleList

  • nn.ModuleList is just a container — it stores modules but doesn’t have a forward method. You can’t call self.layers(x).

  • You can use nn.Sequential for the encode/decoder, but ModuleList gives you more control (you can skip layers, add conditions, etc.), which is why most transformer implementations use it with a loop.

class MAE(nn.Module):
    def __init__(self, img_size, patch_size, num_hiddens, num_heads, mask_ratio, decoder_hidden_dim, depth=6, dropout=0.1):
        super().__init__()
        # TODO: Calculate num_patches = (img_size // patch_size) ** 2
        # TODO: Store mask_ratio, depth, patch_size
        # TODO: Create PatchEmbedding(img_size, patch_size, num_hiddens)
        # TODO: Create Mask()
        # TODO: Create learnable pos_embedding Parameter of shape (1, num_patches, num_hiddens)
        # TODO: Create TransformerEncoder(num_hiddens, depth, num_heads, mlp_dim=num_hiddens*4)
        # TODO: Create Decoder(decoder_hidden_dim, num_hiddens, num_heads, depth, num_patches=self.num_patches)
        # TODO: Create PredictPixel(decoder_hidden_dim, self.patch_size)
        pass

    def forward(self, x):
        # TODO: Get batch size B from x.shape[0]
        #
        # TODO: Extract raw patches as reconstruction targets using unfold:
        #   raw = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        #   raw = raw.permute(0, 2, 3, 1, 4, 5).reshape(B, -1, self.patch_size * self.patch_size * 3)
        #
        # TODO: Patch embed the input: patches = self.patch_embedding(x)
        # TODO: Flatten and transpose: patches = patches.flatten(2).transpose(1, 2)
        #
        # TODO: Add positional embedding to patches
        # TODO: Apply mask to get un_masked_patches, masked_patches, masked_locations, un_masked_locations
        #
        # TODO: Encode only the un_masked_patches through self.encoder
        # TODO: Decode using self.decoder with masked_locations and un_masked_locations
        # TODO: Predict pixels using self.pixel_head
        #
        # TODO: Select predictions at masked_locations only
        # TODO: Select targets (raw) at masked_locations only
        # TODO: Compute MSE loss between predictions and targets
        # TODO: Return loss, x (the full predictions)
        pass
img = torch.randn(2, 3, 56, 56)
mae_model = MAE(patch_size =7, num_hiddens=256, num_heads=8,mask_ratio=.75, decoder_hidden_dim = 128, img_size = 56)
loss, out = mae_model(img)
loss
def visualize_mae(model, dataset, device, patch_size=7, img_size=56):    
      model.to(device)                                                                                                          
      model.eval()
      img, label = dataset[0]                                                                                                                                                        
      x = img.unsqueeze(0).to(device)  # [1, 3, 56, 56]

      with torch.no_grad():
          # Get patches and mask
          patches = model.patch_embedding(x)
          patches = patches.flatten(2).transpose(1, 2)
          patches = patches + model.pos_embedding
          un_masked_patches, masked_patches, masked_locations, un_masked_locations = model.mask(patches)

          # Encode and decode
          encoded = model.encoder(un_masked_patches)
          decoded = model.decoder(encoded, masked_locations, un_masked_locations)
          predicted = model.pixel_head(decoded)  # [1, 64, 147]

      # Build raw patches as target
      grid_size = img_size // patch_size
      raw_patches = img.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)
      # [3, 8, 8, 7, 7] -> [64, 3, 7, 7]
      raw_patches = raw_patches.permute(1, 2, 0, 3, 4).reshape(-1, 3, patch_size, patch_size)

      # Reconstruct images
      pred_patches = predicted[0].cpu().reshape(-1, 3, patch_size, patch_size).clamp(0, 1)

      fig, axes = plt.subplots(1, 4, figsize=(20, 5))

      # 1. Original image
      axes[0].imshow(img.permute(1, 2, 0))
      axes[0].set_title("Original")

      # 2. Masked image (only visible patches shown)
      masked_img = torch.zeros_like(img)
      for idx in un_masked_locations:
          row = idx // grid_size
          col = idx % grid_size
          masked_img[:, row*patch_size:(row+1)*patch_size, col*patch_size:(col+1)*patch_size] = \
              img[:, row*patch_size:(row+1)*patch_size, col*patch_size:(col+1)*patch_size]
      axes[1].imshow(masked_img.permute(1, 2, 0))
      axes[1].set_title(f"Visible patches ({len(un_masked_locations)}/64)")

      # 3. Reconstruction (full)
      recon_img = torch.zeros(3, img_size, img_size)
      for i in range(grid_size * grid_size):
          row = i // grid_size
          col = i % grid_size
          recon_img[:, row*patch_size:(row+1)*patch_size, col*patch_size:(col+1)*patch_size] = \
              pred_patches[i]
      axes[2].imshow(recon_img.permute(1, 2, 0).clamp(0, 1))
      axes[2].set_title("Full reconstruction")

      # 4. Only masked patches reconstructed, visible patches from original
      mixed_img = img.clone()
      for idx in masked_locations:
          row = idx // grid_size
          col = idx % grid_size
          mixed_img[:, row*patch_size:(row+1)*patch_size, col*patch_size:(col+1)*patch_size] = \
              pred_patches[idx]
      axes[3].imshow(mixed_img.permute(1, 2, 0).clamp(0, 1))
      axes[3].set_title("Masked patches filled in")

      for ax in axes:
          ax.axis('off')
      plt.tight_layout()
      plt.show()
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using: {device}")
visualize_mae(mae_model, train_dataset_raw, device)# untrained model

training loop

from torch.utils.data import DataLoader
transform = transforms.Compose([                                                                                                                                                   
      transforms.Resize((56, 56)),                                                                                                                                                   
      transforms.ToTensor(),                                                                                                                                                         
  ])                                                                                                                                                                                 
                                                                                                                                                                                     
train_dataset = CIFAR100(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

model = MAE(
    img_size=56, patch_size=7, num_hiddens=256, num_heads=8,
    mask_ratio=0.75, decoder_hidden_dim=128, depth=2, dropout=0.1
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
for epoch in range(50):
    model.train()
    total_loss = 0

    for images, _ in train_loader:
        images = images.to(device)
        loss, predictions = model(images)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f}")
visualize_mae(model, train_dataset_raw, device)

Now lets explore a transformers implementation from HuggingFace

# !pip install transformers
from transformers import ViTMAEModel, ViTMAEForPreTraining

model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
print(model)