#!/usr/bin/env python
# python add_special_tokens.py --model meta-llama/Llama-3.2-1B-Instruct \
#                              --outdir ./llama3_special

import argparse, torch
from pathlib import Path
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
)

parser = argparse.ArgumentParser()
parser.add_argument("--model",   required=True)
parser.add_argument("--outdir",  required=True, type=Path)
args = parser.parse_args()

# ----- 1. Load tokenizer first
tok = AutoTokenizer.from_pretrained(args.model)                 # :contentReference[oaicite:0]{index=0}

# ----- 2. Define tokens and add them
special = [f"<move_{i}>" for i in range(1, 19)]
# added   = tok.add_tokens(special)                               # :contentReference[oaicite:1]{index=1}
tok.add_special_tokens({"additional_special_tokens": special})  # note the method
added   = len(special)                                          # update added to reflect the number of tokens
print(f"added {added} tokens")

# ----- 3. Load model *after* tokens are ready
model = AutoModelForCausalLM.from_pretrained(
            args.model,
            torch_dtype=torch.bfloat16,                         # keeps memory low
            device_map="auto")                                  # GPUs if present
old, new = model.config.vocab_size, len(tok)

# ----- 5. Optional: if pad token missing, point it at eos
if tok.pad_token is None:
    tok.pad_token = tok.eos_token  

# ----- 4. Resize embeddings only once
if new != old:
    print(f"resizing from {old} to {new}")
    model.resize_token_embeddings(new)                          # :contentReference[oaicite:2]{index=2}
    model.tie_weights()                                         # LM head ↔ embedding


# The average‑vector trick follows Hewitt 2021’s recommendation that stops the “all‑new‑tokens” collapse
# https://www.cs.columbia.edu/~johnhew//vocab-expansion.html
# --- soften the new rows so they do not win the soft‑max
with torch.no_grad():
    emb = model.get_input_embeddings().weight
    emb_mean = emb[:-len(special)].mean(0, keepdim=True)
    emb[-len(special):] = emb_mean        # safer than random init
                             
# ----- 6. Save
args.outdir.mkdir(parents=True, exist_ok=True)
tok.save_pretrained(args.outdir)                                # :contentReference[oaicite:4]{index=4}
model.save_pretrained(args.outdir, safe_serialization=True)     # :contentReference[oaicite:5]{index=5}
print("saved to", args.outdir)

# ----- 7. Reload to verify; mismatched flag skips only new rows
tok2   = AutoTokenizer.from_pretrained(args.outdir)
model2 = AutoModelForCausalLM.from_pretrained(
             args.outdir,
             ignore_mismatched_sizes=True,                      # :contentReference[oaicite:6]{index=6}
             torch_dtype=torch.bfloat16,
             device_map="auto")

print("reload ok; vocab size", len(tok2))

# Try tokenizing some text

text = "<move_1> <move_2>"
tokens = tok2(text)
print("Tokens:", tokens)
print("Decoded:", tok2.decode(tokens["input_ids"]))

# Tokenize only the added tokens
# tokens = tok2(special)
# print("Tokens:", tokens)
# print("Decoded:", tok2.decode(tokens["input_ids"]))
# print("Decoded:", tok2.decode(tokens["input_ids"], skip_special_tokens=True))
