import os
import torch
import numpy as np
from transformers import AutoModelForCausalLM
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.save_mapping import load_temperature_dict_npz, save_temperature_dict_npz
import argparse
from tqdm import tqdm

# Usage example:
# python tools/delta_to_bias.py --model_path /path/to/model --delta_path /path/to/delta_map.npz --output_path /path/to/bias_map.npz
# python tools/delta_to_bias.py --model_path /path/to/models/llama-3.2-1b-instruct --delta_path /path/to/calibration_output/llama-3.2-1b-instruct/delta_training_with_completions/n1_32_k_8/delta.npz

def main():
    parser = argparse.ArgumentParser(description="Convert delta_map (hidden) to bias_map (vocab) using model's lm_head.")
    parser.add_argument('--model_path', type=str, required=True, help='Path to HuggingFace model directory')
    parser.add_argument('--delta_path', type=str, required=True, help='Path to delta_map npz file')
    parser.add_argument('--output_path', type=str, required=False, help='Output path for bias_map npz file (default: same dir as delta_path, name bias.npz)')
    parser.add_argument('--device', type=str, default='cuda', help='Device for model (cuda or cpu)')
    args = parser.parse_args()

    print(f"Loading model from {args.model_path}")
    model = AutoModelForCausalLM.from_pretrained(args.model_path).to(args.device)
    model.eval()
    lm_head = getattr(model, 'lm_head', None)
    if lm_head is None:
        raise RuntimeError('Model does not have lm_head attribute!')

    print(f"Loading delta map from {args.delta_path}")
    delta_map = load_temperature_dict_npz(args.delta_path)
    bias_map = {}

    for unique_id, entry in tqdm(delta_map.items(), desc="Converting delta to bias"):
        delta = entry["delta"]
        if isinstance(delta, list):
            delta = np.array(delta)
        print(f"[DEBUG] unique_id={unique_id} raw delta shape: {delta.shape}")
        delta_tensor = torch.from_numpy(delta).float().to(args.device)
        print(f"[DEBUG] unique_id={unique_id} torch delta_tensor shape before squeeze: {delta_tensor.shape}")
        # delta shape: [1, 1, hidden_size] or [hidden_size]
        if delta_tensor.dim() == 3:
            delta_tensor = delta_tensor.squeeze(0).squeeze(0)
        print(f"[DEBUG] unique_id={unique_id} torch delta_tensor shape after squeeze: {delta_tensor.shape}")
        with torch.no_grad():
            bias = lm_head(delta_tensor).cpu().numpy()  # [vocab_size]
        bias_map[unique_id] = {"bias": bias}

    # Determine output_path
    output_path = args.output_path
    if not output_path:
        base_dir = os.path.dirname(args.delta_path)
        output_path = os.path.join(base_dir, "bias.npz")

    print(f"Saving bias map to {output_path}")
    save_temperature_dict_npz(bias_map, output_path)
    print("Done!")

if __name__ == "__main__":
    main()
