import pandas as pd
from . import util
import typing
import tqdm
import numpy as np
import os

model_info = {
    "gpt-4o-mini-2024-07-18": { "type": "Instruct" },
    # "gpt-4o-mini-2024-07-18-ablation": { "type": "Instruct" },
    # "gpt-4o-mini-2024-07-18-ablation-v2": { "type": "Instruct" },
    # "ft:gpt-4o-mini-2024-07-18:camer:topic-split-all:BOqtcdMB": { "type": "SFT", "base": "gpt-4o-mini-2024-07-18" },
    # "ft:gpt-4o-mini-2024-07-18:camer:group-split-all:BOvZjzvU": { "type": "SFT", "base": "gpt-4o-mini-2024-07-18" },
    # "ft:gpt-4o-mini-2024-07-18:camer:round-split-all:BOvS862Y": { "type": "SFT", "base": "gpt-4o-mini-2024-07-18" },
    # "ft:gpt-4o-mini-2024-07-18:camer:round-split-valid:BRTJQLtG": { "type": "Instruct", "base": "gpt-4o-mini-2024-07-18" },
    # "Llama-3.1-8B": { "type": "Base" },
    "Llama-3.1-Tulu-3-8B-SFT": { "type": "SFT", "base": "Llama-3.1-8B" },
    # "Llama-3.1-Tulu-3-8B-DPO": { "type": "DPO", "base": "Llama-3.1-Tulu-3-8B-SFT" },
    # "Llama-3.1-Tulu-3-8B": { "type": "RLVR", "base": "Llama-3.1-Tulu-3-8B-DPO" },
    "Llama-3.1-8B-Instruct": { "type": "Instruct", "base": "Llama-3.1-8B" },
    "Llama-3.1-70B-Instruct": { "type": "Instruct", "base": "Llama-3.1-70B" },
    # "ft:Llama-3.1-8B-Instruct-250528:round-split-valid-5epochs": { "type": "Instruct", "base": "Llama-3.1-8B-Instruct" },
    # "ft:Llama-3.1-8B-Instruct:round-split-valid-5epochs": { "type": "Instruct", "base": "Llama-3.1-8B-Instruct" },
    "Qwen2.5-32B-Instruct": { "type": "Instruct" },
    "Mistral-7B-Instruct-v0.3": { "type": "Instruct" },
    # "ft:Llama-3.1-8B-Instruct-SFT-20250711:group-5epochs": { "type": "SFT", "base": "Llama-3.1-8B-Instruct" },
    # "ft:Llama-3.1-8B-Instruct-SFT-20250710:round-5epochs": { "type": "SFT", "base": "Llama-3.1-8B-Instruct" },
    # "ft:Llama-3.1-8B-Instruct-SFT-20250710:topic-5epochs": { "type": "SFT", "base": "Llama-3.1-8B-Instruct" },

    
    # "Llama-3.1-Tulu-3-8B-MT-DDPO-0129": { "type": "DDPO", "base": "Llama-3.1-Tulu-3-8B-SFT" }
    "ft:Llama-3.1-8B-Instruct-SFT-20250710:round-5epochs": { "type": "SFT", "base": "Llama-3.1-8B-Instruct" },
    "ft:Llama-3.1-8B-Instruct-SFT-20250710:topic-5epochs": { "type": "SFT", "base": "Llama-3.1-8B-Instruct" },
    "ft:Llama-3.1-8B-Instruct-SFT-20250711:group-5epochs": { "type": "SFT", "base": "Llama-3.1-8B-Instruct" },
}

def train_or_test(data_prefix: str, split: str):
    if split != "round":
        train_path = f"../../data/finetune_data/{split}_split_data/train"
        test_path = f"../../data/finetune_data/{split}_split_data/test"
        train_files = os.listdir(train_path)
        test_files = os.listdir(test_path)
        if any(data_prefix in f for f in train_files):
            return "train"
        elif any(data_prefix in f for f in test_files):
            return "test"
        else:
            return "unknown"
    else:
        return "round"

