import json
import os
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm

def parse_args():
    parser = argparse.ArgumentParser(description="Expand model vocabulary and initialize embeddings.")
    
    parser.add_argument("--base_model_path", type=str, required=True, 
                        help="Path to the original base model (e.g., Qwen2.5-7B-Instruct)")
    
    parser.add_argument("--new_tokens_file", type=str, required=True, 
                        help="Path to the JSON file containing the list of new tokens")
    
    parser.add_argument("--output_model_path", type=str, required=True, 
                        help="Path to save the modified model (Step 0)")
    
    parser.add_argument("--device", type=str, default="cpu", choices=["cpu", "cuda", "auto"],
                        help="Device to load the model on. 'cpu' is safer for RAM, 'cuda' is faster.")
    
    return parser.parse_args()

def main():
    args = parse_args()

    print(f"\n=== Starting Model Surgery (Vocab Expansion) ===")
    print(f"Base Model: {args.base_model_path}")
    print(f"New Tokens: {args.new_tokens_file}")
    print(f"Output Dir: {args.output_model_path}")
    print(f"Device:     {args.device}")

    # 1. Validation
    if not os.path.exists(args.base_model_path):
        print(f"❌ Error: Base model path not found: {args.base_model_path}")
        return
    if not os.path.exists(args.new_tokens_file):
        print(f"❌ Error: New tokens file not found: {args.new_tokens_file}")
        return

    # 2. Load Tokenizer & Model
    print("📂 Loading tokenizer and model...")
    tokenizer = AutoTokenizer.from_pretrained(args.base_model_path, trust_remote_code=True)
    
    # Load model (using bfloat16 to save memory if possible, else float32)
    # Note: For initialization only, CPU is usually fine and avoids OOM.
    model = AutoModelForCausalLM.from_pretrained(
        args.base_model_path,
        torch_dtype="auto",
        device_map=args.device,
        trust_remote_code=True
    )
    
    original_vocab_size = len(tokenizer)
    print(f"   Original Vocab Size: {original_vocab_size}")

    # 3. Load New Tokens
    with open(args.new_tokens_file, 'r', encoding='utf-8') as f:
        new_tokens = json.load(f)
    print(f"   Loaded {len(new_tokens)} candidate tokens.")

    # 4. Resize Vocabulary
    # Returns the number of tokens actually added (excluding duplicates)
    num_added_toks = tokenizer.add_tokens(new_tokens)
    
    if num_added_toks > 0:
        print(f"   Resizing model embeddings (+{num_added_toks})...")
        model.resize_token_embeddings(len(tokenizer))
        print(f"✅ Resize successful! New Vocab Size: {len(tokenizer)}")
    else:
        print("⚠️  No new tokens were added (all might already exist).")
        print("   Saving vanilla model to new path anyway for consistency.")
        tokenizer.save_pretrained(args.output_model_path)
        model.save_pretrained(args.output_model_path)
        return

    # 5. Smart Mean Pooling Initialization
    print("\n🚀 Executing Smart Mean Pooling Initialization...")
    
    input_embeddings = model.get_input_embeddings()
    output_embeddings = model.get_output_embeddings()

    # Load a temporary vanilla tokenizer to calculate old IDs
    # This is crucial because 'tokenizer' now has new IDs that didn't exist before
    vanilla_tokenizer = AutoTokenizer.from_pretrained(args.base_model_path, trust_remote_code=True)

    # Ensure we are in no_grad mode
    with torch.no_grad():
        for token_str in tqdm(new_tokens, desc="Initializing Embeddings"):
            # a. Get the ID of the new token in the RESIZED tokenizer
            #    (If add_tokens didn't add it because it existed, this still gets the ID)
            new_token_id = tokenizer.convert_tokens_to_ids(token_str)
            
            # Skip if this token was already in the original vocab (index < original size)
            # We only want to initialize the *newly added* slots
            if new_token_id < original_vocab_size:
                continue

            # b. Tokenize using the VANILLA tokenizer to get constituent sub-words
            old_ids = vanilla_tokenizer.encode(token_str, add_special_tokens=False)
            
            if not old_ids:
                print(f"   ⚠️ Warning: Token '{repr(token_str)}' encodes to empty. Skipping.")
                continue

            # c. Fetch original embeddings
            # Make sure indices are on the same device as the model
            old_ids_tensor = torch.tensor(old_ids, device=model.device)
            old_embeddings = input_embeddings.weight[old_ids_tensor]
            
            # d. Calculate Mean
            avg_embedding = old_embeddings.mean(dim=0)
            
            # e. Assign to the new slot
            input_embeddings.weight[new_token_id] = avg_embedding
            
            # Also update output layer (lm_head) if it's separate
            if output_embeddings is not None:
                output_embeddings.weight[new_token_id] = avg_embedding

    print("✅ Embeddings initialized.")

    # 6. Save Result
    print(f"\n💾 Saving Step-0 Model to: {args.output_model_path}")
    os.makedirs(args.output_model_path, exist_ok=True)
    
    tokenizer.save_pretrained(args.output_model_path)
    model.save_pretrained(args.output_model_path)
    
    print("\n🎉 Model Surgery Complete!")
    print(f"   Ready for SFT: {args.output_model_path}")

if __name__ == "__main__":
    main()