Molmo2 Fine-tune Notebook

import numpy as np
# Look at a single labeled frame in detail
lf = labels.labeled_frames[0]
print(f"Frame index: {lf.frame_idx}")                                                                                       
print(f"Number of instances (mice): {len(lf.instances)}")                                                                 
                                                                                                                            
for i, instance in enumerate(lf.instances):                                                                               
    print(f"\n  Mouse {i} (track: {instance.track})")
    pts = instance.numpy()  # (n_nodes, 2) array of x, y
    for node, pt in zip(skeleton.nodes, pts):
        if not np.isnan(pt[0]):
            print(f"    {node.name:>15s}: x={pt[0]:6.1f}, y={pt[1]:6.1f}")
        else:
            print(f"    {node.name:>15s}: not labeled")
Frame index: 5259
Number of instances (mice): 2

  Mouse 0 (track: None)
              snout: x= 554.0, y= 328.0
               earL: x= 586.0, y= 332.0
               earR: x= 590.0, y= 306.0
                 tb: x= 740.0, y= 352.0
                 tt: x= 874.9, y= 423.2

  Mouse 1 (track: None)
              snout: x= 392.0, y= 372.0
               earL: x= 444.1, y= 367.4
               earR: x= 433.1, y= 347.6
                 tb: x= 541.8, y= 324.9
                 tt: x= 705.1, y= 262.4
