import os
from transformers import AutoModel
import pandas as pd
import matplotlib.pyplot as plt
from . import util
import evaluate
import functools
import numpy as np
import functools
from .import llm_report_aggr

nanmean = functools.partial(pd.Series.mean, skipna = True)
nanstd = functools.partial(pd.Series.std, skipna = True)

model = AutoModel.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True, device_map="cuda")
rouge = evaluate.load("rouge")
bleu = evaluate.load("bleu", module_type="metric")
rouge_compute = functools.partial(rouge.compute, use_aggregator=False)
bleu_compute = bleu.compute

def calculate_embeddings(simulation_df: pd.DataFrame, model):
    """
    Calculate and cache embeddings for all unique human and llm texts in the dataframe.
    Returns a dict mapping text to embedding.
    """
    texts = pd.concat([simulation_df["text"], simulation_df["llm_text"]]).dropna().unique().tolist()
    embeddings = model.encode(texts, task="text-matching")
    return {text: emb for text, emb in zip(texts, embeddings)}


def score_message_cosine(row: pd.Series, embeddings):
    # cosine
    if pd.notna(row["llm_text"]):
        human_emb = embeddings.get(row["text"])
        llm_emb = embeddings.get(row["llm_text"])
        return (human_emb @ llm_emb.T).item()
    return pd.NA

def score_message_bleu(rows):
    scores = rows.apply(
        lambda row: bleu_compute(predictions=[row["llm_text"]], references=[[row["text"]]])["bleu"] if pd.notna(row["llm_text"]) else pd.NA,
        axis=1
    )
    return scores


def cosine_plot(simulation_df: pd.DataFrame, round_separators, n_rounds: int, player_name_col: str, output_path: str, version: str):
    """
    Generate cosine similarity plots for human vs LLM comparison.
    """
    plt.figure(figsize=(12, 6))
    util.plot_round_separators(round_separators, 1.01)
    colors = {}
    round_scores_round = []
    round_scores_player = []
    round_scores_score = []
    for round in range(n_rounds + 1):
        for player_id, group2 in simulation_df[simulation_df["chat_round_order"] == round].groupby(player_name_col):
            group2 = group2.fillna(float("nan"))
            if player_id not in colors:
                line, = plt.plot(group2["chat_order"], group2["score"], linestyle="-", marker=".", label=player_id)
                colors[player_id] = line.get_color()
            else:
                plt.plot(group2["chat_order"], group2["score"], linestyle="-", marker=".", color=colors[player_id])
            round_scores_round.append(round)
            round_scores_player.append(player_id)
            round_scores_score.append(group2[group2["event_type"] == "message_sent"]["score"].mean())

    plt.title("Human vs LLM Similarity\n")
    plt.ylim(-0.01, 1.01)
    plt.xlabel(f"Time Step (Chat Order)")
    plt.ylabel("Similarity Score")
    legend = plt.legend(fontsize="x-small", loc="lower left")
    legend.get_frame().set_alpha(0)
    plt.savefig(os.path.join(output_path, f"human_llm_sim_{version}.svg"))

    # plot: x-axis: round, y-axis: averaged similarity score
    plt.figure()
    round_scores = pd.DataFrame({"round": round_scores_round, "player": round_scores_player, "score": round_scores_score})
    for player_id, group in round_scores.groupby("player"):
        plt.plot(group["round"], group["score"], linestyle="-", marker=".", label=player_id, color=colors[player_id])
    plt.title("Human vs LLM Similarity (Averaged Across Rounds, message_sent)\n")
    plt.ylim(-0.01, 1.01)
    plt.xlabel("Round")
    plt.ylabel("Mean Similarity Score")
    legend = plt.legend(fontsize="x-small", loc="lower left")
    legend.get_frame().set_alpha(0)
    plt.savefig(os.path.join(output_path, f"human_llm_avg_sim_{version}.svg"))


def score_round_cosine(simulation_df: pd.DataFrame, embeddings) -> pd.Series:
    """
    For each round and each pair, calculate all (human, llm) pairwise cosine similarities 
    and assign the mean to all rows of the round for that specific pair.
    """
    df = simulation_df.copy()
    df["round_score"] = np.nan
    
    unique_rounds = df["chat_round_order"].unique()
    for round_num in unique_rounds:
        players = llm_report_aggr.get_players(simulation_df, round_num)
        for player in players:
            # Get messages for this round and pair from both human and LLM
            round_player_condition = (df["chat_round_order"] == round_num) & (df["event_type"] == "message_sent") & (df["sender_id"] == player)
            
            round_player_df = df[round_player_condition]
            human_texts = round_player_df["text"].dropna().unique()
            llm_texts = round_player_df["llm_text"].dropna().unique()
            
            if len(human_texts) == 0 or len(llm_texts) == 0:
                mean_sim = pd.NA
            else:
                sims = []
                for h in human_texts:
                    for l in llm_texts:
                        h_emb = embeddings.get(h)
                        l_emb = embeddings.get(l)
                        if h_emb is not None and l_emb is not None:
                            sim = (h_emb @ l_emb.T).item()
                            sims.append(sim)
                mean_sim = np.mean(sims) if len(sims) > 0 else pd.NA
            
            df.loc[round_player_condition, "round_score"] = mean_sim
    
    return df["round_score"]


