import json
import torch
import re
import os
import numpy as np
from tqdm import tqdm
import random
from transformers import AutoTokenizer, AutoModelForCausalLM
from rouge_score import rouge_scorer

def contains_exact_phrase(response: str, reference: str) -> bool:
    response_clean = response.lower()
    reference_clean = reference.lower()
    return reference_clean in response_clean

with open("../../data/news/knowmem/forget_qa.json", "r") as f:
    dataset = json.load(f)

llama_dir = "meta-llama/Llama-2-7b-hf"
model_names = {
    "Original": "muse-bench/MUSE-News_target",
    "SimNPO": "OPTML-Group/SimNPO-MUSE-News-Llama-2-7b",
    "NPO-SAM": "OPTML-Group/NPO-SAM-MUSE-NEWS",
    "Retrain": "muse-bench/MUSE-News_retrain"
}
param_pairs = [
    (0.2, 0.2),
    (0.2, 0.8),
    (0.2, 1.0),
    (0.8, 0.2),
    (0.8, 0.8),
    (0.8, 1.0),
    (1.0, 0.2),
    (1.0, 0.8),
    (1.0, 1.0)
]
output_dir = "output"

for t, p in param_pairs:
    for method, model_name in model_names.items():
        
        tokenizer = AutoTokenizer.from_pretrained(llama_dir)
        tokenizer.pad_token = tokenizer.eos_token
        gen_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype="auto"
        ).eval()

        scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=False)

        results_by_n = {}
        detailed_results_by_n = {}

        for n in [1, 2, 4, 8, 16, 32, 64, 128, 200]:
            max_rouge_scores = []
            detailed_results = []

            for item in tqdm(dataset, desc=f"[{method} | T={t}, P={p}] {n}x"):
                prompt = f"[INST] {item['question']} [/INST]"
                prompts = [prompt] * n

                inputs = tokenizer(prompts, return_tensors="pt", padding=True, add_special_tokens=True)

                with torch.no_grad():
                    output_ids = gen_model.generate(
                        input_ids=inputs["input_ids"],
                        attention_mask=inputs["attention_mask"],
                        max_new_tokens=256,
                        do_sample=True,
                        temperature=t,
                        top_p=p
                    )

                decoded_outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

                best_response = ""
                best_rouge_score = -1
                best_exact_score = 0
                generation_scores = []

                for gen_text in decoded_outputs:
                    response_clean = gen_text.split(prompt)[-1].strip() if prompt in gen_text else gen_text.strip()
                    rouge_recall = scorer.score(item["answer"], response_clean)['rougeL'].recall
                    exact_score = 1 if contains_exact_phrase(response_clean, item["answer"]) else 0

                    generation_scores.append({
                        "response": response_clean,
                        "rougeL_recall": rouge_recall,
                        "exact_match": exact_score
                    })

                    if exact_score == 1 and rouge_recall == 1:
                        best_response = response_clean
                        best_rouge_score = rouge_recall
                        best_exact_score = 1
                    elif rouge_recall > best_rouge_score:
                        best_response = response_clean
                        best_rouge_score = rouge_recall

                max_rouge_scores.append(best_rouge_score)
                detailed_results.append({
                    "question": item["question"],
                    "answer": item["answer"],
                    "best_response": best_response,
                    "max_rougeL_recall": best_rouge_score,
                    "exact_match_found": bool(best_exact_score),
                    "all_generations": generation_scores
                })

            mean_rouge = sum(max_rouge_scores) / len(max_rouge_scores)
            results_by_n[n] = mean_rouge
            detailed_results_by_n[n] = {
                "mean_rougeL_recall": mean_rouge,
                "per_question": detailed_results
            }

        run_dir = os.path.join(output_dir, f"temperature={t}_top_p={p}_{method}")
        os.makedirs(run_dir, exist_ok=True)

        with open(os.path.join(run_dir, "gen_scores.json"), "w") as f:
            json.dump(results_by_n, f, indent=2)

        with open(os.path.join(run_dir, "gen_scores_resp.json"), "w") as f:
            json.dump(detailed_results_by_n, f, indent=2)
