import argparse
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm
from peft import LoraConfig, get_peft_model
from vllm import LLM, SamplingParams
import gc

def accumulate_fisher_on_lora(problem, answer):
    """Calculate gradient squared for a single sample and accumulate to fisher_dict, only for LoRA parameters"""
    text = tokenizer.apply_chat_template(
        [
            {"role": "user", "content": problem},
            {"role": "assistant", "content": answer}
        ],
        tokenize=False,
        add_generation_prompt=False
    )
    
    enc = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=args.max_len
    ).to(model.device)
    
    outputs = model(**enc, labels=enc["input_ids"])
    loss = outputs.loss
    loss.backward()

    # Iterate through all parameters, only accumulate gradient squared for LoRA parameters
    for name, param in model.named_parameters():
        if "lora_" not in name:
            continue
        if param.requires_grad and param.grad is not None:
            fisher_val = (param.grad.detach() ** 2)
            if name not in fisher_dict:
                fisher_dict[name] = fisher_val.clone()
            else:
                fisher_dict[name] += fisher_val.clone()

def ref_dataset_generate(dataset, tokenizer):
    llm = LLM(
        model=args.model,
        trust_remote_code=True,
        dtype="float16",
        max_model_len=16384,
        tensor_parallel_size=torch.cuda.device_count()
    )
    
    sampling_params = SamplingParams(
        temperature=0.7,
        top_p=0.9,
        max_tokens=args.max_new_tokens,
        stop_token_ids=[tokenizer.eos_token_id]
    )
    
    prompts = []
    for example in dataset:
        problem = example["problem"]
        prompt = tokenizer.apply_chat_template(
            [{"role": "user", "content": problem}],
            tokenize=False,
            add_generation_prompt=True
        )
        prompts.append(prompt)
    
    outputs = llm.generate(prompts, sampling_params)
    answers = [output.outputs[0].text for output in outputs]
    
    del llm
    gc.collect()
    torch.cuda.empty_cache()
    
    return answers

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Extract Fisher information on untrained LoRA parameters (preset fixed)")
    
    parser.add_argument("--model", type=str, default="model_name_placeholder", help="Base model ID to use")
    parser.add_argument("--dataset", type=str, default="HuggingFaceH4/MATH-500", help="Dataset for computation")
    parser.add_argument("--subset", type=str, default="test", help="Dataset subset (e.g., test, train)")
    parser.add_argument("--num_samples", type=int, default=100, help="Number of samples for Fisher information computation")
    parser.add_argument("--max_len", type=int, default=1024, help="Maximum length for sequence processing")
    parser.add_argument("--max_new_tokens", type=int, default=8192, help="Maximum number of new tokens when generating answers")
    
    parser.add_argument("--lora_r", type=int, default=4, help="LoRA rank")
    parser.add_argument("--lora_alpha", type=int, default=16, help="LoRA alpha parameter")
    parser.add_argument("--lora_dropout", type=float, default=0.05, help="Dropout probability for LoRA layers")
    
    parser.add_argument(
        "--target_preset",
        type=str,
        choices=["mlp", "attn", "both"],
        default="both",
        help="Fixed LoRA target preset: mlp / attn / both"
    )

    args = parser.parse_args()

    # Define preset to modules mapping
    PRESET_TO_MODULES = {
        "mlp": "gate_proj,down_proj,up_proj",
        "attn": "q_proj,k_proj,v_proj,o_proj",
        "both": "q_proj,k_proj,v_proj,o_proj,gate_proj,down_proj,up_proj"
    }

    target_str = PRESET_TO_MODULES[args.target_preset]
    target_modules = [m.strip() for m in target_str.split(",")]

    # Load dataset and generate reference answers
    print(f"Loading dataset: {args.dataset} (subset: {args.subset})")
    dataset = load_dataset(args.dataset, split=args.subset)
    dataset = dataset.shuffle(seed=42).select(range(args.num_samples))
    print(f"Prepared {len(dataset)} samples.")
    tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False, trust_remote_code=True)
    answers = ref_dataset_generate(dataset, tokenizer)

    # Load base model
    print(f"Loading base model: {args.model}...")
    model = AutoModelForCausalLM.from_pretrained(
        args.model,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )
    print("Base model loaded.")

    # Apply LoRA adapter to model (without training)
    print("Applying LoRA adapter to model (without training)...")
    config = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        target_modules=target_modules,
        lora_dropout=args.lora_dropout,
        bias="none",
        task_type="CAUSAL_LM"
    )
    model = get_peft_model(model, config)
    print("LoRA adapter applied. Trainable (target) parameters:")
    model.print_trainable_parameters()
    model.eval()

    # Calculate Fisher information on LoRA parameters
    print("\n🔬 Starting Fisher information calculation on LoRA parameters...")
    fisher_dict = {}

    for example, answer in tqdm(zip(dataset, answers), total=len(dataset), desc="Computing Fisher information"):
        problem = example["problem"]
        accumulate_fisher_on_lora(problem, answer)
        model.zero_grad()

    # Normalize Fisher information
    print("Normalizing Fisher information...")
    for name in fisher_dict:
        fisher_dict[name] /= args.num_samples

    args_dict = vars(args)
    results_to_save = {
        'args': args_dict,
        'fisher_information': fisher_dict
    }

    safe_model_name = args.model.replace("/", "--")
    
    if args.target_preset == "mlp":
        target_tag = "mlp"
    elif args.target_preset == "attn":
        target_tag = "attn"
    else:
        target_tag = "mlp-attn"

    output_filename = (
        f"results/"
        f"fisher-results_{safe_model_name}"
        f"_lora-r{args.lora_r}"
        f"_samples{args.num_samples}"
        f"_{target_tag}.pt"
    )
    
    torch.save(results_to_save, output_filename)

    print("="*50)
    print("✅ Fisher information extraction completed!")
    print(f"   Target preset = {args.target_preset} ({target_str})")
    print(f"   All results (including run parameters) saved to:")
    print(f"   {output_filename}")
    print("="*50)