┌──────────────────────────┬────────────────────────────┬────────────────────────────────────────────────────────────────────┐   
│         Question         │       Generic answer       │                       Domain-precise answer                        │
├──────────────────────────┼────────────────────────────┼────────────────────────────────────────────────────────────────────┤   
│ "How many mice and       │ "Two mice, one top one     │ "2 mice. Mouse A at (640, 320), Mouse B at (400, 800)"             │
│ where?"                  │ bottom"                    │                                                                    │
├──────────────────────────┼────────────────────────────┼────────────────────────────────────────────────────────────────────┤   
│ "Is this mouse           │ "The mouse appears to be   │ "Yes — forepaws are elevated near the snout, head is flexed        │   
│ grooming?"               │ cleaning itself"           │ ventrally, consistent with face grooming behavior"                 │   
├──────────────────────────┼────────────────────────────┼────────────────────────────────────────────────────────────────────┤   
│ "What body parts are     │ "Part of the mouse is      │ "Left ear and left hindpaw of mouse closest to the left are        │
│ occluded?"               │ hidden"                    │ occluded"                                                          │   
├──────────────────────────┼────────────────────────────┼────────────────────────────────────────────────────────────────────┤
│ "Estimate the            │ Can't do this              │ "Nose-to-nose distance is approximately 150px (~4cm at this        │   
│ inter-mouse distance"    │                            │ scale)"                                                            │   
└──────────────────────────┴────────────────────────────┴────────────────────────────────────────────────────────────────────┘


::: {#cell-4 .cell execution_count=1}
``` {.python .cell-code}
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

:::

 %pip install "transformers==4.57.1"
Collecting transformers==4.57.1

  Using cached transformers-4.57.1-py3-none-any.whl.metadata (43 kB)

Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from transformers==4.57.1) (3.25.2)

Collecting huggingface-hub<1.0,>=0.34.0 (from transformers==4.57.1)

  Using cached huggingface_hub-0.36.2-py3-none-any.whl.metadata (15 kB)

Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from transformers==4.57.1) (2.0.2)

Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from transformers==4.57.1) (26.0)

Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from transformers==4.57.1) (6.0.3)

Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from transformers==4.57.1) (2025.11.3)

Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from transformers==4.57.1) (2.32.4)

Requirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in /usr/local/lib/python3.12/dist-packages (from transformers==4.57.1) (0.22.2)

Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.12/dist-packages (from transformers==4.57.1) (0.7.0)

Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.12/dist-packages (from transformers==4.57.1) (4.67.3)

Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.34.0->transformers==4.57.1) (2025.3.0)

Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.34.0->transformers==4.57.1) (1.4.2)

Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.34.0->transformers==4.57.1) (4.15.0)

Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->transformers==4.57.1) (3.4.6)

Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->transformers==4.57.1) (3.11)

Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->transformers==4.57.1) (2.5.0)

Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->transformers==4.57.1) (2026.2.25)

Using cached transformers-4.57.1-py3-none-any.whl (12.0 MB)

Using cached huggingface_hub-0.36.2-py3-none-any.whl (566 kB)

Installing collected packages: huggingface-hub, transformers

  Attempting uninstall: huggingface-hub

    Found existing installation: huggingface_hub 1.8.0

    Uninstalling huggingface_hub-1.8.0:

      Successfully uninstalled huggingface_hub-1.8.0

  Attempting uninstall: transformers

    Found existing installation: transformers 5.4.0

    Uninstalling transformers-5.4.0:

      Successfully uninstalled transformers-5.4.0

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.

gradio 5.50.0 requires pillow<12.0,>=8.0, but you have pillow 12.1.1 which is incompatible.

Successfully installed huggingface-hub-0.36.2 transformers-4.57.1
import transformers
print(transformers.__version__)
4.57.1
# Cell 1a: Install dependencies
%pip install --upgrade \
  torch \
  torchvision \
  "transformers>=4.57.1" \
  "datasets>=3.0.1" \
  "accelerate>=0.34.2" \
  "bitsandbytes>=0.44.0" \
  "trl>=0.15.0" \
  "peft>=0.13.0" \
  pillow tensorboard

# Fix CUDA library paths (Colab-specific)
!ln -sf /usr/local/lib/python3.12/dist-packages/nvidia/cu13/lib/libnvrtc-builtins.so.13.0 /usr/lib/libnvrtc-builtins.so.13.0
Requirement already satisfied: torch in /usr/local/lib/python3.12/dist-packages (2.11.0)

Requirement already satisfied: torchvision in /usr/local/lib/python3.12/dist-packages (0.26.0)

Requirement already satisfied: transformers>=4.57.1 in /usr/local/lib/python3.12/dist-packages (4.57.1)

Collecting transformers>=4.57.1

  Downloading transformers-5.4.0-py3-none-any.whl.metadata (32 kB)

Requirement already satisfied: datasets>=3.0.1 in /usr/local/lib/python3.12/dist-packages (4.8.4)

Requirement already satisfied: accelerate>=0.34.2 in /usr/local/lib/python3.12/dist-packages (1.13.0)

Requirement already satisfied: bitsandbytes>=0.44.0 in /usr/local/lib/python3.12/dist-packages (0.49.2)

Requirement already satisfied: trl>=0.15.0 in /usr/local/lib/python3.12/dist-packages (1.0.0)

Requirement already satisfied: peft>=0.13.0 in /usr/local/lib/python3.12/dist-packages (0.18.1)

Requirement already satisfied: pillow in /usr/local/lib/python3.12/dist-packages (12.1.1)

Requirement already satisfied: tensorboard in /usr/local/lib/python3.12/dist-packages (2.20.0)

Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch) (3.25.2)

Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.12/dist-packages (from torch) (4.15.0)

Requirement already satisfied: setuptools<82 in /usr/local/lib/python3.12/dist-packages (from torch) (75.2.0)

Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch) (1.14.0)

Requirement already satisfied: networkx>=2.5.1 in /usr/local/lib/python3.12/dist-packages (from torch) (3.6.1)

Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch) (3.1.6)

Requirement already satisfied: fsspec>=0.8.5 in /usr/local/lib/python3.12/dist-packages (from torch) (2025.3.0)

Requirement already satisfied: cuda-toolkit==13.0.2 in /usr/local/lib/python3.12/dist-packages (from cuda-toolkit[cublas,cudart,cufft,cufile,cupti,curand,cusolver,cusparse,nvjitlink,nvrtc,nvtx]==13.0.2; platform_system == "Linux"->torch) (13.0.2)

Requirement already satisfied: cuda-bindings<14,>=13.0.3 in /usr/local/lib/python3.12/dist-packages (from torch) (13.2.0)

Requirement already satisfied: nvidia-cudnn-cu13==9.19.0.56 in /usr/local/lib/python3.12/dist-packages (from torch) (9.19.0.56)

Requirement already satisfied: nvidia-cusparselt-cu13==0.8.0 in /usr/local/lib/python3.12/dist-packages (from torch) (0.8.0)

Requirement already satisfied: nvidia-nccl-cu13==2.28.9 in /usr/local/lib/python3.12/dist-packages (from torch) (2.28.9)

Requirement already satisfied: nvidia-nvshmem-cu13==3.4.5 in /usr/local/lib/python3.12/dist-packages (from torch) (3.4.5)

Requirement already satisfied: triton==3.6.0 in /usr/local/lib/python3.12/dist-packages (from torch) (3.6.0)

Requirement already satisfied: nvidia-cublas==13.1.0.3.* in /usr/local/lib/python3.12/dist-packages (from cuda-toolkit[cublas,cudart,cufft,cufile,cupti,curand,cusolver,cusparse,nvjitlink,nvrtc,nvtx]==13.0.2; platform_system == "Linux"->torch) (13.1.0.3)

Requirement already satisfied: nvidia-cuda-runtime==13.0.96.* in /usr/local/lib/python3.12/dist-packages (from cuda-toolkit[cublas,cudart,cufft,cufile,cupti,curand,cusolver,cusparse,nvjitlink,nvrtc,nvtx]==13.0.2; platform_system == "Linux"->torch) (13.0.96)

Requirement already satisfied: nvidia-cufft==12.0.0.61.* in /usr/local/lib/python3.12/dist-packages (from cuda-toolkit[cublas,cudart,cufft,cufile,cupti,curand,cusolver,cusparse,nvjitlink,nvrtc,nvtx]==13.0.2; platform_system == "Linux"->torch) (12.0.0.61)

Requirement already satisfied: nvidia-cufile==1.15.1.6.* in /usr/local/lib/python3.12/dist-packages (from cuda-toolkit[cublas,cudart,cufft,cufile,cupti,curand,cusolver,cusparse,nvjitlink,nvrtc,nvtx]==13.0.2; platform_system == "Linux"->torch) (1.15.1.6)

Requirement already satisfied: nvidia-cuda-cupti==13.0.85.* in /usr/local/lib/python3.12/dist-packages (from cuda-toolkit[cublas,cudart,cufft,cufile,cupti,curand,cusolver,cusparse,nvjitlink,nvrtc,nvtx]==13.0.2; platform_system == "Linux"->torch) (13.0.85)

Requirement already satisfied: nvidia-curand==10.4.0.35.* in /usr/local/lib/python3.12/dist-packages (from cuda-toolkit[cublas,cudart,cufft,cufile,cupti,curand,cusolver,cusparse,nvjitlink,nvrtc,nvtx]==13.0.2; platform_system == "Linux"->torch) (10.4.0.35)

Requirement already satisfied: nvidia-cusolver==12.0.4.66.* in /usr/local/lib/python3.12/dist-packages (from cuda-toolkit[cublas,cudart,cufft,cufile,cupti,curand,cusolver,cusparse,nvjitlink,nvrtc,nvtx]==13.0.2; platform_system == "Linux"->torch) (12.0.4.66)

Requirement already satisfied: nvidia-cusparse==12.6.3.3.* in /usr/local/lib/python3.12/dist-packages (from cuda-toolkit[cublas,cudart,cufft,cufile,cupti,curand,cusolver,cusparse,nvjitlink,nvrtc,nvtx]==13.0.2; platform_system == "Linux"->torch) (12.6.3.3)

Requirement already satisfied: nvidia-nvjitlink==13.0.88.* in /usr/local/lib/python3.12/dist-packages (from cuda-toolkit[cublas,cudart,cufft,cufile,cupti,curand,cusolver,cusparse,nvjitlink,nvrtc,nvtx]==13.0.2; platform_system == "Linux"->torch) (13.0.88)

Requirement already satisfied: nvidia-cuda-nvrtc==13.0.88.* in /usr/local/lib/python3.12/dist-packages (from cuda-toolkit[cublas,cudart,cufft,cufile,cupti,curand,cusolver,cusparse,nvjitlink,nvrtc,nvtx]==13.0.2; platform_system == "Linux"->torch) (13.0.88)

Requirement already satisfied: nvidia-nvtx==13.0.85.* in /usr/local/lib/python3.12/dist-packages (from cuda-toolkit[cublas,cudart,cufft,cufile,cupti,curand,cusolver,cusparse,nvjitlink,nvrtc,nvtx]==13.0.2; platform_system == "Linux"->torch) (13.0.85)

Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from torchvision) (2.0.2)

Collecting huggingface-hub<2.0,>=1.5.0 (from transformers>=4.57.1)

  Downloading huggingface_hub-1.8.0-py3-none-any.whl.metadata (13 kB)

Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from transformers>=4.57.1) (26.0)

Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from transformers>=4.57.1) (6.0.3)

Requirement already satisfied: regex>=2025.10.22 in /usr/local/lib/python3.12/dist-packages (from transformers>=4.57.1) (2025.11.3)

Requirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in /usr/local/lib/python3.12/dist-packages (from transformers>=4.57.1) (0.22.2)

Requirement already satisfied: typer in /usr/local/lib/python3.12/dist-packages (from transformers>=4.57.1) (0.24.1)

Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.12/dist-packages (from transformers>=4.57.1) (0.7.0)

Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.12/dist-packages (from transformers>=4.57.1) (4.67.3)

Requirement already satisfied: pyarrow>=21.0.0 in /usr/local/lib/python3.12/dist-packages (from datasets>=3.0.1) (23.0.1)

Requirement already satisfied: dill<0.4.2,>=0.3.0 in /usr/local/lib/python3.12/dist-packages (from datasets>=3.0.1) (0.3.8)

Requirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (from datasets>=3.0.1) (2.2.2)

Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.12/dist-packages (from datasets>=3.0.1) (2.32.4)

Requirement already satisfied: httpx<1.0.0 in /usr/local/lib/python3.12/dist-packages (from datasets>=3.0.1) (0.28.1)

Requirement already satisfied: xxhash in /usr/local/lib/python3.12/dist-packages (from datasets>=3.0.1) (3.6.0)

Requirement already satisfied: multiprocess<0.70.20 in /usr/local/lib/python3.12/dist-packages (from datasets>=3.0.1) (0.70.16)

Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from accelerate>=0.34.2) (5.9.5)

Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.12/dist-packages (from tensorboard) (1.4.0)

Requirement already satisfied: grpcio>=1.48.2 in /usr/local/lib/python3.12/dist-packages (from tensorboard) (1.78.0)

Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.12/dist-packages (from tensorboard) (3.10.2)

Requirement already satisfied: protobuf!=4.24.0,>=3.19.6 in /usr/local/lib/python3.12/dist-packages (from tensorboard) (5.29.6)

Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.12/dist-packages (from tensorboard) (0.7.2)

Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from tensorboard) (3.1.6)

Requirement already satisfied: cuda-pathfinder~=1.1 in /usr/local/lib/python3.12/dist-packages (from cuda-bindings<14,>=13.0.3->torch) (1.4.3)

Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<=2026.2.0,>=2023.1.0->datasets>=3.0.1) (3.13.3)

Requirement already satisfied: anyio in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->datasets>=3.0.1) (4.12.1)

Requirement already satisfied: certifi in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->datasets>=3.0.1) (2026.2.25)

Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->datasets>=3.0.1) (1.0.9)

Requirement already satisfied: idna in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->datasets>=3.0.1) (3.11)

Requirement already satisfied: h11>=0.16 in /usr/local/lib/python3.12/dist-packages (from httpcore==1.*->httpx<1.0.0->datasets>=3.0.1) (0.16.0)

Requirement already satisfied: hf-xet<2.0.0,>=1.4.2 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<2.0,>=1.5.0->transformers>=4.57.1) (1.4.2)

Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets>=3.0.1) (3.4.6)

Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets>=3.0.1) (2.5.0)

Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch) (1.3.0)

Requirement already satisfied: markupsafe>=2.1.1 in /usr/local/lib/python3.12/dist-packages (from werkzeug>=1.0.1->tensorboard) (3.0.3)

Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets>=3.0.1) (2.9.0.post0)

Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets>=3.0.1) (2025.2)

Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets>=3.0.1) (2025.3)

Requirement already satisfied: click>=8.2.1 in /usr/local/lib/python3.12/dist-packages (from typer->transformers>=4.57.1) (8.3.1)

Requirement already satisfied: shellingham>=1.3.0 in /usr/local/lib/python3.12/dist-packages (from typer->transformers>=4.57.1) (1.5.4)

Requirement already satisfied: rich>=12.3.0 in /usr/local/lib/python3.12/dist-packages (from typer->transformers>=4.57.1) (13.9.4)

Requirement already satisfied: annotated-doc>=0.0.2 in /usr/local/lib/python3.12/dist-packages (from typer->transformers>=4.57.1) (0.0.4)

Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets>=3.0.1) (2.6.1)

Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets>=3.0.1) (1.4.0)

Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets>=3.0.1) (25.4.0)

Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets>=3.0.1) (1.8.0)

Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets>=3.0.1) (6.7.1)

Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets>=3.0.1) (0.4.1)

Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets>=3.0.1) (1.23.0)

Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.8.2->pandas->datasets>=3.0.1) (1.17.0)

Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.12/dist-packages (from rich>=12.3.0->typer->transformers>=4.57.1) (4.0.0)

Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.12/dist-packages (from rich>=12.3.0->typer->transformers>=4.57.1) (2.19.2)

Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.12/dist-packages (from markdown-it-py>=2.2.0->rich>=12.3.0->typer->transformers>=4.57.1) (0.1.2)

Downloading transformers-5.4.0-py3-none-any.whl (10.1 MB)

   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 10.1/10.1 MB 94.0 MB/s eta 0:00:00

Downloading huggingface_hub-1.8.0-py3-none-any.whl (625 kB)

   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 625.2/625.2 kB 81.0 MB/s eta 0:00:00

Installing collected packages: huggingface-hub, transformers

  Attempting uninstall: huggingface-hub

    Found existing installation: huggingface_hub 0.36.2

    Uninstalling huggingface_hub-0.36.2:

      Successfully uninstalled huggingface_hub-0.36.2

  Attempting uninstall: transformers

    Found existing installation: transformers 4.57.1

    Uninstalling transformers-4.57.1:

      Successfully uninstalled transformers-4.57.1

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.

gradio 5.50.0 requires pillow<12.0,>=8.0, but you have pillow 12.1.1 which is incompatible.

Successfully installed huggingface-hub-1.8.0 transformers-5.4.0
from huggingface_hub import login, snapshot_download
from google.colab import userdata

login(token=userdata.get('HF_TOKEN'), add_to_git_credential=True)

snapshot_download("jpoberhauser/sleap-mice-vqa", repo_type="dataset", local_dir="data")

# Unzip frames if they were uploaded as a zip
import os
if os.path.exists("data/frames.zip") and not os.path.exists("data/frames/frame_0000.png"):
    !unzip -q data/frames.zip -d data/
    print("Frames unzipped.")
else:
    print("Frames already available.")
Frames already available.
# Unzip frames
# !unzip /content/data/frames.zip -d /content/data/
!ls /content/data/
all.json  frames  frames.zip  train.json  val.json
import json
import matplotlib.pyplot as plt
from PIL import Image
import random

# Load from saved JSON
with open("/content/data/train.json") as f:
    train_data = json.load(f)

with open("/content/data/val.json") as f:
    val_data = json.load(f)

print(f"Train: {len(train_data)} QA pairs")
print(f"Val:   {len(val_data)} QA pairs")

# Show random samples with their images
random.seed(1023)
samples = random.sample(train_data, 6)

fig, axes = plt.subplots(2, 3, figsize=(18, 12))
for ax, sample in zip(axes.flat, samples):
    old_dir = sample["image"]
    new_dir = f"/content/{old_dir}"
    img = Image.open(new_dir)
    print(sample["image"])
    ax.imshow(img, cmap="gray")
    # Wrap the Q/A text for display
    q = sample["question"]
    a = sample["answer"]
    ax.set_title(f"Q: {q}\nA: {a}", fontsize=8, wrap=True, pad=10)
    ax.axis("off")

plt.suptitle("Sample VQA pairs from train set", fontsize=14)
plt.tight_layout()
plt.show()
Train: 11011 QA pairs
Val:   1976 QA pairs
data/frames/frame_0854.png
data/frames/frame_0797.png
data/frames/frame_0839.png
data/frames/frame_0226.png
data/frames/frame_0744.png
data/frames/frame_0557.png

#!ln -sf /usr/local/lib/python3.12/dist-packages/nvidia/cu13/lib/libnvrtc-builtins.so.13.0 /usr/lib/libnvrtc-builtins.so.13.0
# import os
# os.environ['LD_LIBRARY_PATH'] = '/usr/local/lib/python3.12/dist-packages/nvidia/nvjitlink/lib:'
# os.environ.get('LD_LIBRARY_PATH', '')
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor

model_id = "allenai/Molmo2-4B"

model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    torch_dtype=torch.float32,
    device_map="auto",
    trust_remote_code=True,
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
print(f"Model loaded. dtype: {model.dtype}")
`torch_dtype` is deprecated! Use `dtype` instead!
A new version of the following files was downloaded from https://huggingface.co/allenai/Molmo2-4B:
- image_processing_molmo2.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/allenai/Molmo2-4B:
- video_processing_molmo2.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/allenai/Molmo2-4B:
- processing_molmo2.py
- image_processing_molmo2.py
- video_processing_molmo2.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Model loaded. dtype: torch.float32
from PIL import Image

img = Image.open("data/frames/frame_0000.png").convert("RGB")

messages = [
    {
        "role": "user",
        "content": [
            dict(type="image", image=img),
            dict(type="text", text="How many mice are in this image? Describe their positions and postures."),
        ],
    }
]

inputs = processor.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,
    return_tensors="pt",
    return_dict=True,
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}

with torch.inference_mode():
    output = model.generate(**inputs, max_new_tokens=256)

generated_tokens = output[0, inputs["input_ids"].size(1):]
print("Base model response:")
print(processor.tokenizer.decode(generated_tokens, skip_special_tokens=True))
Base model response:
There are two mice visible in this image. They are positioned in the center of the container, facing each other. The mouse on the left is facing left, while the mouse on the right is facing right. Both mice appear to be in a relaxed posture, with their tails visible. The container they're in seems to be a clear plastic enclosure with a gravelly bottom, likely a pet habitat or research setup.
img

As you can see, the model is good at general things, but it says the mice are facing each other (incorrectly), when they are not. It also says mouse on the left is facing left (correctly) but incorrectly saus the mouse on the right is facing right.

Lets fine-tune with LORA adapters

Step 1: Attach LoRA adapters

LoRA (Low-Rank Adaptation) freezes the base model and injects small trainable matrices into the attention layers. Instead of updating all ~4B parameters, we only train ~0.5-1% of them.

Key choices: - r=16 — rank of the low-rank matrices (higher = more capacity but more params - lora_alpha=32 — scaling factor (rule of thumb: 2x the rank) - We target the query, key, value, and output projection layers in the LLM

Ok, how do we find the target modules that we actually want? LEts print them out to understand what the model has and how they are named:


import torch
from transformers import AutoModelForImageTextToText, AutoProcessor

model_id = "allenai/Molmo2-4B"

model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    torch_dtype=torch.float32,
    device_map="auto",
    trust_remote_code=True,
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
print(f"Model loaded. dtype: {model.dtype}")
Model loaded. dtype: torch.float32
# for name, module in model.named_modules():
#     if "attn" in name.lower() or "proj" in name.lower():
#         print(name)
  • so it looks like we want to add adapters to attn_proj, attn_out, and ff_proj.

  • why?

  • attn_proj : this layer is a projection of the three q,k,v matrices. this is where the multi-modal model is actually deciding what to attend to.

  • attn_out : This is the layer where we add the W_o projection. This comes right after we finish the multi-head attention. In our PaliGemma implementation it looks like:

#self.out_proj applies the final linear projection W_O on the concatenated result, matching the standard formula: Output =   
        #Concat(head_1, ..., head_h) · W_O.
        attn_output = self.out_proj(attn_output)# we need this to mix between heads.
# where the out_proj was defined as:
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) # W_o matrix
  • lastly, ff_proj : this is the feedForward layer in the mutli-head attention code. this is where the model will do per-token transformations right after attention.

LoRA

  • for LoRa, its kind of standard practice to start with attention layers. this comes right from the LoRa paper where they show that attention projections are the best parameter-efficient targets to add adapters to.

  • if we want to be more conservative, we can just add adapters to attn_proj and see how far that gets us for our use case.

  • if we add all three layers, we get a print out saying trainable parameters are ~21M out of the 4.8B parameters in the full model, so we will train about 4.4% of the parameters

from peft import LoraConfig, get_peft_model

# Freeze the vision encoder — we only want to fine-tune the language model
for param in model.model.vision_backbone.parameters():
    param.requires_grad = False

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.0,
    target_modules=["att_proj", "attn_out", "ff_proj"],
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)

# Forward hook to enable gradients through frozen layers.
# We need to find the correct path to the first transformer block.
def make_inputs_require_grad(module, input, output):
    if isinstance(output, tuple):
        output[0].requires_grad_(True)
    else:
        output.requires_grad_(True)

# Find and hook the first transformer block (path varies by how model is wrapped)
hooked = False
for name, module in model.named_modules():
    if name.endswith("blocks.0") and "vision" not in name:
        module.register_forward_hook(make_inputs_require_grad)
        print(f"Gradient hook attached to: {name}")
        hooked = True
        break

if not hooked:
    print("WARNING: Could not find transformer blocks. Listing candidates:")
    for name, module in model.named_modules():
        if "blocks.0" in name:
            print(f"  {name}")

model.print_trainable_parameters()
ERROR:bitsandbytes.cextension:bitsandbytes library load error: libnvJitLink.so.13: cannot open shared object file: No such file or directory
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/bitsandbytes/cextension.py", line 320, in <module>
    lib = get_native_library()
          ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/bitsandbytes/cextension.py", line 298, in get_native_library
    dll = ct.cdll.LoadLibrary(str(binary_path))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/ctypes/__init__.py", line 460, in LoadLibrary
    return self._dlltype(name)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/ctypes/__init__.py", line 379, in __init__
    self._handle = _dlopen(self._name, mode)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^
OSError: libnvJitLink.so.13: cannot open shared object file: No such file or directory
Gradient hook attached to: base_model.model.model.transformer.blocks.0
trainable params: 21,528,576 || all params: 4,872,397,776 || trainable%: 0.4418
import json
from datasets import Dataset

# Load VQA data
with open("data/train.json") as f:
    train_raw = json.load(f)
with open("data/val.json") as f:
    val_raw = json.load(f)

def format_to_chat(examples):
    """Convert VQA JSON into chat message format. Stores image paths (lazy loading)."""
    formatted = []
    for ex in examples:
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": ex["image"]},
                    {"type": "text", "text": ex["question"]},
                ],
            },
            {
                "role": "assistant",
                "content": [
                    {"type": "text", "text": ex["answer"]},
                ],
            },
        ]
        formatted.append({"messages": messages})
    return formatted

train_formatted = format_to_chat(train_raw)
val_formatted = format_to_chat(val_raw)

train_dataset = Dataset.from_list(train_formatted)
val_dataset = Dataset.from_list(val_formatted)

print(f"Train: {len(train_dataset)} examples")
print(f"Val:   {len(val_dataset)} examples")
Train: 11011 examples
Val:   1976 examples
train_dataset = Dataset.from_list(train_formatted)
val_dataset = Dataset.from_list(val_formatted)

print(train_dataset)
print(val_dataset)
Dataset({
    features: ['messages'],
    num_rows: 11011
})
Dataset({
    features: ['messages'],
    num_rows: 1976
})
  from PIL import Image

  def collate_fn(examples):
      """Custom collator — processes one example at a time.
      Molmo2's multi-crop image handling requires batch_size=1."""

      ex = examples[0]  # single example per batch

      # Build messages with lazily loaded images
      messages = []
      for msg in ex["messages"]:
          new_content = []
          for content in msg["content"]:
              if content["type"] == "image":
                  img = Image.open(content["image"]).convert("RGB")
                  new_content.append({"type": "image", "image": img})
              else:
                  new_content.append(content)
          messages.append({"role": msg["role"], "content": new_content})

      # Process through apply_chat_template
      batch = processor.apply_chat_template(
          messages,
          add_generation_prompt=False,
          tokenize=True,
          return_tensors="pt",
          return_dict=True,
      )

      # Create labels — mask tokens the model shouldn't predict:
      # 1. Padding tokens
      # 2. Image/special tokens that are >= vocab_size (handled internally, not by the LM head)
      labels = batch["input_ids"].clone()
      vocab_size = processor.tokenizer.vocab_size
      if processor.tokenizer.pad_token_id is not None:
          labels[labels == processor.tokenizer.pad_token_id] = -100
      labels[labels >= vocab_size] = -100  # mask image patch tokens and other special tokens
      batch["labels"] = labels

      return batch

  # Quick test
  test_batch = collate_fn([train_formatted[0]])
  print(f"Input IDs shape: {test_batch['input_ids'].shape}")
  print(f"Keys: {list(test_batch.keys())}")

  # Verify no out-of-range labels
  labels = test_batch["labels"]
  valid_labels = labels[labels != -100]
  vocab_size = processor.tokenizer.vocab_size
  print(f"Valid label range: {valid_labels.min()} to {valid_labels.max()}")
  print(f"Any labels >= vocab_size? {(valid_labels >= vocab_size).any()}")
  print(f"Tokens masked as -100: {(labels == -100).sum().item()} out of {labels.numel()}")
Input IDs shape: torch.Size([1, 1011])
Keys: ['input_ids', 'attention_mask', 'token_type_ids', 'pixel_values', 'image_token_pooling', 'image_grids', 'image_num_crops', 'labels']
Valid label range: 13 to 77091
Any labels >= vocab_size? False
Tokens masked as -100: 986 out of 1011

Note on the clamping above

  • why do we need to check if there are labels greater than the vocab size?
  • I think what was happening, and i dont know if this is best practice (?) is that cross entropy loss was receiving the image_tokens along the predicted text_tokens. That is obviously not what we want, we just want the model to be graded on the text_tokens it produces, not the image tokens!
  • in pytorch, cross_entropy has an ignore index, which is -100. So if we take any image tokens that exceeded the vocab_size, and set it to -100, then we only grade the model on the text_tokens.
  • I believe the Molmo2 fine-tuning handles this internally, but the labels = input_ids.clone() in huggingface did not handle that properly?
  • There might be a better best-practice here, but lets see if this works?
# What are ALL the image-related token IDs?
print(f"image_patch_id: {model.config.image_patch_id}")

# Check all config attributes for image token IDs
for attr in dir(model.config):
    if "image" in attr.lower() or "patch" in attr.lower():
        val = getattr(model.config, attr)
        if isinstance(val, int):
            print(f"{attr}: {val} (>= vocab? {val >= 151643})")
test_batch = {k: v.to(model.device) if hasattr(v, 'to') else v for k, v in test_batch.items()}
output = model(**test_batch)
print(f"Forward OK! Loss: {output.loss.item():.4f}")

output.loss.backward()
print("Backward OK!")

# Check gradients reached LoRA params
grads_ok = sum(1 for n, p in model.named_parameters() if p.requires_grad and p.grad is not None)
grads_none = sum(1 for n, p in model.named_parameters() if p.requires_grad and p.grad is None)
print(f"Params with gradients: {grads_ok}")
print(f"Params with NO gradients: {grads_none}")

model.zero_grad()

if grads_none == 0:
    print("\nAll good — ready to train!")
Forward OK! Loss: 4.1107
Backward OK!
Params with gradients: 216
Params with NO gradients: 0

All good — ready to train!
test_batch = collate_fn([train_formatted[0]])

vocab_size = processor.tokenizer.vocab_size
print(f"Vocab size: {vocab_size}")

labels = test_batch["labels"]
valid_labels = labels[labels != -100]
print(f"Label range: {valid_labels.min()} to {valid_labels.max()}")
print(f"Any labels >= vocab_size? {(valid_labels >= vocab_size).any()}")

input_ids = test_batch["input_ids"]
print(f"Input ID range: {input_ids.min()} to {input_ids.max()}")
print(f"Any input_ids >= vocab_size? {(input_ids >= vocab_size).any()}")
Vocab size: 151643
Label range: 13 to 77091
Any labels >= vocab_size? False
Input ID range: 13 to 151940
Any input_ids >= vocab_size? True
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./molmo2-mice-lora",
    num_train_epochs=3,
    per_device_train_batch_size=1,    # must be 1 for Molmo2's multi-crop images
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,    # effective batch size = 8
    learning_rate=2e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=200,
    save_strategy="steps",
    save_steps=200,
    save_total_limit=3,
    fp16=False,                       # no mixed precision — everything is fp32
    bf16=False,
    gradient_checkpointing=False,     # disabled — 98GB VRAM is plenty
    dataloader_pin_memory=False,
    remove_unused_columns=False,
    report_to="tensorboard",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_fn,
)

print(f"Training on {len(train_dataset)} examples")
print(f"Evaluating on {len(val_dataset)} examples")
print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"Steps per epoch: ~{len(train_dataset) // training_args.gradient_accumulation_steps}")
The model is already on multiple devices. Skipping the move to device specified in `args`.
Training on 11011 examples
Evaluating on 1976 examples
Effective batch size: 8
Steps per epoch: ~1376
 # 1. Verify LoRA params are trainable and in the right dtype
trainable = [(n, p.dtype, p.shape) for n, p in model.named_parameters() if p.requires_grad]
print(f"Trainable params: {len(trainable)}")
for name, dtype, shape in trainable[:5]:
    print(f"  {name}: {dtype}, {shape}")
Trainable params: 216
  base_model.model.model.transformer.blocks.0.self_attn.att_proj.lora_A.default.weight: torch.float32, torch.Size([16, 2560])
  base_model.model.model.transformer.blocks.0.self_attn.att_proj.lora_B.default.weight: torch.float32, torch.Size([6144, 16])
  base_model.model.model.transformer.blocks.0.self_attn.attn_out.lora_A.default.weight: torch.float32, torch.Size([16, 4096])
  base_model.model.model.transformer.blocks.0.self_attn.attn_out.lora_B.default.weight: torch.float32, torch.Size([2560, 16])
  base_model.model.model.transformer.blocks.0.mlp.ff_proj.lora_A.default.weight: torch.float32, torch.Size([16, 2560])
 # 2. Verify gradient hook is working — do a test forward+backward pass
test_batch = collate_fn([train_formatted[0]])
test_batch = {k: v.to(model.device) if hasattr(v, 'to') else v for k, v in test_batch.items()}

# Print shapes and dtypes of all inputs
for k, v in test_batch.items():
    if hasattr(v, 'shape'):
        print(f"  {k}: shape={v.shape}, dtype={v.dtype}, device={v.device}")
    else:
        print(f"  {k}: {type(v)}")

# Try forward only first
output = model(**test_batch)
print(f"\nForward pass OK! Loss: {output.loss.item():.4f}")
  input_ids: shape=torch.Size([1, 1011]), dtype=torch.int64, device=cuda:0
  attention_mask: shape=torch.Size([1, 1011]), dtype=torch.int64, device=cuda:0
  token_type_ids: shape=torch.Size([1, 1011]), dtype=torch.bool, device=cuda:0
  pixel_values: shape=torch.Size([7, 729, 588]), dtype=torch.float32, device=cuda:0
  image_token_pooling: shape=torch.Size([955, 4]), dtype=torch.int64, device=cuda:0
  image_grids: shape=torch.Size([1, 4]), dtype=torch.int64, device=cuda:0
  image_num_crops: shape=torch.Size([1]), dtype=torch.int64, device=cuda:0
  labels: shape=torch.Size([1, 1011]), dtype=torch.int64, device=cuda:0

Forward pass OK! Loss: 4.1107
trainer.train()# we are training with fp12 so its sloooow
[2289/4131 4:54:57 < 3:57:33, 0.13 it/s, Epoch 1.66/3]
Step Training Loss Validation Loss
200 0.214400 0.189229
400 0.165100 0.174236
600 0.178900 0.166137
800 0.173200 0.163261
1000 0.157400 0.160916
1200 0.161900 0.157965
1400 0.148600 0.153788
1600 0.161000 0.154005
1800 0.132300 0.155237
2000 0.165700 0.151982
2200 0.140300 0.150819

---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
/tmp/ipykernel_9281/4032920361.py in <cell line: 0>()
----> 1 trainer.train()

/usr/local/lib/python3.12/dist-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   2323                 hf_hub_utils.enable_progress_bars()
   2324         else:
-> 2325             return inner_training_loop(
   2326                 args=args,
   2327                 resume_from_checkpoint=resume_from_checkpoint,

/usr/local/lib/python3.12/dist-packages/transformers/trainer.py in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2672                     )
   2673                     with context():
-> 2674                         tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
   2675 
   2676                     if (

/usr/local/lib/python3.12/dist-packages/transformers/trainer.py in training_step(self, model, inputs, num_items_in_batch)
   4018 
   4019             with self.compute_loss_context_manager():
-> 4020                 loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
   4021 
   4022             del inputs

/usr/local/lib/python3.12/dist-packages/transformers/trainer.py in compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
   4108                 kwargs["num_items_in_batch"] = num_items_in_batch
   4109             inputs = {**inputs, **kwargs}
-> 4110         outputs = model(**inputs)
   4111         # Save past state if it exists
   4112         # TODO: this needs to be fixed and made cleaner later.

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1777             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1778         else:
-> 1779             return self._call_impl(*args, **kwargs)
   1780 
   1781     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1788                 or _global_backward_pre_hooks or _global_backward_hooks
   1789                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1790             return forward_call(*args, **kwargs)
   1791 
   1792         result = None

/usr/local/lib/python3.12/dist-packages/peft/peft_model.py in forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
   1921             with self._enable_peft_forward_hooks(**kwargs):
   1922                 kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1923                 return self.base_model(
   1924                     input_ids=input_ids,
   1925                     attention_mask=attention_mask,

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1777             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1778         else:
-> 1779             return self._call_impl(*args, **kwargs)
   1780 
   1781     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1788                 or _global_backward_pre_hooks or _global_backward_hooks
   1789                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1790             return forward_call(*args, **kwargs)
   1791 
   1792         result = None

/usr/local/lib/python3.12/dist-packages/peft/tuners/tuners_utils.py in forward(self, *args, **kwargs)
    309 
    310     def forward(self, *args: Any, **kwargs: Any):
--> 311         return self.model.forward(*args, **kwargs)
    312 
    313     def _pre_injection_hook(self, model: nn.Module, config: PeftConfig, adapter_name: str) -> None:

/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py in wrapper(self, *args, **kwargs)
    916         if return_dict_passed is not None:
    917             return_dict = return_dict_passed
--> 918         output = func(self, *args, **kwargs)
    919         if not return_dict and not isinstance(output, tuple):
    920             output = output.to_tuple()

~/.cache/huggingface/modules/transformers_modules/allenai/Molmo2_hyphen_4B/042abfa7a38879a376cec03d949eff0aefaa0600/modeling_molmo2.py in forward(self, input_ids, pixel_values, image_token_pooling, image_grids, image_num_crops, pixel_values_videos, video_token_pooling, video_grids, attention_mask, position_ids, past_key_values, token_type_ids, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, cache_position, logits_to_keep, **kwargs)
   1650         "The image shows a bustling street scene in what appears to be a Chinatown area. There's ..."
   1651         ```"""
-> 1652         outputs = self.model(
   1653             input_ids=input_ids,
   1654             pixel_values=pixel_values,

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1777             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1778         else:
-> 1779             return self._call_impl(*args, **kwargs)
   1780 
   1781     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1788                 or _global_backward_pre_hooks or _global_backward_hooks
   1789                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1790             return forward_call(*args, **kwargs)
   1791 
   1792         result = None

/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py in wrapper(self, *args, **kwargs)
    916         if return_dict_passed is not None:
    917             return_dict = return_dict_passed
--> 918         output = func(self, *args, **kwargs)
    919         if not return_dict and not isinstance(output, tuple):
    920             output = output.to_tuple()

~/.cache/huggingface/modules/transformers_modules/allenai/Molmo2_hyphen_4B/042abfa7a38879a376cec03d949eff0aefaa0600/modeling_molmo2.py in forward(self, input_ids, pixel_values, image_token_pooling, image_grids, image_num_crops, pixel_values_videos, video_token_pooling, video_grids, attention_mask, position_ids, past_key_values, token_type_ids, inputs_embeds, use_cache, output_attentions, output_hidden_states, cache_position, **kwargs)
   1501 
   1502         if inputs_embeds is None:
-> 1503             inputs_embeds, image_features = self.build_input_embeddings(
   1504                 input_ids, images, token_pooling,
   1505             )

~/.cache/huggingface/modules/transformers_modules/allenai/Molmo2_hyphen_4B/042abfa7a38879a376cec03d949eff0aefaa0600/modeling_molmo2.py in build_input_embeddings(self, input_ids, images, token_pooling)
   1442         image_features: Optional[torch.FloatTensor] = None
   1443         if images is not None:
-> 1444             image_features = self.vision_backbone(images, token_pooling).to(x.device)
   1445             is_image_patch = input_ids.view(-1) == self.config.image_patch_id
   1446             assert is_image_patch.sum() == len(image_features)

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1777             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1778         else:
-> 1779             return self._call_impl(*args, **kwargs)
   1780 
   1781     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1788                 or _global_backward_pre_hooks or _global_backward_hooks
   1789                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1790             return forward_call(*args, **kwargs)
   1791 
   1792         result = None

KeyboardInterrupt: 
# Save the LoRA adapter
model.save_pretrained("./molmo2-mice-lora/final_adapter")
# model.push_to_hub("jpoberhauser/molmo2-4b-mice-lora")
# processor.push_to_hub("jpoberhauser/molmo2-4b-mice-lora")
# Test on a validation image — compare to ground truth
import random

random.seed(99)
test_sample = random.choice(val_raw)
test_img = Image.open(test_sample["image"]).convert("RGB")

print(f"Question:     {test_sample['question']}")
print(f"Ground truth: {test_sample['answer']}")
print()

messages = [
    {
        "role": "user",
        "content": [
            dict(type="image", image=test_img),
            dict(type="text", text=test_sample["question"]),
        ],
    }
]

inputs = processor.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,
    return_tensors="pt",
    return_dict=True,
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}

with torch.inference_mode():
    output = model.generate(**inputs, max_new_tokens=256)

generated_tokens = output[0, inputs["input_ids"].size(1):]
prediction = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
print(f"Fine-tuned model: {prediction}")
Question:     Where is the mouse near the top located in the image?
Ground truth: The mouse near the top is centered at approximately (588, 312) in pixel coordinates.

Fine-tuned model: The mouse near the top is centered at approximately (615, 305) in pixel coordinates. This mouse is near the top of the image, close to the top edge.

And look at the results! so much more detail and grounding extracted for our use case!

  • Question: Where is the mouse near the top located in the image?

  • Ground truth: The mouse near the top is centered at approximately (588, 312) in pixel coordinates.

  • Fine-tuned model: The mouse near the top is centered at approximately (615, 305) in pixel coordinates. This mouse is near the top of the image, close to the top edge.