def score_round_rouge(simulation_df: pd.DataFrame) -> pd.DataFrame:
    df = simulation_df.copy()
    df["round_rouge1"] = np.nan
    df["round_rouge2"] = np.nan
    df["round_rougeL"] = np.nan
    df["round_rougeLsum"] = np.nan
    
    unique_rounds = df["chat_round_order"].unique()
    for round_num in unique_rounds:
        players = llm_report_aggr.get_players(simulation_df, round_num)
        for player in players:            
            # Get messages for this round and pair from both human and LLM
            round_player_condition = (df["chat_round_order"] == round_num) & (df["event_type"] == "message_sent") & (df["sender_id"] == player)
            
            round_player_df = df[round_player_condition]
            human_texts = round_player_df["text"].dropna().unique()
            llm_texts = round_player_df["llm_text"].dropna().unique()
            
            if len(human_texts) == 0 or len(llm_texts) == 0:
                rouge_score = None
            else:
                preds = []
                refs = []
                for h in human_texts:
                    if pd.isna(h):
                        continue
                    for l in llm_texts:
                        if pd.isna(l):
                            continue
                        preds.append(l)
                        refs.append(h)
                rouge_score = rouge_compute(predictions=preds, references=refs) if len(preds) > 0 else None
            
            if rouge_score is None:
                df.loc[round_player_condition, "round_rouge1"] = pd.NA
                df.loc[round_player_condition, "round_rouge2"] = pd.NA
                df.loc[round_player_condition, "round_rougeL"] = pd.NA
                df.loc[round_player_condition, "round_rougeLsum"] = pd.NA
            else:
                df.loc[round_player_condition, "round_rouge1"] = np.mean(rouge_score["rouge1"])
                df.loc[round_player_condition, "round_rouge2"] = np.mean(rouge_score["rouge2"])
                df.loc[round_player_condition, "round_rougeL"] = np.mean(rouge_score["rougeL"])
                df.loc[round_player_condition, "round_rougeLsum"] = np.mean(rouge_score["rougeLsum"])
    
    return df[["round_rouge1", "round_rouge2", "round_rougeL", "round_rougeLsum"]]


def score_round_bleu(simulation_df: pd.DataFrame) -> pd.Series:
    df = simulation_df.copy()
    df["round_bleu"] = np.nan
    
    unique_rounds = df["chat_round_order"].unique()
    for round_num in unique_rounds:
        players = llm_report_aggr.get_players(simulation_df, round_num)
        for player in players:
            # Get messages for this round and pair from both human and LLM
            round_player_condition = (df["chat_round_order"] == round_num) & (df["event_type"] == "message_sent") & (df["sender_id"] == player)
            round_player_df = df[round_player_condition][["llm_text", "text"]].dropna()
            
            if len(round_player_df) == 0:
                score = pd.NA
            else:
                score = bleu_compute(
                    predictions=round_player_df["llm_text"].tolist(),
                    references=round_player_df["text"].tolist()
                )["bleu"]
            df.loc[round_player_condition, "round_bleu"] = score
    
    return df["round_bleu"]



