import torch
import torch.nn as nn
# A tiny video: 16 frames, 224x224, RGB
video = torch.randn(1, 3, 16, 224, 224) # (B, C, T, H, W)
# 3D patch embedding
# kernel = (2, 16, 16) → 2 frames in time, 16x16 in space
# stride = (2, 16, 16) → non-overlapping in all dimensions
patch_embed = nn.Conv3d(
in_channels=3,
out_channels=768, # embed_dim
kernel_size=(2, 16, 16), # (temporal, height, width)
stride=(2, 16, 16), # no overlap
)
# nn.Conv2d(in_channels=3,
# out_channels = 768, <- Frame Vit
# kernel_size = (16, 16),
# stride = (16, 16))
patches = patch_embed(video)
print(f"Input video: {video.shape}") # (1, 3, 16, 224, 224)
print(f"After Conv3D: {patches.shape}") # (1, 768, 8, 14, 14)
# Flatten spatial+temporal into a sequence
B, D, T, H, W = patches.shape
tokens = patches.reshape(B, D, -1).transpose(1, 2)
print(f"Token sequence: {tokens.shape}") # (1, 1568, 768)
# 1568 = 8 (time) × 14 (height) × 14 (width) patches