text = "the cat sat on the mat" Building a Text Encoder from Scratch
Stage 1 Understanding Tokenization
the main idea is how do we get text into a sequence of integers.
we will start with a character-level tokenization to understand the concept and then we will move to building a BPE tokenizer. The main idea is simple, we need a tokenizer to map text to integers. The easiest place to start is just character level.
we can compare our simple tokenizer with one from Huggingface
## we need a vocab:
chars = sorted(set(text))
print(chars)[' ', 'a', 'c', 'e', 'h', 'm', 'n', 'o', 's', 't']
# and then just map that to integers
char_to_id = {ch: i for i, ch in enumerate(chars)}
id_to_char = {i: ch for ch, i in char_to_id.items()}
print("Vocabulary:", char_to_id)
print(f"Vocab size: {len(chars)}")Vocabulary: {' ': 0, 'a': 1, 'c': 2, 'e': 3, 'h': 4, 'm': 5, 'n': 6, 'o': 7, 's': 8, 't': 9}
Vocab size: 10
# and the ecnoder is very simple:
encoded = [char_to_id[ch] for ch in text]
print("Encoded:", encoded)Encoded: [9, 4, 3, 0, 2, 1, 9, 0, 8, 1, 9, 0, 7, 6, 0, 9, 4, 3, 0, 5, 1, 9]
# the decoder is easy as well:
decoded = "".join(id_to_char[i] for i in encoded)
print("Decoded:", decoded)
assert decoded == textDecoded: the cat sat on the mat
this is simple and understandable. But there are some issues. One is that we get long sequences, every character returns an int and we have 22 tokens for a text of 22 characters.
Another issue we encounter here is that each character is encoded on its own. So we miss things like encoding a full word.
BPE
bpe encodes at a character level like we did above, and then iteratively groups the most frequent pairs next to each other.
if you for example start with the word ‘low’. It appears 5 times for example in the entire corpus and the word ‘lower’ 2 times.
lowappeards 5 timeslowerappears twice.
recipe from this blogpost
- identify frequent pairs:
- for each iteration, scan the text to find the most commonly occuring pair of bytes (or characters)
- Replace and record
- replace that pair with a new placeholder ID (one that isnt in use)
- record this mapping in a lookup table
- we can tune the size of that lookup table as the vocab size. For GPT-2 its 50, 257.
- repeat until you se eno gains
- keep repeating steps 1 and 2, continulayy merging the most frequent pairs.
- stop when no further compression is possible.
- Decompression (decoding)
- to restore original text, reverse the process by substituting each ID with its correpsonding pair using the lookup table we built.
So for the example above, we 1. count all the ajacent pairs. (l,o) appears 7 times, (o, w) also, (w,e) twice, and (e,r) appears twice. Then 2. we merge the most frequent pair (l,o) into new token lo. Now the tokens for the words above are : lo w and then lo w e r. Repeat steps 1 and 2 until you hit target vocab size.
So now, common words like “the” or “low” become their own single token. Rare words are split into sub-pieces like “lowest” could become low and est.
Ok, now lets see it in code:
import re
from collections import Counter, defaultdict
text = """The cat sat next to the dog. The cat is much smaller than the dog. The fox jumped over the brown fence. It was the strangest day of the year."""words = re.findall(r'\S+', text.lower())
print("Words:", words)Words: ['the', 'cat', 'sat', 'next', 'to', 'the', 'dog.', 'the', 'cat', 'is', 'much', 'smaller', 'than', 'the', 'dog.', 'the', 'fox', 'jumped', 'over', 'the', 'brown', 'fence.', 'it', 'was', 'the', 'strangest', 'day', 'of', 'the', 'year.']
word_freqs = Counter(words)
print("Word frequencies:", word_freqs)Word frequencies: Counter({'the': 8, 'cat': 2, 'dog.': 2, 'sat': 1, 'next': 1, 'to': 1, 'is': 1, 'much': 1, 'smaller': 1, 'than': 1, 'fox': 1, 'jumped': 1, 'over': 1, 'brown': 1, 'fence.': 1, 'it': 1, 'was': 1, 'strangest': 1, 'day': 1, 'of': 1, 'year.': 1})
# Represent each word as a tuple of characters
# We add a special end-of-word token "▁" so the model knows where words end
vocab_words = {tuple(word) + ('</w>',): freq for word, freq in word_freqs.items()}
print("\nInitial token sequences:")
for word, freq in list(vocab_words.items())[:5]:
print(f" {word} × {freq}")
Initial token sequences:
('t', 'h', 'e', '</w>') × 8
('c', 'a', 't', '</w>') × 2
('s', 'a', 't', '</w>') × 1
('n', 'e', 'x', 't', '</w>') × 1
('t', 'o', '</w>') × 1
def get_pair_counts(vocab_words):
"""Count frequency of all adjacent token pairs across the vocabulary."""
pairs = defaultdict(int)
for tokens, freq in vocab_words.items():
for i in range(len(tokens) - 1):
pairs[(tokens[i], tokens[i + 1])] += freq
return pairsdef merge_pair(pair, vocab_words):
"""Merge every occurrence of `pair` into a single new token."""
new_vocab = {}
bigram = pair # e.g. ('t', 'h')
replacement = pair[0] + pair[1] # e.g. 'th'
for tokens, freq in vocab_words.items():
new_tokens = []
i = 0
while i < len(tokens):
# If we find the pair, merge it
if i < len(tokens) - 1 and tokens[i] == bigram[0] and tokens[i + 1] == bigram[1]:
new_tokens.append(replacement)
i += 2
else:
new_tokens.append(tokens[i])
i += 1
new_vocab[tuple(new_tokens)] = freq
return new_vocabnum_merges = 20 # In practice this would be ~49k for CLIP
merges = [] # This is the learned model
for step in range(num_merges):
pairs = get_pair_counts(vocab_words)
if not pairs:
break
# Find the most frequent pair
best_pair = max(pairs, key=pairs.get)
best_count = pairs[best_pair]
# Merge it
vocab_words = merge_pair(best_pair, vocab_words)
merges.append(best_pair)
print(f"Step {step + 1}: merge {best_pair} → '{best_pair[0] + best_pair[1]}' (appeared {best_count} times)")
print(f"\nLearned {len(merges)} merges")Step 1: merge ('t', 'h') → 'th' (appeared 9 times)
Step 2: merge ('th', 'e') → 'the' (appeared 8 times)
Step 3: merge ('the', '</w>') → 'the</w>' (appeared 8 times)
Step 4: merge ('t', '</w>') → 't</w>' (appeared 6 times)
Step 5: merge ('.', '</w>') → '.</w>' (appeared 4 times)
Step 6: merge ('a', 't</w>') → 'at</w>' (appeared 3 times)
Step 7: merge ('c', 'at</w>') → 'cat</w>' (appeared 2 times)
Step 8: merge ('d', 'o') → 'do' (appeared 2 times)
Step 9: merge ('do', 'g') → 'dog' (appeared 2 times)
Step 10: merge ('dog', '.</w>') → 'dog.</w>' (appeared 2 times)
Step 11: merge ('s', '</w>') → 's</w>' (appeared 2 times)
Step 12: merge ('e', 'r') → 'er' (appeared 2 times)
Step 13: merge ('er', '</w>') → 'er</w>' (appeared 2 times)
Step 14: merge ('a', 'n') → 'an' (appeared 2 times)
Step 15: merge ('s', 'at</w>') → 'sat</w>' (appeared 1 times)
Step 16: merge ('n', 'e') → 'ne' (appeared 1 times)
Step 17: merge ('ne', 'x') → 'nex' (appeared 1 times)
Step 18: merge ('nex', 't</w>') → 'next</w>' (appeared 1 times)
Step 19: merge ('t', 'o') → 'to' (appeared 1 times)
Step 20: merge ('to', '</w>') → 'to</w>' (appeared 1 times)
Learned 20 merges
def encode(text, merges):
"""Tokenize a string using learned BPE merges."""
words = re.findall(r'\S+', text.lower())
all_tokens = []
for word in words:
# Start with character-level tokens
tokens = list(word) + ['</w>']
# Apply merges in priority order
for pair in merges:
i = 0
while i < len(tokens) - 1:
if tokens[i] == pair[0] and tokens[i + 1] == pair[1]:
tokens[i] = pair[0] + pair[1]
del tokens[i + 1]
else:
i += 1
all_tokens.extend(tokens)
return all_tokensall_tokens_set = set()
for tokens, _ in vocab_words.items():
all_tokens_set.update(tokens)
# Add any remaining individual characters
for pair in merges:
all_tokens_set.add(pair[0])
all_tokens_set.add(pair[1])
token_to_id = {tok: i for i, tok in enumerate(sorted(all_tokens_set))}test_text = "the cat sat on the mat and then jumped over the dog"
tokens = encode(test_text, merges)
token_ids = [token_to_id[t] for t in tokens]print(f"Text: '{test_text}'")
print(f"Tokens: {tokens}")
print(f"Token IDs: {token_ids}")
print(f"Num tokens: {len(tokens)} (vs {len(test_text)} characters)")Text: 'the cat sat on the mat and then jumped over the dog'
Tokens: ['the</w>', 'cat</w>', 'sat</w>', 'o', 'n', '</w>', 'the</w>', 'm', 'at</w>', 'an', 'd', '</w>', 'the', 'n', '</w>', 'j', 'u', 'm', 'p', 'e', 'd', '</w>', 'o', 'v', 'er</w>', 'the</w>', 'dog', '</w>']
Token IDs: [37, 8, 32, 27, 23, 2, 37, 22, 5, 4, 9, 2, 36, 23, 2, 20, 40, 22, 28, 13, 9, 2, 27, 41, 15, 37, 11, 2]
Num tokens: 28 (vs 51 characters)
ok lets compare to another tokenizer trained on a larger corpus
from transformers import CLIPTokenizer
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")/Users/jpoberhauser/mambaforge3/envs/sam2_env/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
print(f"Vocab size: {tokenizer.vocab_size}")
print(f"Special tokens: {tokenizer.special_tokens_map}")
print(f"All special token IDs: {tokenizer.all_special_ids}")Vocab size: 49408
Special tokens: {'bos_token': '<|startoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}
All special token IDs: [49406, 49407]
test_sentences = [
"the cat sat on the mat",
"a photograph of a dog",
"supercalifragilisticexpialidocious", # unseen word — watch it get split into subwords
]for sent in test_sentences:
# HuggingFace encoding
hf_encoded = tokenizer(sent)
hf_tokens = tokenizer.convert_ids_to_tokens(hf_encoded["input_ids"])
# Your BPE encoding (from previous section)
our_tokens = encode(sent, merges)
print(f"\nText: '{sent}'")
print(f" HF tokens ({len(hf_tokens)}): {hf_tokens}")
print(f" HF IDs: {hf_encoded['input_ids']}")
print(f" Our tokens ({len(our_tokens)}): {our_tokens}")
Text: 'the cat sat on the mat'
HF tokens (8): ['<|startoftext|>', 'the</w>', 'cat</w>', 'sat</w>', 'on</w>', 'the</w>', 'mat</w>', '<|endoftext|>']
HF IDs: [49406, 518, 2368, 3279, 525, 518, 9063, 49407]
Our tokens (9): ['the</w>', 'cat</w>', 'sat</w>', 'o', 'n', '</w>', 'the</w>', 'm', 'at</w>']
Text: 'a photograph of a dog'
HF tokens (7): ['<|startoftext|>', 'a</w>', 'photograph</w>', 'of</w>', 'a</w>', 'dog</w>', '<|endoftext|>']
HF IDs: [49406, 320, 8853, 539, 320, 1929, 49407]
Our tokens (19): ['a', '</w>', 'p', 'h', 'o', 'to', 'g', 'r', 'a', 'p', 'h', '</w>', 'o', 'f', '</w>', 'a', '</w>', 'dog', '</w>']
Text: 'supercalifragilisticexpialidocious'
HF tokens (11): ['<|startoftext|>', 'super', 'cali', 'frag', 'ili', 'stic', 'expi', 'ali', 'do', 'cious</w>', '<|endoftext|>']
HF IDs: [49406, 1642, 2857, 13093, 2076, 5868, 26850, 835, 639, 38466, 49407]
Our tokens (32): ['s', 'u', 'p', 'er', 'c', 'a', 'l', 'i', 'f', 'r', 'a', 'g', 'i', 'l', 'i', 's', 't', 'i', 'c', 'e', 'x', 'p', 'i', 'a', 'l', 'i', 'do', 'c', 'i', 'o', 'u', 's</w>']
# we need one more thing, which is padding. this will fill shorter senteces with 0 to match the other one.
batch = tokenizer(
["a short sentence", "a much longer sentence with more words in it"],
padding=True,
truncation=True,
max_length=77, # CLIP's context length
return_tensors="pt"
)
print("input_ids shape:", batch["input_ids"].shape)
print("attention_mask shape:", batch["attention_mask"].shape)
print("\nSentence 1 IDs:", batch["input_ids"][0].tolist())
print("Sentence 2 IDs:", batch["input_ids"][1].tolist())
print("\nSentence 1 mask:", batch["attention_mask"][0].tolist())input_ids shape: torch.Size([2, 11])
attention_mask shape: torch.Size([2, 11])
Sentence 1 IDs: [49406, 320, 3005, 12737, 49407, 49407, 49407, 49407, 49407, 49407, 49407]
Sentence 2 IDs: [49406, 320, 1238, 5349, 12737, 593, 750, 2709, 530, 585, 49407]
Sentence 1 mask: [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]
Now we can move on to Stage2
here, we will try to build a text encoder. We can reuse some of the building blocks we built for our vanilla ViT actually!
get token embeddings and positional embeddings
run MHA, but with causal masking.
To recap back to our ViT, instead of patch mebeddings we will use token embeddings.
- Intead of adding a
cls_tokenwe will ad aeostoken ouput orclstoken.
- Intead of adding a
We then run a forward pass that takes us from text -> to tokens -> token embeddings -> Transformer blocks -> rescue the
eosand use that.
import torch
import torch.nn as nnclass TextEmbedding(nn.Module):
def __init__(self, vocab_size, embed_dim, max_seq_len):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, embed_dim)
self.position_embedding = nn.Embedding(max_seq_len, embed_dim)
def forward(self, token_ids):
# token_ids: (batch, seq_len)
seq_len = token_ids.shape[1]
positions = torch.arange(seq_len, device=token_ids.device)
x = self.token_embedding(token_ids) + self.position_embedding(positions)
return x # (batch, seq_len, embed_dim)emb = TextEmbedding(vocab_size=49408, embed_dim=512, max_seq_len=77)
dummy_ids = torch.randint(0, 49408, (2, 10))
print(emb(dummy_ids).shape) torch.Size([2, 10, 512])
lets import the
MultiHeadAttentionfrom the vit we built in the first module and comparethe difference should only lie in the causal masks, which we didnt need in ViT since we want the attentions scores to see the entire sequence of patches
import sys
import inspect
sys.path.append('../module1_vision_transformer_foundations')
from vit import AttentionHead, MultiHeadAttentionprint(inspect.getsource(MultiHeadAttention)) class MultiHeadAttention(nn.Module):
def __init__(self, dim, num_heads, dropout=0.1):
super().__init__()
assert dim % num_heads == 0, "dim must be divisible by num_heads"
self.head_dim = dim // num_heads
self.num_heads = num_heads
self.heads = nn.ModuleList(
[AttentionHead(dim, self.head_dim, dropout) for _ in range(num_heads)]
)
self.proj = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
head_outputs = [head(x) for head in self.heads]
out = torch.cat(head_outputs, dim=-1)
out = self.proj(out)
out = self.dropout(out)
return out
class CausalSelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
self.proj = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
B, T, C = x.shape
# Project to Q, K, V
qkv = self.qkv(x).reshape(B, T, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, heads, T, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]
# Attention scores
scale = self.head_dim ** -0.5
attn = (q @ k.transpose(-2, -1)) * scale # (B, heads, T, T)
# Causal mask: prevent attending to future tokens
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
attn = attn.masked_fill(mask, float('-inf'))
attn = attn.softmax(dim=-1)
# Weighted sum of values
out = (attn @ v).transpose(1, 2).reshape(B, T, C)
return self.proj(out)csa = CausalSelfAttention(embed_dim=512, num_heads=8)
dummy = torch.randn(2, 10, 512)
print(csa(dummy).shape) # (2, 10, 512)torch.Size([2, 10, 512])
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads, mlp_ratio=4):
super().__init__()
self.ln1 = nn.LayerNorm(embed_dim)
self.attn = CausalSelfAttention(embed_dim, num_heads)
self.ln2 = nn.LayerNorm(embed_dim)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, embed_dim * mlp_ratio),
nn.GELU(),
nn.Linear(embed_dim * mlp_ratio, embed_dim),
)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return xclass TextEncoder(nn.Module):
def __init__(self, vocab_size=49408, embed_dim=512, num_heads=8,
num_layers=12, max_seq_len=77, projection_dim=512):
super().__init__()
self.embedding = TextEmbedding(vocab_size, embed_dim, max_seq_len)
self.blocks = nn.Sequential(*[
TransformerBlock(embed_dim, num_heads) for _ in range(num_layers)
])
self.ln_final = nn.LayerNorm(embed_dim)
self.projection = nn.Linear(embed_dim, projection_dim, bias=False)
def forward(self, token_ids):
# token_ids: (batch, seq_len) — includes <startoftext> and <endoftext>
x = self.embedding(token_ids) # (B, T, embed_dim)
x = self.blocks(x) # (B, T, embed_dim)
x = self.ln_final(x) # (B, T, embed_dim)
# Take the EOS token's representation as the sentence embedding
# EOS is the highest-valued token in each sequence
eos_indices = token_ids.argmax(dim=-1) # works because <endoftext> has the highest ID
sentence_emb = x[torch.arange(x.shape[0]), eos_indices] # (B, embed_dim)
# Project to shared embedding space + L2 normalize
sentence_emb = self.projection(sentence_emb)
sentence_emb = sentence_emb / sentence_emb.norm(dim=-1, keepdim=True)
return sentence_emb # (B, projection_dim), unit normalizedAnd finally Stage 3
here, we add a projection head (an
nn.Linear) that maps transformer output to an embedding of fixed dimensions. (512)then, we normalize embeddings for cosine similarity
test and run some stuff through it!
model = TextEncoder(num_layers=4) # fewer layers for quick testing
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")Parameters: 38,209,024
texts = ["a photo of a cat", "a photo of a dog", "the eiffel tower at sunset"]
batch = tokenizer(texts, padding=True, truncation=True, max_length=77, return_tensors="pt")with torch.no_grad():
embeddings = model(batch["input_ids"])
print(f"Embedding shape: {embeddings.shape}") # (3, 512)
print(f"Norms (should be ~1.0): {embeddings.norm(dim=-1)}")Embedding shape: torch.Size([3, 512])
Norms (should be ~1.0): tensor([1.0000, 1.0000, 1.0000])
sim_matrix = embeddings @ embeddings.T
print(f"\nSimilarity matrix:\n{sim_matrix}") # this is untrained, so dont expect it to make much sense
Similarity matrix:
tensor([[1.0000, 0.9973, 0.9801],
[0.9973, 1.0000, 0.9829],
[0.9801, 0.9829, 1.0000]])