def main(data_prefix: str, model_name: str, player_name_col: str, version: str, 
         run_cosine: bool = True, run_bleu: bool = True, run_rouge: bool = True):
    output_path = f"../../result/eval/human_llm/{data_prefix}/{model_name}"
    os.makedirs(output_path, exist_ok=True)
    
    # Path to existing results file
    existing_results_path = os.path.join(output_path, f"human_llm_score_{version}.csv")
    
    # Check if existing results file exists
    if os.path.exists(existing_results_path):
        print(f"Found existing results file: {existing_results_path}")
        simulation_df = pd.read_csv(existing_results_path)
        
        # Load round separators and chat order if needed for any metric
        if run_cosine or run_bleu or run_rouge:
            _, temp_df = util.load_simulation_dfs(
                data_prefix=data_prefix,
                model_name=model_name,
                version=version,
                filter_strategy="any",
                preprocess=True,
                consecutive_messages=True
            )
            _, round_separators = util.get_chat_order_and_separators(temp_df)
            del temp_df
        else:
            raise ValueError("No metric to compute, Check your parameters!")
    else:
        print("No existing results file found. Loading from util.load_simulation_dfs")
        # Load LLM data using load_simulation_dfs with 'any' filter strategy
        _, simulation_df = util.load_simulation_dfs(
            data_prefix=data_prefix,
            model_name=model_name,
            version=version,
            filter_strategy="any",  # Filter if either human or LLM message is invalid
            preprocess=True,
            consecutive_messages=True
        )
        simulation_df, round_separators = util.get_chat_order_and_separators(simulation_df)

    # Calculate embeddings only if cosine similarity is needed
    if run_cosine:
        print("Computing cosine similarity metrics...")
        embeddings = calculate_embeddings(simulation_df, model)
        
        # Assign round-wise cosine similarities
        simulation_df["round_score"] = score_round_cosine(simulation_df, embeddings)
        
        # Assign message-wise cosine similarities
        score_with_cache = lambda row: score_message_cosine(row, embeddings)
        simulation_df["score"] = simulation_df.apply(score_with_cache, axis=1)
        
        del embeddings  # Free memory
    else:
        print("Skipping cosine similarity metrics...")

    # Calculate ROUGE metrics if specified
    if run_rouge:
        print("Computing ROUGE metrics...")
        simulation_df[["round_rouge1", "round_rouge2", "round_rougeL", "round_rougeLsum"]] = score_round_rouge(simulation_df)
        
        rouge_scores = rouge_compute(predictions=simulation_df["llm_text"].fillna(""), references=simulation_df["text"].fillna(""))
        if rouge_scores is not None:
            simulation_df["rouge1"] = rouge_scores["rouge1"]
            simulation_df["rouge2"] = rouge_scores["rouge2"]
            simulation_df["rougeL"] = rouge_scores["rougeL"]
            simulation_df["rougeLsum"] = rouge_scores["rougeLsum"]
    else:
        print("Skipping ROUGE metrics...")

    # Calculate BLEU metrics if specified
    if run_bleu:
        print("Computing BLEU metrics...")
        simulation_df["round_bleu"] = score_round_bleu(simulation_df)
        simulation_df["bleu"] = score_message_bleu(simulation_df)
    else:
        print("Skipping BLEU metrics...")
    simulation_df.to_csv(os.path.join(output_path, f"human_llm_score_{version}.csv"), index=False)
    
    # Build aggregation dictionary based on which metrics were calculated
    agg_dict = {}
    
    # Define all possible score columns
    score_columns = ["score", "rouge1", "rouge2", "rougeL", "rougeLsum", "bleu"]
    round_score_columns = ["round_score", "round_rouge1", "round_rouge2", "round_rougeL", "round_rougeLsum", "round_bleu"]
    
    # Add existing score columns to aggregation dict
    for col in score_columns:
        if col in simulation_df.columns:
            agg_dict[col] = [nanmean, nanstd]
    
    # Add existing round score columns to aggregation dict
    for col in round_score_columns:
        if col in simulation_df.columns:
            agg_dict[col] = [nanmean, "size"]
    # Print mean scores for message_sent events using nanmean to handle NaN values
    # message_scores = simulation_df[simulation_df["event_type"] == "message_sent"]
    # if "score" in simulation_df.columns:
    #     print(f"Mean message score: {nanmean(message_scores['score']):.4f}")
    if agg_dict:  # Only generate report if there are metrics to aggregate
        report_df = simulation_df.groupby("event_type").agg(agg_dict)
        report_df = report_df.rename(columns={"nanmean": "mean", "nanstd": "std"})
        # print("REPORT DF VALUES:")
        # if "score" in report_df.columns:
        #     print(f"Message sent mean score: {report_df.loc['message_sent', ('score', 'mean')]:.4f}")
        if "round_score" in report_df.columns:
            print(f"Message sent mean round score: {report_df.loc['message_sent', ('round_score', 'mean')]:.4f}")
        report_df.to_csv(os.path.join(output_path, f"human_llm_score_report_{version}.csv"))

    simulation_df = simulation_df[simulation_df["event_type"].isin(["Initial Opinion", "message_sent", "Post Opinion", "tweet"])]
    simulation_df = simulation_df[simulation_df["text"] != ""]
    n_rounds = int(simulation_df["chat_round_order"].max()) + 1

    # plot: x-axis: time step, y-axis: similarity score (only if cosine similarity is enabled)
    if run_cosine:
        cosine_plot(simulation_df, round_separators, n_rounds, player_name_col, output_path, version)
