# Use when align LLM data (i.e., run opinion_proc.py with llm_column=True)

import pandas as pd
import matplotlib.pyplot as plt
from . import util


def main(data_prefix: str, model_name: str, eval_model_save_name: str, player_name_col: str, version: str, is_memory: bool = True):
    if not is_memory:
        raise ValueError("Independent evaluation is no longer supported.")
        human_file = f"../../result/eval/human_llm/{data_prefix}/{model_name}/opinion_human_{eval_model_save_name}_{version}.csv"
        llm_file = f"../../result/eval/human_llm/{data_prefix}/{model_name}/opinion_llm_{eval_model_save_name}_{version}.csv"
        output_human_llm_individual_file = f"../../result/eval/human_llm/{data_prefix}/{model_name}/opinion_plot_{eval_model_save_name}_{version}.svg"
        output_human_only_individual_file = f"../../result/eval/human_llm/{data_prefix}/{model_name}/opinion_plot_human_only_{eval_model_save_name}_{version}.svg"
        output_llm_only_individual_file = f"../../result/eval/human_llm/{data_prefix}/{model_name}/opinion_plot_llm_only_{eval_model_save_name}_{version}.svg"
        output_averaged_file = f"../../result/eval/human_llm/{data_prefix}/{model_name}/opinion_plot_avg_{eval_model_save_name}_{version}.svg"
    else:
        human_file = f"../../result/eval/human_llm/{data_prefix}/{model_name}/opinion_human_memory_{eval_model_save_name}_{version}.csv"
        llm_file = f"../../result/eval/human_llm/{data_prefix}/{model_name}/opinion_llm_memory_{eval_model_save_name}_{version}.csv"
        output_aggregated_scores = f"../../result/eval/human_llm/{data_prefix}/{model_name}/opinion_memory_{eval_model_save_name}_{version}.csv"
        output_human_llm_individual_file = f"../../result/eval/human_llm/{data_prefix}/{model_name}/opinion_plot_memory_{eval_model_save_name}_{version}.svg"
        output_human_only_individual_file = f"../../result/eval/human_llm/{data_prefix}/{model_name}/opinion_plot_memory_human_only_{eval_model_save_name}_{version}.svg"
        output_llm_only_individual_file = f"../../result/eval/human_llm/{data_prefix}/{model_name}/opinion_plot_memory_llm_only_{eval_model_save_name}_{version}.svg"
        output_averaged_file = f"../../result/eval/human_llm/{data_prefix}/{model_name}/opinion_plot_memory_avg_{eval_model_save_name}_{version}.svg"

    human_df = pd.read_csv(human_file)
    llm_df = pd.read_csv(llm_file)


    # Draw separators and add round labels
    human_df, round_separators = util.get_chat_order_and_separators(human_df)
    llm_df, _ = util.get_chat_order_and_separators(llm_df)
    llm_df["chat_order"] = human_df["chat_order"]
    n_rounds = round_separators.shape[0] + 3  # should be +2 (initial and post opinions), +1 for upper bound

    colors = {}
    colored = []
    labeled = []

    human_round_scores_round = []
    human_round_scores_player = []
    human_round_scores_score = []
    llm_round_scores_round = []
    llm_round_scores_player = []
    llm_round_scores_score = []


    def aggregate_scores(human_df: pd.DataFrame, llm_df: pd.DataFrame, output_file: str):
        differ_fields = ["message", "likert_pred", "label_pred", "conversation", "prompt"]
        skipped_fields = ["human_validity", "llm_validity", "validity_reason", "input_prompt_validation"]
        differ_fields_human_prefix = ["human_" + f for f in differ_fields] + ["human_validity"]
        differ_fields_llm_prefix = ["llm_" + f for f in differ_fields] + ["llm_validity"]
        differ_fields_human_suffix = [f + "_human" for f in differ_fields] + ["human_validity_human"]
        differ_fields_llm_suffix = [f + "_llm" for f in differ_fields] + ["llm_validity_llm"]
        other_fields = [f for f in human_df.columns.tolist() if f not in differ_fields and f not in skipped_fields]
        merged_df = human_df.merge(llm_df, on=other_fields, how="outer", suffixes=("_human", "_llm"))
        # change suffix names to prefix names
        merged_df.rename(columns=dict(zip(differ_fields_human_suffix, differ_fields_human_prefix)), inplace=True)
        merged_df.rename(columns=dict(zip(differ_fields_llm_suffix, differ_fields_llm_prefix)), inplace=True)
        assert len(merged_df) == len(human_df) == len(llm_df), "Merged dataframe has different length than human or LLM dataframe"
        merged_df.to_csv(output_file, index=False)


    def set_plot():
        plt.ylim(0.9, 6.1)
        plt.yticks([1, 2, 3, 4, 5, 6])
        plt.xlabel(f"Time Step (Chat Order)")
        plt.ylabel("Likert Score")
        legend = plt.legend(fontsize="small", loc="lower left")
        legend.get_frame().set_alpha(0)
        labeled.clear()


    def plot_human(accu_round_scores: bool):
        for round in range(n_rounds):
            round_scores_player = []
            round_scores_score = []
            for player_id, group in human_df[human_df["chat_round_order"] == round].groupby(player_name_col):
                if player_id not in colors:
                    line, = plt.plot(group["chat_order"], group["likert_pred"], linestyle="-", marker="+", label=f"{player_id}", alpha=0.5)
                    colors[player_id] = line.get_color()
                    labeled.append(player_id)
                else:
                    if player_id not in labeled:
                        plt.plot(group["chat_order"], group["likert_pred"], linestyle="-", marker="+", label=f"{player_id}", color=colors[player_id], alpha=0.5)
                        labeled.append(player_id)
                    else:
                        plt.plot(group["chat_order"], group["likert_pred"], linestyle="-", marker="+", color=colors[player_id], alpha=0.5)
                round_scores_player.extend(group[player_name_col].tolist())
                round_scores_score.extend(group["likert_pred"].tolist())
            if accu_round_scores:
                round_scores = pd.DataFrame({"player": round_scores_player, "likert_pred": round_scores_score})
                round_scores = round_scores.groupby("player").mean().reset_index()
                human_round_scores_round.extend([round] * round_scores.shape[0])
                human_round_scores_player.extend(round_scores["player"].tolist())
                human_round_scores_score.extend(round_scores["likert_pred"].tolist())


    def plot_llm(accu_round_scores: bool):
        for round in range(n_rounds):
            round_scores_player = []
            round_scores_score = []
            for player_id, group in llm_df[llm_df["chat_round_order"] == round].groupby(player_name_col):
                if player_id not in colored:
                    plt.plot(group["chat_order"], group["likert_pred"], linestyle="dotted", marker="x", label=f"{player_id} (Agent)", color=colors[player_id], alpha=0.7)
                    colored.append(player_id)
                    labeled.append(player_id)
                else:
                    if f"{player_id} (Agent)" not in labeled:
                        plt.plot(group["chat_order"], group["likert_pred"], linestyle="dotted", marker="x", color=colors[player_id], alpha=0.7)
                        labeled.append(player_id)
                    else:
                        plt.plot(group["chat_order"], group["likert_pred"], linestyle="dotted", marker="x", color=colors[player_id], alpha=0.7)
                round_scores_player.extend(group[player_name_col].tolist())
                round_scores_score.extend(group["likert_pred"].tolist())
            if accu_round_scores:
                round_scores = pd.DataFrame({"player": round_scores_player, "likert_pred": round_scores_score})
                round_scores = round_scores.groupby("player").mean().reset_index()
                llm_round_scores_round.extend([round] * round_scores.shape[0])
                llm_round_scores_player.extend(round_scores["player"].tolist())
                llm_round_scores_score.extend(round_scores["likert_pred"].tolist())

    aggregate_scores(human_df, llm_df, output_aggregated_scores)

    plt.figure(figsize=(12, 6))
    util.plot_round_separators(round_separators, 2.1)
    plot_human(accu_round_scores=True)
    plot_llm(accu_round_scores=True)
    plt.title("Human vs LLM Opinion Trajectory\n")
    set_plot()
    plt.savefig(output_human_llm_individual_file)

    plt.figure(figsize=(12, 6))
    util.plot_round_separators(round_separators, 2.1)
    plot_human(accu_round_scores=False)
    plt.title("Human Opinion Trajectory\n")
    set_plot()
    plt.savefig(output_human_only_individual_file)

    plt.figure(figsize=(12, 6))
    util.plot_round_separators(round_separators, 2.1)
    plot_llm(accu_round_scores=False)
    plt.title("LLM Opinion Trajectory\n")
    set_plot()
    plt.savefig(output_llm_only_individual_file)


    # make another plot for round likert scores
    human_round_scores = pd.DataFrame({
        "round": human_round_scores_round,
        "player": human_round_scores_player,
        "score": human_round_scores_score,
    })
    llm_round_scores = pd.DataFrame({
        "round": llm_round_scores_round,
        "player": llm_round_scores_player,
        "score": llm_round_scores_score,
    })
    plt.figure()
    for player_id, group in human_round_scores.groupby("player"):
        plt.plot(group["round"], group["score"], label=player_id, marker='+', linestyle='-', color=colors[player_id], alpha=0.5)
    for player_id, group in llm_round_scores.groupby("player"):
        plt.plot(group["round"], group["score"], label=f"{player_id} (Agent)", marker='x', linestyle='dotted', color=colors[player_id], alpha=0.7)
    plt.title("Human vs LLM Opinion Trajectory (Averaged Across Rounds)\n")
    plt.ylim(0.9, 6.1)
    plt.yticks([1, 2, 3, 4, 5, 6])
    plt.xlabel(f"Time Step (Round)")
    plt.ylabel("Mean Likert Score")
    legend = plt.legend(fontsize="small", loc="lower left")
    legend.get_frame().set_alpha(0)
    plt.savefig(output_averaged_file)
