import json
import torch
import re
import os
import numpy as np
from tqdm import tqdm
import random
from transformers import AutoTokenizer, AutoModelForCausalLM
from rouge_score import rouge_scorer

with open("../../data/news/knowmem/retain_qa.json", "r") as f:
    dataset = json.load(f)

llama_dir = "meta-llama/Llama-2-7b-hf"
model_names = {
    "Original": "muse-bench/MUSE-News_target",
    "SimNPO": "OPTML-Group/SimNPO-MUSE-News-Llama-2-7b",
    "NPO-SAM": "OPTML-Group/NPO-SAM-MUSE-NEWS",
    "Retrain": "muse-bench/MUSE-News_retrain"
}
param_pairs = [
    (0.2, 0.2),
    (0.2, 0.8),
    (0.2, 1.0),
    (0.8, 0.2),
    (0.8, 0.8),
    (0.8, 1.0),
    (1.0, 0.2),
    (1.0, 0.8),
    (1.0, 1.0)
]
output_dir = "output"
for t, p in param_pairs:
    for method, model_name in model_names.items():
        
        tokenizer = AutoTokenizer.from_pretrained(llama_dir)
        tokenizer.pad_token = tokenizer.eos_token
        gen_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype="auto"
        ).eval()

        scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=False)

        results_by_n = {}
        detailed_results_by_n = {}

        for n in [200]:
            worst_rouge_scores = []
            mean_rouge_scores = []
            detailed_results = []

            for item in tqdm(dataset, desc=f"[{method} | T={t}, P={p}] {n}x"):
                prompt = f"[INST] {item['question']} [/INST]"
                prompts = [prompt] * n

                inputs = tokenizer(prompts, return_tensors="pt", padding=True, add_special_tokens=True)

                with torch.no_grad():
                    output_ids = gen_model.generate(
                        input_ids=inputs["input_ids"],
                        attention_mask=inputs["attention_mask"],
                        max_new_tokens=256,
                        do_sample=False
                    )

                decoded_outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

                worst_response = ""
                worst_rouge_score = 2
                generation_scores = []

                for gen_text in decoded_outputs:
                    response_clean = gen_text.split(prompt)[-1].strip() if prompt in gen_text else gen_text.strip()
                    rouge_recall = scorer.score(item["answer"], response_clean)['rougeL'].recall

                    generation_scores.append({
                        "response": response_clean,
                        "rougeL_recall": rouge_recall
                    })

                    if rouge_recall < worst_rouge_score:
                        worst_response = response_clean
                        worst_rouge_score = rouge_recall
                        
                mean_rouge_score = sum([g["rougeL_recall"] for g in generation_scores]) / len(generation_scores)
                mean_rouge_scores.append(mean_rouge_score)
                worst_rouge_scores.append(worst_rouge_score)
                
                detailed_results.append({
                    "question": item["question"],
                    "answer": item["answer"],
                    "worst_response": worst_response,
                    "min_rougeL_recall": worst_rouge_score,
                    "mean_rougeL_recall": mean_rouge_score,
                    "all_generations": generation_scores
                })

            mean_worst_rouge = sum(worst_rouge_scores) / len(worst_rouge_scores)
            mean_mean_rouge_score = sum(mean_rouge_scores) / len(mean_rouge_scores)
            results_by_n[n] = {
                "mean_min_rougeL_recall": mean_worst_rouge,
                "mean_mean_rougeL_recall": mean_mean_rouge_score
            }

            detailed_results_by_n[n] = {
                "mean_min_rougeL_recall": mean_worst_rouge,
                "mean_mean_rougeL_recall": mean_mean_rouge_score,
                "per_question": detailed_results
            }

        run_dir = os.path.join(output_dir, f"temperature={t}_top_p={p}_{method}")
        os.makedirs(run_dir, exist_ok=True)

        with open(os.path.join(run_dir, "gen_scores.json"), "w") as f:
            json.dump(results_by_n, f, indent=2)

        with open(os.path.join(run_dir, "gen_scores_resp.json"), "w") as f:
            json.dump(detailed_results_by_n, f, indent=2)
