V-Jepa2.1

PatchEmbeddings for Video (?)

  • So far, we did a 2D patch embedding, now we need to form the intution for adding a temporal dimension.

Before we had:

# 2D: Image → patches
# Input:  (B, 3, 224, 224)  — one frame
# Conv2d: kernel=16x16, stride=16
# Output: (B, embed_dim, 14, 14) → flatten → (B, 196, embed_dim)
  • Now, each token, instead of a 2D patch of an image is a 3D cude of the video.
  • If we have 16 frames as input, and we use a temporal kernel of 2 (meaning each patch spans 2 frames), then we have 8 timesteps.
  • Total tokens is now 8 (temporal chunks) * 14 (patches down) * 14 (patches up) = 1568 tokens or unique 3d patches.
  • Each token/patch captures 2 consecutive frames, 16x16 pixels (if we resize to 224x224) and 3 color channels.
  • Transformer mechanism then uses attention to let all 1,568 tokens attend to each other.
  • We go from 196 patches/tokens to 1,568!
import torch
import torch.nn as nn

# A tiny video: 16 frames, 224x224, RGB
video = torch.randn(1, 3, 16, 224, 224)  # (B, C, T, H, W)

# 3D patch embedding
# kernel = (2, 16, 16) → 2 frames in time, 16x16 in space
# stride = (2, 16, 16) → non-overlapping in all dimensions
patch_embed = nn.Conv3d(
    in_channels=3,
    out_channels=768,         # embed_dim
    kernel_size=(2, 16, 16),  # (temporal, height, width)
    stride=(2, 16, 16),       # no overlap
)
#  nn.Conv2d(in_channels=3,
    # out_channels = 768,            <- Frame Vit
    # kernel_size = (16, 16),
    # stride = (16, 16))

patches = patch_embed(video)
print(f"Input video:  {video.shape}")       # (1, 3, 16, 224, 224)
print(f"After Conv3D: {patches.shape}")      # (1, 768, 8, 14, 14)

# Flatten spatial+temporal into a sequence
B, D, T, H, W = patches.shape
tokens = patches.reshape(B, D, -1).transpose(1, 2)
print(f"Token sequence: {tokens.shape}")     # (1, 1568, 768)
# 1568 = 8 (time) × 14 (height) × 14 (width) patches
Input video:  torch.Size([1, 3, 16, 224, 224])
After Conv3D: torch.Size([1, 768, 8, 14, 14])
Token sequence: torch.Size([1, 1568, 768])

Life of an input to understand architecture

  1. tokenization specific to modality.
    1. input can be an image or a video clip. Goes into a multi-modal tokenizer
    2. videos are processed with the above sample of a 3d convolution. (16x16x2)
    3. images are process with traditional ViT PatchEmbedding, 2d conv. (16x16)
  2. Add Positional and Modality emebeddings
    1. once we have our patches (2d or 3d) model adds Rotational Positional Encoding, this tells the model where the patch exists in time and space.
    2. Also, add a learnable modality embedding. The shared encoder needs to know if its an image or a video.
  3. masking
    1. model applies mask corruption. Randomly drop a large portion of the tokens, leaving only a small set of context tokens
  4. X-Encoder
    1. visible (non-masked) tokens are sent through X-Encoder.
    2. Encoder uses deep self-supervision so it captures representations from multiple intermediate encoder levels.
  5. Multi-level fusion
    1. representations from different levels of the encoder (remeber these are visibile patches only) are concatenated along the channel axis.
    2. Small MLP fuses them together, reducing dimensionality before moving on.
  6. Concat with mask-tokens
    1. Before we send to the predictor, system re-assembles the sequence.
    2. Processed context tokens are concatenated with a set of learnable mask tokens
    3. These mask tokens act as placeholders that carry spatio-temporal information of the patches that were originall dropped.
    4. Why? context tokens might only include a person jumping’s head and feet. To know where their arms are or how they are moving, the model neeeds to be told where the missing information is located in the original scene.
    5. Instead of just leaving empty spaces, the model uses these mask tokens
      1. these placeholders are essentially blank vectors that represent missing data.
      2. They are learnable, so these are parameters that the model can learn and refine over time to represent a concept of a missing patch.
      3. They have useful spatio-temporal information. Each mask will carry its own positional information specifying its exact coordinate in 3D grid of the input.
    6. Now, we have a single, complete sequence that matches the grid (structure) of the original un-masked input. Even though the predictor doesnt receive the masked inputs, the prdictor will know what it say (un-masked tokens) and what needs to be filled in (and wehre those belong)
    7. Predictor can now produce an output token fir every input, whether masked or not.
  7. Multi-Level Predictor
    1. Processed the whole sombined sequence.
    2. Tasked with predicting the representations of the masked patches should be based on the context.
    3. produces four separate outputs, one for each of the 4 encoder layers (this is the self-supervision)
  8. Target Generation and Loss
    1. the clean (un-masked input) is sent through a Y-Encoder
    2. The Y-Encoder is a teacher odel that is updated with EMA
    3. This creates the ‘ground truths’ targets for both masked and un-masked tokens.
    4. Dense predictive loss is a sum of
      1. prediction loss: loss applied to the masked tokens to ensure the model can learn to ‘fill in the blanks’
      2. context loss: loss applied to the visible un-masked tokens. This is a weighted loss that helps prevent visible tokens from just acting as generic global aggregators and instead force them to maintain high-quiality local information.