Molmo2 Fine-tuning with QLoRA (bf16)

This notebook follows a proven approach for fine-tuning Molmo2 on our SLEAP mouse VQA dataset.

Key design choices: - QLoRA: 4-bit quantized base model + LoRA adapters in bf16 - prepare_model_for_kbit_training: handles gradient flow, LayerNorm casting, and checkpointing - Prompt masking: only train on the assistant’s response, not the user’s question or image tokens - use_reentrant=False: fixes gradient checkpointing + LoRA compatibility - Skip vision backbone quantization: prevents uint8 crashes in LayerNorm

Run in Google Colab with a GPU runtime.

Cell 1: Install dependencies

%pip install --upgrade \
  torch \
  torchvision \
  "transformers==4.57.1" \
  "datasets>=3.0.1" \
  "accelerate>=0.34.2" \
  "bitsandbytes>=0.44.0" \
  "peft>=0.13.0" \
  pillow tensorboard

# Fix CUDA library paths (Colab-specific)
!ln -sf /usr/local/lib/python3.12/dist-packages/nvidia/cu13/lib/libnvrtc-builtins.so.13.0 /usr/lib/libnvrtc-builtins.so.13.0 2>/dev/null || true

Cell 2: Authenticate and download dataset

from huggingface_hub import login, snapshot_download
from google.colab import userdata

login(token=userdata.get('HF_TOKEN'), add_to_git_credential=True)

snapshot_download("jpoberhauser/sleap-mice-vqa", repo_type="dataset", local_dir="data")

import os
if os.path.exists("data/frames.zip") and not os.path.exists("data/frames/frame_0000.png"):
    !unzip -q data/frames.zip -d data/
    print("Frames unzipped.")
else:
    print("Frames already available.")

Cell 3: Load model with QLoRA (4-bit quantized, bf16 compute)

The base model weights are stored in 4-bit (NF4 quantization) to save VRAM. LoRA adapters and compute happen in bf16. The vision backbone is excluded from quantization because its LayerNorm layers would crash in uint8.

import torch
from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig

MODEL_ID = "allenai/Molmo2-4B"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    llm_int8_skip_modules=["vision_backbone"],  # don't quantize the vision encoder
)

model = AutoModelForImageTextToText.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    quantization_config=bnb_config,
)
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)

print(f"Model loaded. VRAM: {torch.cuda.memory_allocated() / 1e9:.1f} GB")

Cell 4: Test base model inference

from PIL import Image

img = Image.open("data/frames/frame_0000.png").convert("RGB")

messages = [
    {
        "role": "user",
        "content": [
            dict(type="image", image=img),
            dict(type="text", text="How many mice are in this image? Describe their positions and postures."),
        ],
    }
]

inputs = processor.apply_chat_template(
    messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True,
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}

with torch.inference_mode():
    output = model.generate(**inputs, max_new_tokens=256)

generated_tokens = output[0, inputs["input_ids"].size(1):]
print("Base model response:")
print(processor.tokenizer.decode(generated_tokens, skip_special_tokens=True))

Cell 5: Attach LoRA adapters

Key differences from our previous attempt: - prepare_model_for_kbit_training replaces our manual gradient hook. It handles: - Casting LayerNorm to float32 for stability - Enabling gradient flow through frozen/quantized layers - Setting up gradient checkpointing with use_reentrant=False - We target 4 layers: att_proj, attn_out, ff_proj, and ff_out - Only the ViT is frozen (not the full vision backbone) — connectors stay trainable

from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType

# Freeze only the ViT, not the full vision backbone (connectors stay trainable)
for param in model.model.vision_backbone.image_vit.parameters():
    param.requires_grad = False

# prepare_model_for_kbit_training handles:
# - casting LayerNorm to float32
# - enabling input gradients (replaces our manual hook)
# - gradient checkpointing setup
model = prepare_model_for_kbit_training(
    model,
    use_gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
)

