import os
import argparse
import json
import re
import numpy as np
from tqdm import tqdm
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge import Rouge
from sentence_transformers import SentenceTransformer
import torch
from transformers import AutoModel, AutoTokenizer
from torch.nn.functional import cosine_similarity
import logging

logging.getLogger("transformers").setLevel(logging.ERROR)

os.environ["TRANSFORMERS_CACHE"] = "/root/storage/models"
os.environ["HF_HOME"] = "/root/storage/models"
source_file = "/data/locomo10.json"

def calculate_bert_f1(gold_answers, generated_answers, model, tokenizer, device="cuda" if torch.cuda.is_available() else "cpu"):
    model.eval()
    all_f1 = []
    
    with torch.no_grad():
        for ref, hyp in tqdm(zip(gold_answers, generated_answers), total=len(gold_answers), desc="Calculating BERTScore"):
            if not ref or not hyp:
                all_f1.append(0.0)
                continue

            inputs_ref = tokenizer(ref, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
            inputs_hyp = tokenizer(hyp, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
            
            outputs_ref = model(**inputs_ref)
            outputs_hyp = model(**inputs_hyp)
            
            emb_ref = outputs_ref.last_hidden_state.squeeze(0)  # [seq_len_ref, hidden_size]
            emb_hyp = outputs_hyp.last_hidden_state.squeeze(0)  # [seq_len_hyp, hidden_size]

            mask_ref = inputs_ref['attention_mask'].squeeze(0).bool()
            mask_hyp = inputs_hyp['attention_mask'].squeeze(0).bool()
            
            emb_ref = emb_ref[mask_ref][1:-1] 
            emb_hyp = emb_hyp[mask_hyp][1:-1] 

            if emb_ref.nelement() == 0 or emb_hyp.nelement() == 0:
                all_f1.append(0.0)
                continue

            emb_ref = emb_ref / emb_ref.norm(dim=1, keepdim=True)
            emb_hyp = emb_hyp / emb_hyp.norm(dim=1, keepdim=True)
            
            sim_matrix = torch.matmul(emb_ref, emb_hyp.T)
            
            recall_scores = sim_matrix.max(dim=1).values
            R = recall_scores.mean()
            
            precision_scores = sim_matrix.max(dim=0).values
            P = precision_scores.mean()
            
            F1 = 2 * P * R / (P + R) if (P + R) > 0 else 0.0
            all_f1.append(F1.item())
            
    return all_f1


def calculate_f1_score(gold_answer, generated_answer):
    """Calculate the F1 score between the gold answer and the generated answer."""
    gold_answer = re.sub(r'[^\w\s]', '', gold_answer).lower()
    generated_answer = re.sub(r'[^\w\s]', '', generated_answer).lower()
    
    gold_tokens = set(gold_answer.split())
    generated_tokens = set(generated_answer.split())
    
    intersection = gold_tokens.intersection(generated_tokens)
    precision = len(intersection) / len(generated_tokens) if generated_tokens else 0
    recall = len(intersection) / len(gold_tokens) if gold_tokens else 0
    f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
    
    return f1

def calculate_bleu1_score(gold_answer, generated_answer):
    """Calculate the BLEU-1 score between the gold answer and the generated answer."""
    gold_answer = re.sub(r'[^\w\s]', '', gold_answer).lower()
    generated_answer = re.sub(r'[^\w\s]', '', generated_answer).lower()
    
    gold_tokens = [gold_answer.split()]
    generated_tokens = generated_answer.split()
    
    chencherry = SmoothingFunction()
    bleu_score = sentence_bleu(gold_tokens, generated_tokens, weights=(1, 0, 0, 0), smoothing_function=chencherry.method1)
    
    return bleu_score


def clean_answer(answer):
    """Clean the answer by removing unrelated content """
    if '(' in answer:
        answer = answer.split('(')[0].strip()
    if '\n' in answer:
        answer = answer.split('\n')[0].strip()
    if 'Timestamp' in answer:
        answer = answer.split('Timestamp')[0].strip()
    if '<|endoftext|>' in answer:
        answer = answer.split('<|endoftext|>')[0].strip()
    if '<|eot_id|>' in answer:
        answer = answer.split('<|eot_id|>')[0].strip()
    answer = re.sub(r'[^\w\s]', '', answer)
    return answer


def load_judged_file(judged_file_path):
    if os.path.exists(judged_file_path):
        with open(judged_file_path, "r", encoding="utf-8") as f:
            data = json.load(f)
            calculated_metrics = {}
            for item in data:
                if "f1_score" in item:
                    calculated_metrics["F1"] = True
                if "bleu1_score" in item:
                    calculated_metrics["BLEU1"] = True
                if "bert_score_f1" in item:
                    calculated_metrics["BERTScore"] = True
                if "cosine_similarity" in item:
                    calculated_metrics["CosineSimilarity"] = True
            return calculated_metrics, data
    else:
        return {}, []

def get_category_from_source(source_file, question):
    with open(source_file, "r", encoding="utf-8") as f:
        source_data = json.load(f)
    
    for qa_block in source_data:
        for qa_entry in qa_block["qa"]:
            if qa_entry.get("question") == question:
                return qa_entry.get("category", None)
    return None  

def process_file(file_path, embedding_model=None, bert_model_path='bert-base-uncased'):
    with open(file_path, "r") as f:
        data = json.load(f)

    if not data:
        return (0.0,) * 6 + (None,)

    judged_file_name = f"Judged_{os.path.basename(file_path)}"
    judged_file_path = os.path.join(os.path.dirname(file_path), judged_file_name)
    
    calculated_metrics, judged_data = load_judged_file(judged_file_path)

    questions = [item["question"] for item in data]
    
    for item in data:
        question = item["question"]
        if "category" not in item:
            category = get_category_from_source(source_file, question)
            if category is not None:
                item["category"] = category

    golds = [str(item["standard answer"]) for item in data]
    gens = [clean_answer(item["answer"]) for item in data]

    device = "cuda" if torch.cuda.is_available() else "cpu"
    try:
        tokenizer = AutoTokenizer.from_pretrained(bert_model_path)
        model = AutoModel.from_pretrained(bert_model_path).to(device)
        print("✓ BERT model loaded successfully")
    except Exception as e:
        print(f"× Failed to load BERT model: {e}")
        return (0.0,) * 6 + (None,)

    if "BERTScore" not in calculated_metrics:
        bert_scores = calculate_bert_f1([str(item["standard answer"]) for item in data], [clean_answer(item["answer"]) for item in data], model, tokenizer, device)
    else:
        bert_scores = [item["bert_score_f1"] for item in judged_data]  
    
    if "F1" not in calculated_metrics:
        f1_scores = [calculate_f1_score(g, gen) for g, gen in zip([str(item["standard answer"]) for item in data], [clean_answer(item["answer"]) for item in data])]
    else:
        f1_scores = [item["f1_score"] for item in judged_data]  
    
    if "BLEU1" not in calculated_metrics:
        bleu1_scores = [calculate_bleu1_score(g, gen) for g, gen in zip([str(item["standard answer"]) for item in data], [clean_answer(item["answer"]) for item in data])]
    else:
        bleu1_scores = [item["bleu1_score"] for item in judged_data] 

    
    if "CosineSimilarity" not in calculated_metrics and embedding_model:
        gold_embeddings = embedding_model.encode([str(item["standard answer"]) for item in data], convert_to_tensor=True, show_progress_bar=False)
        gen_embeddings = embedding_model.encode([clean_answer(item["answer"]) for item in data], convert_to_tensor=True, show_progress_bar=False)
        sim_scores = cosine_similarity(gen_embeddings, gold_embeddings).tolist()
    else:
        sim_scores = [item["cosine_similarity"] for item in judged_data]
    
    num_questions = len(data)
    for i, item in enumerate(data):
        item["answer"] = clean_answer(item["answer"])
        item["f1_score"] = f1_scores[i]
        item["bleu1_score"] = bleu1_scores[i]
        item["bert_score_f1"] = bert_scores[i]
        item["cosine_similarity"] = sim_scores[i]

    avg_f1 = sum(f1_scores) / num_questions if f1_scores else 0.0
    avg_bleu1 = sum(bleu1_scores) / num_questions if bleu1_scores else 0.0
    avg_bert = sum(bert_scores) / num_questions if bert_scores else 0.0
    avg_sim = sum(sim_scores) / num_questions if sim_scores else 0.0

    with open(judged_file_path, "w", encoding='utf-8') as f:
        json.dump(data, f, indent=4, ensure_ascii=False)

    return avg_f1, avg_bleu1, avg_sim, avg_bert, judged_file_path

def main():
    """Main function to evaluate RAG results using multiple metrics."""
    parser = argparse.ArgumentParser(description="Evaluate results using multiple metrics.")
    parser.add_argument(
        "--input", type=str, required=True,
        help="Path to the input file or folder containing result files.",
    )
    parser.add_argument(
        "--bert_model_path", type=str, default='/root/storage/models/bert-base-uncased',
        help="Path to the local BERT model or huggingface model name for scoring."
    )
    args = parser.parse_args()

    embedding_model_path = "/root/storage/models/bge-m3"
    embedding_model = None
    if os.path.exists(embedding_model_path):
        print(f"Loading embedding model from {embedding_model_path}...")
        embedding_model = SentenceTransformer(embedding_model_path)
        print("Embedding model loaded.")
    else:
        print(f"Warning: Embedding model not found at {embedding_model_path}. Cosine similarity (Sim) will not be calculated.")

    scores = process_file(args.input, embedding_model, args.bert_model_path)
    (avg_f1, avg_b1, avg_sim, avg_bert, judged_file_path) = scores
    total_aver_score = avg_f1 + avg_b1 + avg_sim + avg_bert/4
    print(f"\n--- Average Scores for {os.path.basename(args.input)} ---")
    print(f"F1 Score: {avg_f1:.4f}")
    print(f"BLEU-1 (B1): {avg_b1:.4f}")
    print(f"BERTScore F1: {avg_bert:.4f}")
    print(f"Cosine Similarity (Sim): {avg_sim:.4f}")
    print(f"Average Score: {total_aver_score:.4f}")
    print(f"\nJudged results saved to {judged_file_path}")

if __name__ == "__main__":
    main()