def get_pairs(user_data, round_number):
    """
    Retrieves the pairs of players for a given round number.

    Args:
        user_data (pd.DataFrame): The DataFrame containing message data.
        round_number (int): The round number to retrieve pairs for.
    Returns:
        list: A list of tuples, where each tuple contains (sender_agent, recipient_agent).
              The list represents the pairs of agents for the specified round.
    """
    # if len(user_data) == 0 or "event_type" not in user_data.columns or "chat_round_order" not in user_data.columns:
    #     return []
    pairs = user_data[(user_data["chat_round_order"] == round_number) & (user_data["event_type"] != "tweet")][["sender_id", "recipient_id"]].dropna().drop_duplicates().values.tolist()
    pairs = list(set(tuple(sorted(pair)) for pair in pairs))
    return pairs

def get_players(user_data, round_number):
    players = user_data[(user_data["chat_round_order"] == round_number) & (user_data["event_type"] != "tweet")]["sender_id"].dropna().drop_duplicates().values.tolist()
    return players

def metrics_per_round(human_df, llm_df, metric_name: str, abs_diff: bool):
    # Make a copy to avoid modifying the original dataframe
    df = human_df.copy()
    df[metric_name + "_round"] = np.nan
    
    unique_rounds = df["chat_round_order"].unique()
    for round_num in unique_rounds:
        players = get_players(human_df, round_num)
        for player in players:
            diffs = []
            # Get messages for this round from both human and LLM
            human_round_df = human_df[(human_df["chat_round_order"] == round_num) & (human_df["sender_id"] == player)]
            llm_round_df = llm_df[(llm_df["chat_round_order"] == round_num) & (llm_df["sender_id"] == player)]
            assert len(human_round_df) == len(llm_round_df)
            # if len(human_round_df) == 1:  # testing purposes
            #     continue
            
            # Calculate all pairwise differences for this pair
            for _, human_row in human_round_df.iterrows():
                for _, llm_row in llm_round_df.iterrows():
                    if metric_name == "n_words":
                        assert len(human_row["text"].split()) == human_row["n_words"]
                        assert len(llm_row["llm_text"].split()) == llm_row["n_words"]
                        # assert human_row["n_words"] <= llm_row["n_words"]
                    diff = abs(human_row[metric_name] - llm_row[metric_name]) if abs_diff else (human_row[metric_name] - llm_row[metric_name])
                    diffs.append(diff)

            # Calculate average difference for this pair
            round_avg_diff_player = np.nanmean(diffs) if len(diffs) > 0 else np.nan
            df.loc[(df["chat_round_order"] == round_num) & (df["sender_id"] == player), metric_name + "_round"] = round_avg_diff_player

            # # obsolete: no weighting
            # mask = (df["chat_round_order"] == round_num) & (df["sender_id"] == player)
            # matching_indices = df.index[mask].tolist()
            # if matching_indices:
            #     # Set first matching row to round_avg_diff, rest to NaN
            #     df.loc[matching_indices[0], metric_name + "_round"] = round_avg_diff_player
            #     for idx in matching_indices[1:]:
            #         df.loc[idx, metric_name + "_round"] = np.nan
    
    return df


def metrics_by_message(human_df, llm_df, metric_name: str, abs_diff: bool):
    """
    (Only for testing purposes) with message-level filtering to skip round-player with only one message
    """
    df = human_df.copy()
    df[metric_name + "_message"] = np.nan
    
    unique_rounds = df["chat_round_order"].unique()
    for round_num in unique_rounds:
        players = get_players(human_df, round_num)
        for player in players:
            human_round_df = human_df[(human_df["chat_round_order"] == round_num) & (human_df["sender_id"] == player)]
            llm_round_df = llm_df[(llm_df["chat_round_order"] == round_num) & (llm_df["sender_id"] == player)]
            assert len(human_round_df) == len(llm_round_df)
            # if len(human_round_df) == 1:  # testing purposes
            #     continue
            
            for i in range(len(human_round_df)):
                diff = abs(human_round_df.iloc[i][metric_name] - llm_round_df.iloc[i][metric_name]) if abs_diff else (human_round_df.iloc[i][metric_name] - llm_round_df.iloc[i][metric_name])
                df.at[human_round_df.index[i], metric_name + "_message"] = diff
    
    return df