lora_config = LoraConfig(
    r=16,
    lora_alpha=16,           # alpha = r means scaling factor of 1
    lora_dropout=0.0,
    target_modules=["att_proj", "attn_out", "ff_proj", "ff_out"],  # includes MLP down proj
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
print(f"VRAM after LoRA: {torch.cuda.memory_allocated() / 1e9:.1f} GB")

Cell 6: Prepare VQA dataset

We build a custom Dataset class that: 1. Loads images lazily (only when accessed) 2. Processes through apply_chat_template to get tokenized input 3. Masks the prompt portion of labels with -100, so the model only trains on the assistant’s answer 4. Also masks any tokens >= vocab_size (image patch tokens) as a safety measure

import json
from torch.utils.data import Dataset as TorchDataset
from PIL import Image

class MiceVQADataset(TorchDataset):
    """Dataset that processes VQA pairs into Molmo2's chat format.
    Returns pre-tokenized tensors with prompt-masked labels."""
    
    def __init__(self, data_path, processor):
        with open(data_path) as f:
            self.data = json.load(f)
        self.processor = processor
        self.vocab_size = processor.tokenizer.vocab_size
        self.pad_id = processor.tokenizer.pad_token_id
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        ex = self.data[idx]
        img = Image.open(ex["image"]).convert("RGB")
        
        # Build user-only messages to measure prompt length
        user_messages = [
            {"role": "user", "content": [
                {"type": "image", "image": img},
                {"type": "text", "text": ex["question"]},
            ]}
        ]
        
        # Full messages (user + assistant)
        full_messages = user_messages + [
            {"role": "assistant", "content": [
                {"type": "text", "text": ex["answer"]},
            ]}
        ]
        
        # Tokenize prompt-only (to find where the assistant response starts)
        prompt_inputs = self.processor.apply_chat_template(
            user_messages,
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt",
            return_dict=True,
        )
        prompt_len = prompt_inputs["input_ids"].shape[1]
        
        # Tokenize full conversation
        full_inputs = self.processor.apply_chat_template(
            full_messages,
            add_generation_prompt=False,
            tokenize=True,
            return_tensors="pt",
            return_dict=True,
        )
        
        # Create labels: mask prompt with -100, only train on assistant response
        labels = full_inputs["input_ids"].clone()
        labels[0, :prompt_len] = -100                     # mask the prompt
        if self.pad_id is not None:
            labels[labels == self.pad_id] = -100          # mask padding
        labels[labels >= self.vocab_size] = -100          # mask image tokens (safety)
        full_inputs["labels"] = labels
        
        return full_inputs

train_dataset = MiceVQADataset("data/vqa/train.json", processor)
val_dataset = MiceVQADataset("data/vqa/val.json", processor)

print(f"Train: {len(train_dataset)} examples")
print(f"Val:   {len(val_dataset)} examples")

# Inspect one example
sample = train_dataset[0]
labels = sample["labels"]
valid = labels[labels != -100]
print(f"\nSample input shape: {sample['input_ids'].shape}")
print(f"Prompt tokens (masked): {(labels == -100).sum().item()}")
print(f"Answer tokens (trained): {(labels != -100).sum().item()}")
print(f"Any labels >= vocab_size? {(valid >= processor.tokenizer.vocab_size).any()}")

Cell 7: Collator and training config

The collator is a simple passthrough since each dataset item is already fully processed with batch dimension. batch_size=1 is required for Molmo2’s multi-crop image handling.

from transformers import TrainingArguments, Trainer

def passthrough_collator(batch):
    """Passthrough collator — dataset already returns ready tensors."""
    assert len(batch) == 1
    return batch[0]

training_args = TrainingArguments(
    output_dir="./molmo2-mice-qlora",
    num_train_epochs=3,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=4,      # effective batch size = 4
    learning_rate=1e-4,                 # lower LR for QLoRA stability
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    max_grad_norm=1.0,
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=200,
    save_strategy="steps",
    save_steps=200,
    save_total_limit=3,
    bf16=torch.cuda.is_bf16_supported(),
    fp16=not torch.cuda.is_bf16_supported(),
    tf32=True,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    dataloader_pin_memory=False,
    dataloader_num_workers=0,
    remove_unused_columns=False,
    report_to="tensorboard",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=passthrough_collator,
)

print(f"Training on {len(train_dataset)} examples")
print(f"Evaluating on {len(val_dataset)} examples")
print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"bf16: {training_args.bf16}, fp16: {training_args.fp16}")
print(f"VRAM before training: {torch.cuda.memory_allocated() / 1e9:.1f} GB")

Cell 8: Sanity check before training

# Quick forward+backward test
test_batch = train_dataset[0]
test_batch = {k: v.to(model.device) if hasattr(v, 'to') else v for k, v in test_batch.items()}

output = model(**test_batch)
print(f"Forward OK! Loss: {output.loss.item():.4f}")

output.loss.backward()
print("Backward OK!")

grads_ok = sum(1 for n, p in model.named_parameters() if p.requires_grad and p.grad is not None)
grads_none = sum(1 for n, p in model.named_parameters() if p.requires_grad and p.grad is None)
print(f"Params with gradients: {grads_ok}")
print(f"Params with NO gradients: {grads_none}")

model.zero_grad()

if grads_none == 0:
    print("\nAll good — ready to train!")
else:
    print(f"\nWARNING: {grads_none} trainable params got no gradients!")

Cell 9: Train!

trainer.train()

Cell 10: Save adapter and test

model.save_pretrained("./molmo2-mice-qlora/final_adapter")
print("Adapter saved to ./molmo2-mice-qlora/final_adapter")
# Test on validation samples
import random

with open("data/vqa/val.json") as f:
    val_raw = json.load(f)

random.seed(42)
test_samples = random.sample(val_raw, 5)

for i, sample in enumerate(test_samples):
    img = Image.open(sample["image"]).convert("RGB")
    
    messages = [
        {"role": "user", "content": [
            dict(type="image", image=img),
            dict(type="text", text=sample["question"]),
        ]}
    ]
    
    inputs = processor.apply_chat_template(
        messages, tokenize=True, add_generation_prompt=True,
        return_tensors="pt", return_dict=True,
    )
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    with torch.inference_mode():
        output = model.generate(**inputs, max_new_tokens=256)
    
    pred = processor.tokenizer.decode(output[0, inputs["input_ids"].size(1):], skip_special_tokens=True)
    
    print(f"{'='*60}")
    print(f"Q: {sample['question']}")
    print(f"Ground truth: {sample['answer']}")
    print(f"Prediction:   {pred}")
    print()

Cell 11 (optional): Push adapter to HuggingFace Hub

model.push_to_hub("jpoberhauser/molmo2-4b-mice-qlora")
processor.push_to_hub("jpoberhauser/molmo2-4b-mice-qlora")