import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm
from vllm import LLM, SamplingParams
import gc

def accumulate_fisher_full(problem, answer):
    """Calculate gradient squared for a single sample and accumulate to fisher_dict (full parameter version)"""
    text = tokenizer.apply_chat_template(
        [
            {"role": "user", "content": problem},
            {"role": "assistant", "content": answer}
        ],
        tokenize=False,
        add_generation_prompt=False
    )
    
    enc = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=args.max_len
    ).to(model.device)
    
    outputs = model(**enc, labels=enc["input_ids"])
    loss = outputs.loss
    loss.backward()

    # Iterate through all parameters and accumulate gradient squared
    for name, param in model.named_parameters():
        if param.requires_grad and param.grad is not None:
            fisher_val = (param.grad.detach() ** 2)
            if name not in fisher_dict:
                fisher_dict[name] = fisher_val.clone()
            else:
                fisher_dict[name] += fisher_val.clone()

def ref_dataset_generate(dataset, tokenizer):
    """Generate reference answers using vLLM"""
    llm = LLM(
        model=args.model,
        trust_remote_code=True,
        dtype="float16",
        max_model_len=16384,
        tensor_parallel_size=torch.cuda.device_count()
    )
    sampling_params = SamplingParams(max_tokens=args.max_new_tokens)

    prompts = [
        tokenizer.apply_chat_template(
            [{"role": "user", "content": example["problem"]}],
            tokenize=False,
            add_generation_prompt=True
        )
        for example in dataset
    ]
    
    outputs = llm.generate(prompts, sampling_params)
    answers = [output.outputs[0].text for output in outputs]
    
    return answers

if __name__ == '__main__':
    # Parse command line arguments
    parser = argparse.ArgumentParser(description="Extract Fisher information on full model parameters.")
    parser.add_argument("--model", type=str, default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
    parser.add_argument("--dataset", type=str, default="HuggingFaceH4/MATH-500")
    parser.add_argument("--subset", type=str, default="test")
    parser.add_argument("--num_samples", type=int, default=100)
    parser.add_argument("--max_len", type=int, default=1024)
    parser.add_argument("--max_new_tokens", type=int, default=8192)
    args = parser.parse_args()

    # Load dataset
    print(f"Loading dataset: {args.dataset} (subset: {args.subset})")
    dataset = load_dataset(args.dataset, split=args.subset)
    dataset = dataset.shuffle(seed=42).select(range(args.num_samples))
    print(f"Prepared {len(dataset)} samples.")
    
    tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False, trust_remote_code=True)
    answers = ref_dataset_generate(dataset, tokenizer)

    # Load base model
    print(f"Loading base model: {args.model}...")
    model = AutoModelForCausalLM.from_pretrained(
        args.model,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )
    print("Model loaded successfully.")
    model.eval()

    # Calculate Fisher information on full model parameters
    print("\n🔬 Starting Fisher information calculation on **full model parameters**...")
    fisher_dict = {}

    for example, answer in tqdm(zip(dataset, answers), total=len(dataset), desc="Computing Fisher information"):
        problem = example["problem"]
        # Calculate Fisher information for this sample
        accumulate_fisher_full(problem, answer)
        model.zero_grad()

    # Normalize Fisher information
    print("Normalizing Fisher information...")
    for name in fisher_dict:
        fisher_dict[name] /= args.num_samples

    results_to_save = {
        'args': vars(args),
        'fisher_information': fisher_dict
    }

    safe_model_name = args.model.replace("/", "--")
    output_filename = (
        f"results/"
        f"fisher-full_{safe_model_name}"
        f"_samples{args.num_samples}.pt"
    )
    
    torch.save(results_to_save, output_filename)

    print("="*50)
    print("✅ Full parameter Fisher information extraction completed!")
    print(f"   Results saved to: {output_filename}")
    print("="*50)