import json
import torch
import os
from typing import Dict, Any, List
from transformers import AutoTokenizer, AutoModelForCausalLM


def evaluate_model(test_samples, model, tokenizer):
    """
    test_samples: {
        "mem_input": [...],       # list[str]
        "mem_target": [...],      # list[str]
        "gen_input": [...],
        "gen_target": [...],
        "hard_gen_input": [...],
        "hard_gen_target": [...]
        ...
    }
    """

    total_mem = 0
    correct_mem = 0
    
    total_gen = 0
    correct_gen = 0
    
    total_hard_gen = 0
    correct_hard_gen = 0

    def generate_response(prompt) -> str:
        inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                do_sample=False,    
                eos_token_id=tokenizer.eos_token_id
            )
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return generated_text


    for sample in test_samples:
        # 1) mem_input / mem_target
        mem_input_list = sample.get("mem_input", [])
        mem_target_list = sample.get("mem_target", [])
        
        for inp, tgt in zip(mem_input_list, mem_target_list):
            total_mem += 1
            response = generate_response(inp)
            if tgt in response:
                correct_mem += 1

        # 2) gen_input / gen_target
        gen_input_list = sample.get("gen_input", [])
        gen_target_list = sample.get("gen_target", [])
        
        for inp, tgt in zip(gen_input_list, gen_target_list):
            total_gen += 1
            response = generate_response(inp)
            if tgt in response:
                correct_gen += 1
        
        # 3) hard_gen_input / hard_gen_target
        hard_gen_input_list = sample.get("hard_gen_input", [])
        hard_gen_target_list = sample.get("hard_gen_target", [])
        
        for inp, tgt in zip(hard_gen_input_list, hard_gen_target_list):
            total_hard_gen += 1
            response = generate_response(inp)
            if tgt in response:
                correct_hard_gen += 1

    mem_acc = correct_mem / total_mem if total_mem > 0 else 0
    gen_acc = correct_gen / total_gen if total_gen > 0 else 0
    hard_gen_acc = correct_hard_gen / total_hard_gen if total_hard_gen > 0 else 0
    
    print(f"[MEM]   Accuracy: {mem_acc:.3f}  ({correct_mem}/{total_mem})")
    print(f"[GEN]   Accuracy: {gen_acc:.3f}  ({correct_gen}/{total_gen})")
    print(f"[HARD]  Accuracy: {hard_gen_acc:.3f}  ({correct_hard_gen}/{total_hard_gen})")



def get_args():
    import argparse
    parser = argparse.ArgumentParser()

    parser.add_argument("--model_name_or_path", type=str, default="/llm_unlearning/wmdp/models/unlearned_model")
    parser.add_argument("--test_data_path", type=str, default="/llm_unlearning/wmdp/data/fictional_knowledge.json")
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = get_args()

    MODEL_PATH = args.model_name_or_path
    TEST_DATA_PATH = args.test_data_path
    MAX_NEW_TOKENS = 10 
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

    print("Loading model & tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16, device_map="auto")
    model.eval()

    if not os.path.exists(TEST_DATA_PATH):
        raise FileNotFoundError(f"Test file not found: {TEST_DATA_PATH}")

    print("Loading test data...")
    with open(TEST_DATA_PATH, "r", encoding="utf-8") as f:
        test_data = json.load(f)

    if not isinstance(test_data, list):
        test_data = [test_data]

    evaluate_model(test_data, model, tokenizer)