def generate_llm_report_aggr(data_prefixes: typing.List[str], models: typing.List[str], eval_model_save_name: str, version: str = "v2", output_file: typing.Optional[str] = None):
    if output_file is None:
        output_file = f"../../result/eval/human_llm/llm_report_{version}.csv"
    
    output_df = pd.DataFrame(columns=["model_type", "model_base", "model", "type"])
    output_df.set_index(["model_type", "model_base", "model", "type"], inplace=True)
    
    for model_name in tqdm.tqdm(models, desc="LLM Report for Models"):
        running_scores: dict[str, list] = {}
        for score_name in ["semantic_similarity", "rouge1", "rouge2", "rougeL", "rougeLsum", "bleu", "likert_diff", "n_word_diff", "n_word_abs_diff"]:
            for level in ["message", "round"]:
                running_scores[f"{score_name}_{level}"] = []  # base metrics for message and round level
                
                for split_type in ["train", "test"]:  # train/test split metrics
                    for split_method in ["round", "topic", "group"]:
                        running_scores[f"{score_name}_{level}_{split_type}_{split_method}"] = []

        for data_prefix in data_prefixes:
            # if '20250429_162344' in data_prefix or '20250506_162344' in data_prefix or '20250422_040204' in data_prefix:
            #     continue
            sim_score_file = f"../../result/eval/human_llm/{data_prefix}/{model_name}/human_llm_score_{version}.csv"
            human_likert_file = f"../../result/eval/human_llm/{data_prefix}/{model_name}/opinion_human_memory_{eval_model_save_name}_{version}.csv"
            llm_likert_file = f"../../result/eval/human_llm/{data_prefix}/{model_name}/opinion_llm_memory_{eval_model_save_name}_{version}.csv"
            if not os.path.exists(sim_score_file) or not os.path.exists(human_likert_file) or not os.path.exists(llm_likert_file):
                continue
            for split_method in ["round", "group", "topic"]:
                split_type = train_or_test(data_prefix, split_method)

                # Aggregate human-LLM metrics
                sim_score_df = pd.read_csv(sim_score_file)
                sim_score_df = sim_score_df[sim_score_df["event_type"] == "message_sent"]
                score_fields = ["score", "rouge1", "rouge2", "rougeL", "rougeLsum", "bleu"]
                score_names = ["semantic_similarity", "rouge1", "rouge2", "rougeL", "rougeLsum", "bleu"]
                for score_field, score_name in zip(score_fields, score_names):
                    try:
                        running_scores[f"{score_name}_message"].extend(sim_score_df[score_field].tolist())
                        running_scores[f"{score_name}_round"].extend(sim_score_df[f"round_{score_field}"].tolist())
                        # likert_round_df = metrics_per_round(human_likert_df, llm_likert_df, "likert_pred", True)
                        # running_scores["likert_diff_round"].extend(likert_round_df["likert_pred_round"].dropna().tolist())
                        if split_type == "round":
                            sim_score_df_train = sim_score_df[sim_score_df["chat_round_order"] != 3]
                            sim_score_df_test = sim_score_df[sim_score_df["chat_round_order"] == 3]
                            running_scores[f"{score_name}_message_train_round"].extend(sim_score_df_train[score_field].tolist())
                            running_scores[f"{score_name}_message_test_round"].extend(sim_score_df_test[score_field].tolist())
                            running_scores[f"{score_name}_round_train_round"].extend(sim_score_df_train[f"round_{score_field}"].tolist())
                            running_scores[f"{score_name}_round_test_round"].extend(sim_score_df_test[f"round_{score_field}"].tolist())
                        else:
                            if split_type != "unknown":
                                running_scores[f"{score_name}_message_{split_type}_{split_method}"].extend(sim_score_df[score_field].tolist())
                                running_scores[f"{score_name}_round_{split_type}_{split_method}"].extend(sim_score_df[f"round_{score_field}"].tolist())
                    except Exception as e:
                        continue


                # Average Differences in Likert Scale rating (average across all messages and 4 people in the entire experiment)
                human_likert_df = pd.read_csv(human_likert_file)
                human_likert_df = human_likert_df[human_likert_df["event_type"] == "message_sent"]
                llm_likert_df = pd.read_csv(llm_likert_file)
                llm_likert_df = llm_likert_df[llm_likert_df["event_type"] == "message_sent"]
                likert_diff_df = pd.DataFrame({
                    "chat_round_order": human_likert_df["chat_round_order"],
                    "empirica_id": human_likert_df["empirica_id"],
                    "score": (human_likert_df["likert_pred"] - llm_likert_df["likert_pred"]).abs()
                })
                running_scores["likert_diff_message"].extend(likert_diff_df["score"].tolist())
                
                likert_round_df = metrics_per_round(human_likert_df, llm_likert_df, "likert_pred", True)
                running_scores["likert_diff_round"].extend(likert_round_df["likert_pred_round"].dropna().tolist())

                if split_type == "round":
                    likert_diff_df_train = likert_diff_df[likert_diff_df["chat_round_order"] != 3]
                    likert_diff_df_test = likert_diff_df[likert_diff_df["chat_round_order"] == 3]
                    running_scores["likert_diff_message_train_round"].extend(likert_diff_df_train["score"].tolist())
                    running_scores["likert_diff_message_test_round"].extend(likert_diff_df_test["score"].tolist())
                    
                    likert_round_df_train = likert_round_df[likert_round_df["chat_round_order"] != 3]
                    likert_round_df_test = likert_round_df[likert_round_df["chat_round_order"] == 3]
                    running_scores["likert_diff_round_train_round"].extend(likert_round_df_train["likert_pred_round"].dropna().tolist())
                    running_scores["likert_diff_round_test_round"].extend(likert_round_df_test["likert_pred_round"].dropna().tolist())
                else:
                    if split_type != "unknown":
                        running_scores[f"likert_diff_message_{split_type}_{split_method}"].extend(likert_diff_df["score"].tolist())
                        running_scores[f"likert_diff_round_{split_type}_{split_method}"].extend(likert_round_df["likert_pred_round"].dropna().tolist())

                # Difference in average number of words
                simulation_file = f"../../result/eval/human_llm/{data_prefix}/{model_name}/opinion_human_memory_{eval_model_save_name}_{version}.csv"  # load after validity filtering
                simulation_df = pd.read_csv(simulation_file)
                # simulation_df = util.preprocess_simulation_df(simulation_df, consecutive_messages=False)
                simulation_df = simulation_df[simulation_df["event_type"] == "message_sent"]

                simulation_df["n_human_words"] = simulation_df["text"].str.split().str.len()
                
                # Handle individual values that might cause errors
                def safe_word_count(text):
                    try:
                        if pd.isna(text) or text is None:
                            return np.nan
                        return len(str(text).split())
                    except (AttributeError, TypeError):
                        return np.nan
                
                simulation_df["n_llm_words"] = simulation_df["llm_text"].apply(safe_word_count)
                # Print sums of word counts
                # print(f"Sum of human words: {simulation_df['n_human_words'].sum()}")
                # print(f"Sum of LLM words: {simulation_df['n_llm_words'].sum()}")
                simulation_df["diff_n_words"] = simulation_df["n_human_words"] - simulation_df["n_llm_words"]
                simulation_df["abs_diff_n_words"] = simulation_df["diff_n_words"].abs()
                
                human_word_df = simulation_df[["chat_round_order", "sender_id", "recipient_id", "event_type", "text", "n_human_words"]].rename(columns={"n_human_words": "n_words"})
                llm_word_df = simulation_df[["chat_round_order", "sender_id", "recipient_id", "event_type", "llm_text", "n_llm_words"]].rename(columns={"n_llm_words": "n_words"})
                
                running_scores["n_word_diff_message"].extend(simulation_df["diff_n_words"].dropna().tolist())
                running_scores["n_word_abs_diff_message"].extend(simulation_df["abs_diff_n_words"].dropna().tolist())
                # n_word_diff_message_df = simulation_df.copy().rename(columns={"diff_n_words": "n_words_message"})
                # n_word_abs_diff_message_df = simulation_df.copy().rename(columns={"abs_diff_n_words": "n_words_message"})
                
                # n_word_diff_message_df = metrics_by_message(human_word_df, llm_word_df, "n_words", False)
                # n_word_abs_diff_message_df = metrics_by_message(human_word_df, llm_word_df, "n_words", True)
                # running_scores["n_word_diff_message"].extend(n_word_diff_message_df["n_words_message"].dropna().tolist())
                # running_scores["n_word_abs_diff_message"].extend(n_word_abs_diff_message_df["n_words_message"].dropna().tolist())
                
                n_word_diff_round_df = metrics_per_round(human_word_df, llm_word_df, "n_words", False)
                n_word_abs_diff_round_df = metrics_per_round(human_word_df, llm_word_df, "n_words", True)
                
                running_scores["n_word_diff_round"].extend(n_word_diff_round_df["n_words_round"].dropna().tolist())
                running_scores["n_word_abs_diff_round"].extend(n_word_abs_diff_round_df["n_words_round"].dropna().tolist())
                
                # assert len(n_word_diff_message_df["n_words_message"].dropna()) == len(n_word_diff_round_df["n_words_round"].dropna())
                # assert len(n_word_abs_diff_message_df["n_words_message"].dropna()) == len(n_word_abs_diff_round_df["n_words_round"].dropna())
                # assert (n_word_diff_message_df["n_words_message"].dropna().mean() == n_word_diff_round_df["n_words_round"].dropna().mean()) or (len(n_word_diff_message_df["n_words_message"].dropna()) == 0 and len(n_word_diff_round_df["n_words_round"].dropna()) == 0)


                if split_type == "round":
                    simulation_df_train = simulation_df[simulation_df["chat_round_order"] != 3]
                    simulation_df_test = simulation_df[simulation_df["chat_round_order"] == 3]
                    running_scores["n_word_diff_message_train_round"].extend(simulation_df_train["diff_n_words"].tolist())
                    running_scores["n_word_diff_message_test_round"].extend(simulation_df_test["diff_n_words"].tolist())
                    running_scores["n_word_abs_diff_message_train_round"].extend(simulation_df_train["abs_diff_n_words"].tolist())
                    running_scores["n_word_abs_diff_message_test_round"].extend(simulation_df_test["abs_diff_n_words"].tolist())
                    
                    # Round-level for train/test splits
                    n_word_diff_round_df_train = n_word_diff_round_df[n_word_diff_round_df["chat_round_order"] != 3]
                    n_word_diff_round_df_test = n_word_diff_round_df[n_word_diff_round_df["chat_round_order"] == 3]
                    n_word_abs_diff_round_df_train = n_word_abs_diff_round_df[n_word_abs_diff_round_df["chat_round_order"] != 3]
                    n_word_abs_diff_round_df_test = n_word_abs_diff_round_df[n_word_abs_diff_round_df["chat_round_order"] == 3]
                    
                    running_scores["n_word_diff_round_train_round"].extend(n_word_diff_round_df_train["n_words_round"].dropna().tolist())
                    running_scores["n_word_diff_round_test_round"].extend(n_word_diff_round_df_test["n_words_round"].dropna().tolist())
                    running_scores["n_word_abs_diff_round_train_round"].extend(n_word_abs_diff_round_df_train["n_words_round"].dropna().tolist())
                    running_scores["n_word_abs_diff_round_test_round"].extend(n_word_abs_diff_round_df_test["n_words_round"].dropna().tolist())
                else:
                    if split_type != "unknown":
                        running_scores[f"n_word_diff_message_{split_type}_{split_method}"].extend(simulation_df["diff_n_words"].tolist())
                        running_scores[f"n_word_abs_diff_message_{split_type}_{split_method}"].extend(simulation_df["abs_diff_n_words"].tolist())
                        running_scores[f"n_word_diff_round_{split_type}_{split_method}"].extend(n_word_diff_round_df["n_words_round"].dropna().tolist())
                        running_scores[f"n_word_abs_diff_round_{split_type}_{split_method}"].extend(n_word_abs_diff_round_df["n_words_round"].dropna().tolist())
        for key, values in running_scores.items():
            # print(key)
            model_type = model_info[model_name]["type"]
            model_base = model_info[model_name].get("base", "")
            output_df.loc[(model_type, model_base, model_name, key), "score"] = float(np.nanmean(values))

    output_df.to_csv(output_file)