import os
from collections import defaultdict
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
EVAL_PATH = os.path.join(PROJECT_ROOT, "result", "group_level_eval", "base_analysis")
TOPICS = [
    "A_\"body_cleanse,\"_in_which_you_consume_only_particular_kinds_of_nutrients_over_1-3_days,_helps_your_body_to_eliminate_toxins",
    "Angels_are_real",
    "Everything_that_happens_can_eventually_be_explained_by_science",
    "Regular_fasting_will_improve_your_health",
    "The_position_of_the_planets_at_the_time_of_your_birth_can_influence_your_personality",
    "The_United_States_has_the_highest_federal_income_tax_rate_of_any_Western_country",
    "The_US_deficit_increased_after_President_Obama_was_elected",
]
LIKERT_VALUES = [1, 2, 3, 4, 5, 6]
AGG_HUMAN = np.zeros((6, 6), dtype=int)
AGG_MODEL = np.zeros((6, 6), dtype=int)
AGG_TWEET_HUMAN = np.zeros((6, 6), dtype=int)
AGG_TWEET_MODEL = np.zeros((6, 6), dtype=int)
AGG_HUMAN_SLIDER = np.zeros((6, 6), dtype=int)
AGG_MODEL_SLIDER = np.zeros((6, 6), dtype=int)
AGG_TWEET_HUMAN_SLIDER = np.zeros((6, 6), dtype=int)
AGG_TWEET_MODEL_SLIDER = np.zeros((6, 6), dtype=int)
all_transition_rows = []
all_human_tuples = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))  # topic → timestamp → type → [ids]
all_llm_tuples = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))

def save_tuples_txt(tuples_dict, output_path, source_label):
    with open(output_path, "w") as f:
        for topic in sorted(tuples_dict):
            f.write(f"{topic}\n")
            for ts in sorted(tuples_dict[topic]):
                f.write(f"- {ts} - {source_label}\n")
                for kind in ["opinion", "tweet"]:
                    for hid in sorted(tuples_dict[topic][ts].get(kind, [])):
                        f.write(f"--[{kind}] {hid}\n")

def build_tweet_transition_matrix(df, likert_col):
    df_tweets = df[(df["event_type"] == "tweet") & (df["chat_order"].isin([1, 3]))]
    grouped = df_tweets.groupby("human_id")

    transitions = []
    for human_id, group in grouped:
        if set(group["chat_order"]) == {1, 3}:
            try:
                val1 = group[group["chat_order"] == 1][likert_col].values[0]
                val3 = group[group["chat_order"] == 3][likert_col].values[0]
                if pd.notna(val1) and pd.notna(val3):
                    transitions.append((int(val1), int(val3)))
            except Exception as e:
                print(f"[WARN] Skipping tweet pair for human_id={human_id}: {e}")

    matrix = np.zeros((6, 6), dtype=int)
    for val1, val3 in transitions:
        if 1 <= val1 <= 6 and 1 <= val3 <= 6:
            matrix[val1 - 1][val3 - 1] += 1

    return matrix

def build_transition_matrix(df, likert_col):
    df_filtered = df[df["event_type"].isin(["Initial Opinion", "Post Opinion"])]
    grouped = df_filtered.groupby("human_id")
    transitions = []

    for human_id, group in grouped:
        if set(group["event_type"]) == {"Initial Opinion", "Post Opinion"}:
            try:
                init_val = group[group["event_type"] == "Initial Opinion"][likert_col].values[0]
                post_val = group[group["event_type"] == "Post Opinion"][likert_col].values[0]
                if pd.notna(init_val) and pd.notna(post_val):
                    transitions.append((int(init_val), int(post_val)))
            except Exception as e:
                print(f"[WARN] Skipping human_id={human_id} due to error: {e}")

    matrix = np.zeros((6, 6), dtype=int)
    for init, post in transitions:
        if 1 <= init <= 6 and 1 <= post <= 6:
            matrix[init - 1][post - 1] += 1

    return matrix

# def plot_heatmap(matrix, topic, source, output_path, cmap="Blues", center_zero=False):
#     df = pd.DataFrame(matrix, index=LIKERT_VALUES, columns=LIKERT_VALUES)

#     plt.figure(figsize=(7, 6))
#     sns.heatmap(
#         df,
#         annot=True,
#         fmt=".0f" if center_zero else "d",
#         cmap=cmap,
#         cbar=True,
#         center=0 if center_zero else None
#     )

#     title = (
#         f"{source} − Human Transition Difference"
#         if center_zero else f"{source} Transition Count"
#     )
#     plt.title(f"{title}: {topic.replace('_', ' ')}")

#     if "tweet" in output_path.lower():
#         plt.ylabel("Tweet 1 Stance")
#         plt.xlabel("Tweet 3 Stance")
#     else:
#         plt.ylabel("Initial Classified Stance")
#         plt.xlabel("Final Classified Stance")

#     plt.tight_layout()
#     plt.savefig(output_path, format="svg")
#     print(f"[INFO] Saved heatmap to: {output_path}")
#     plt.close()


def plot_heatmap(matrix, topic, source, output_path, cmap="Blues", center_zero=False):
    df = pd.DataFrame(matrix, index=LIKERT_VALUES, columns=LIKERT_VALUES)

    plt.figure(figsize=(7, 6))
    sns.heatmap(
        df,
        annot=True,
        fmt=".0f" if center_zero else "d",
        cmap=cmap,
        cbar=True,
        center=0 if center_zero else None
    )

    title = (
        f"{source} − Human Transition Difference"
        if center_zero else f"{source} Transition Count"
    )
    plt.title(f"{title}: {topic.replace('_', ' ')}")

    if "tweet" in output_path.lower():
        plt.ylabel("Tweet 1 Stance")
        plt.xlabel("Tweet 3 Stance")
    else:
        plt.ylabel("Initial Classified Stance")
        plt.xlabel("Final Classified Stance")

    plt.tight_layout()
    plt.savefig(output_path, format="svg")
    print(f"[INFO] Saved heatmap to: {output_path}")
    plt.close()


def plot_heatmap_slider(matrix, topic, source, output_path, cmap="Greens", center_zero=False):
    df = pd.DataFrame(matrix, index=LIKERT_VALUES, columns=LIKERT_VALUES)

    plt.figure(figsize=(7, 6))
    sns.heatmap(
        df,
        annot=True,
        fmt=".0f" if center_zero else "d",
        cmap=cmap,
        cbar=True,
        center=0 if center_zero else None
    )

    title = (
        f"{source} − Human Slider Difference"
        if center_zero else f"{source} Slider Count"
    )
    plt.title(f"{title}: {topic.replace('_', ' ')}")

    if "tweet" in output_path.lower():
        plt.ylabel("Tweet 1 Stance")
        plt.xlabel("Tweet 3 Stance")
    else:
        plt.ylabel("Initial Slider")
        plt.xlabel("Final Slider")

    plt.tight_layout()
    plt.savefig(output_path, format="svg")
    print(f"[INFO] Saved heatmap to: {output_path}")
    plt.close()


if __name__ == "__main__":
    for topic in TOPICS:
        csv_path = os.path.join(EVAL_PATH, topic, "filtered_depth.csv")
        if not os.path.exists(csv_path):
            print(f"[WARN] Missing file for topic: {csv_path}")
            continue

        df_all = pd.read_csv(csv_path)

        # === COLLECT TRANSITION DATA (HUMAN + MODELS) ===
        transition_used_rows = []

        # Opinion transitions (Initial ➜ Post)
        df_opinion = df_all[df_all["event_type"].isin(["Initial Opinion", "Post Opinion"])]
        valid_human_ids = df_opinion.groupby("human_id")["event_type"].nunique()
        valid_opinion_ids = valid_human_ids[valid_human_ids == 2].index
        df_opinion_valid = df_opinion[df_opinion["human_id"].isin(valid_opinion_ids)]
        transition_used_rows.append(df_opinion_valid)

        # Tweet transitions (Tweet 1 ➜ 3)
        df_tweet = df_all[(df_all["event_type"] == "tweet") & (df_all["chat_order"].isin([1, 3]))]
        valid_tweet_ids = df_tweet.groupby("human_id")["chat_order"].nunique()
        valid_tweet_ids = valid_tweet_ids[valid_tweet_ids == 2].index
        df_tweet_valid = df_tweet[df_tweet["human_id"].isin(valid_tweet_ids)]
        transition_used_rows.append(df_tweet_valid)

        # Combine, deduplicate, and save per topic
        df_transition_used = pd.concat(transition_used_rows).drop_duplicates(
            subset=["human_id", "event_type", "chat_order", "model_name"]
        ).copy()
        df_transition_used["source_topic"] = topic
        df_transition_used.to_csv(os.path.join(EVAL_PATH, topic, "transition_mat_data.csv"), index=False)
        all_transition_rows.append(df_transition_used)

        # === HUMAN ===
        print(f"\n=== Processing Human for Topic: {topic} ===")
        human_matrix = build_transition_matrix(df_all, "human_likert_pred")

        # Collect human tuples (opinion)
        df_opinion = df_all[df_all["event_type"].isin(["Initial Opinion", "Post Opinion"])]
        valid_human_ids = df_opinion.groupby(["time_stamp", "human_id"])["event_type"].nunique()
        valid_human_ids = valid_human_ids[valid_human_ids == 2].reset_index()
        for _, row in valid_human_ids.iterrows():
            all_human_tuples[topic][row["time_stamp"]]["opinion"].append(row["human_id"])

        # Collect human tuples (tweet)
        df_tweet = df_all[(df_all["event_type"] == "tweet") & (df_all["chat_order"].isin([1, 3]))]
        valid_tweet_ids = df_tweet.groupby(["time_stamp", "human_id"])["chat_order"].nunique()
        valid_tweet_ids = valid_tweet_ids[valid_tweet_ids == 2].reset_index()
        for _, row in valid_tweet_ids.iterrows():
            all_human_tuples[topic][row["time_stamp"]]["tweet"].append(row["human_id"])

        print("Raw Transition Matrix (Human):")
        print(pd.DataFrame(human_matrix, index=LIKERT_VALUES, columns=LIKERT_VALUES))

        human_dir = os.path.join(EVAL_PATH, topic, "human")
        os.makedirs(human_dir, exist_ok=True)
        human_plot = os.path.join(human_dir, "transition_matrix.svg")
        plot_heatmap(human_matrix, topic, "human", human_plot, cmap="Blues")

        human_slider_matrix = build_transition_matrix(df_all, "human_slider")
        human_plot_slider = os.path.join(human_dir, "transition_matrix_slider.svg")
        plot_heatmap_slider(human_slider_matrix, topic, "human", human_plot_slider, cmap="Greens")
        AGG_HUMAN_SLIDER += human_slider_matrix
        print(AGG_HUMAN_SLIDER)


        # === MODELS ===
        model_names = df_all["model_name"].dropna().unique()
        model_matrix_map = {}  # 🔧 moved before usage

        for model in model_names:
            print(f"\n=== Processing Model '{model}' for Topic: {topic} ===")
            df_model = df_all[df_all["model_name"] == model]
            model_matrix = build_transition_matrix(df_model, "llm_likert_pred")
            model_matrix_map[model] = model_matrix

            print(f"Raw Transition Matrix (Model: {model}):")
            print(pd.DataFrame(model_matrix, index=LIKERT_VALUES, columns=LIKERT_VALUES))

            model_dir = os.path.join(EVAL_PATH, topic, model)
            os.makedirs(model_dir, exist_ok=True)
            model_plot = os.path.join(model_dir, "transition_matrix.svg")
            plot_heatmap(model_matrix, topic, model, model_plot, cmap="Reds")

            model_slider_matrix = build_transition_matrix(df_model, "llm_slider")
            model_plot_slider = os.path.join(model_dir, "transition_matrix_slider.svg")
            plot_heatmap_slider(model_slider_matrix, topic, model, model_plot_slider, cmap="Greens")
            if model == "gpt-4o-mini-2024-07-18":
                AGG_MODEL_SLIDER += model_slider_matrix



        # === LLM TUPLES (after model_matrix_map is populated) ===
        if "gpt-4o-mini-2024-07-18" in model_matrix_map:
            df_model = df_all[df_all["model_name"] == "gpt-4o-mini-2024-07-18"]

            # Opinion
            df_model_opinion = df_model[df_model["event_type"].isin(["Initial Opinion", "Post Opinion"])]
            valid_model_ids = df_model_opinion.groupby(["time_stamp", "human_id"])["event_type"].nunique()
            valid_model_ids = valid_model_ids[valid_model_ids == 2].reset_index()
            for _, row in valid_model_ids.iterrows():
                all_llm_tuples[topic][row["time_stamp"]]["opinion"].append(row["human_id"])

            # Tweet
            df_model_tweet = df_model[(df_model["event_type"] == "tweet") & (df_model["chat_order"].isin([1, 3]))]
            valid_tweet_ids = df_model_tweet.groupby(["time_stamp", "human_id"])["chat_order"].nunique()
            valid_tweet_ids = valid_tweet_ids[valid_tweet_ids == 2].reset_index()
            for _, row in valid_tweet_ids.iterrows():
                all_llm_tuples[topic][row["time_stamp"]]["tweet"].append(row["human_id"])

            # Difference matrix
            diff_matrix = model_matrix_map["gpt-4o-mini-2024-07-18"] - human_matrix
            diff_plot = os.path.join(EVAL_PATH, topic, "transition_matrix_diff.svg")
            plot_heatmap(diff_matrix, topic, "gpt-4o-mini-2024-07-18", diff_plot, cmap="RdBu_r", center_zero=True)

            # Accumulate global
            AGG_MODEL += model_matrix_map["gpt-4o-mini-2024-07-18"]

        AGG_HUMAN += human_matrix

        # === TWEET TRANSITIONS ===
        print(f"\n--- Processing Tweet 1 ➜ 3 for Human ---")
        tweet_human_matrix = build_tweet_transition_matrix(df_all, "human_likert_pred")
        tweet_human_plot = os.path.join(human_dir, "tweet_transition_matrix.svg")
        plot_heatmap(tweet_human_matrix, topic, "human (tweet)", tweet_human_plot, cmap="Blues")

        tweet_human_slider_matrix = build_tweet_transition_matrix(df_all, "human_slider")
        tweet_human_plot_slider = os.path.join(human_dir, "tweet_transition_matrix_slider.svg")
        plot_heatmap_slider(tweet_human_slider_matrix, topic, "human (tweet)", tweet_human_plot_slider, cmap="Greens")
        AGG_TWEET_HUMAN_SLIDER += tweet_human_slider_matrix


        tweet_model_matrix_map = {}
        for model in model_names:
            df_model = df_all[df_all["model_name"] == model]
            tweet_model_matrix = build_tweet_transition_matrix(df_model, "llm_likert_pred")
            tweet_model_matrix_map[model] = tweet_model_matrix

            model_dir = os.path.join(EVAL_PATH, topic, model)
            tweet_model_plot = os.path.join(model_dir, "tweet_transition_matrix.svg")
            plot_heatmap(tweet_model_matrix, topic, f"{model} (tweet)", tweet_model_plot, cmap="Reds")
            tweet_model_slider_matrix = build_tweet_transition_matrix(df_model, "llm_slider")
            tweet_model_plot_slider = os.path.join(model_dir, "tweet_transition_matrix_slider.svg")
            plot_heatmap_slider(tweet_model_slider_matrix, topic, f"{model} (tweet)", tweet_model_plot_slider, cmap="Greens")
            if model == "gpt-4o-mini-2024-07-18":
                AGG_TWEET_MODEL_SLIDER += tweet_model_slider_matrix



        # TWEET DIFF
        if "gpt-4o-mini-2024-07-18" in tweet_model_matrix_map:
            tweet_diff_matrix = tweet_model_matrix_map["gpt-4o-mini-2024-07-18"] - tweet_human_matrix
            tweet_diff_plot = os.path.join(EVAL_PATH, topic, "tweet_transition_matrix_diff.svg")
            plot_heatmap(tweet_diff_matrix, topic, "Tweet Diff", tweet_diff_plot, cmap="RdBu_r", center_zero=True)

            AGG_TWEET_HUMAN += tweet_human_matrix
            AGG_TWEET_MODEL += tweet_model_matrix_map["gpt-4o-mini-2024-07-18"]

        # === Save tuple files per topic ===
        save_tuples_txt(
            {topic: all_human_tuples[topic]},
            os.path.join(EVAL_PATH, topic, "tuples_human_transition_mat.txt"),
            "Human"
        )
        save_tuples_txt(
            {topic: all_llm_tuples[topic]},
            os.path.join(EVAL_PATH, topic, "tuples_llm_transition_mat.txt"),
            "gpt-4o-mini-2024-07-18"
        )
    
    # === AGGREGATED HEATMAPS ===
    print("\n=== Saving Aggregated Transition Matrices ===")

    agg_human_svg = os.path.join(EVAL_PATH, "aggregated_transition_human.svg")
    agg_model_svg = os.path.join(EVAL_PATH, "aggregated_transition_gpt-4o-mini-2024-07-18.svg")
    agg_diff_svg = os.path.join(EVAL_PATH, "aggregated_transition_diff.svg")

    plot_heatmap(AGG_HUMAN, "Aggregated", "human", agg_human_svg, cmap="Blues")
    plot_heatmap(AGG_MODEL, "Aggregated", "gpt-4o-mini-2024-07-18", agg_model_svg, cmap="Reds")
    plot_heatmap(AGG_MODEL - AGG_HUMAN, "Aggregated", "gpt-4o-mini-2024-07-18", agg_diff_svg, cmap="RdBu_r", center_zero=True)

    print("\n=== Saving Aggregated Tweet Transition Matrices ===")
    agg_tweet_human_svg = os.path.join(EVAL_PATH, "aggregated_tweet_transition_human.svg")
    agg_tweet_model_svg = os.path.join(EVAL_PATH, "aggregated_tweet_transition_gpt-4o-mini-2024-07-18.svg")
    agg_tweet_diff_svg = os.path.join(EVAL_PATH, "aggregated_tweet_transition_diff.svg")

    plot_heatmap(AGG_TWEET_HUMAN, "Aggregated Tweets", "human", agg_tweet_human_svg, cmap="Blues")
    plot_heatmap(AGG_TWEET_MODEL, "Aggregated Tweets", "gpt-4o-mini-2024-07-18", agg_tweet_model_svg, cmap="Reds")
    plot_heatmap(AGG_TWEET_MODEL - AGG_TWEET_HUMAN, "Aggregated Tweets", "gpt-4o-mini-2024-07-18", agg_tweet_diff_svg, cmap="RdBu_r", center_zero=True)

    print("\n=== Saving Aggregated SLIDER Transition Matrices ===")

    agg_slider_human_svg = os.path.join(EVAL_PATH, "aggregated_transition_slider_human.svg")
    agg_slider_model_svg = os.path.join(EVAL_PATH, "aggregated_transition_slider_gpt-4o-mini-2024-07-18.svg")
    agg_slider_diff_svg = os.path.join(EVAL_PATH, "aggregated_transition_slider_diff.svg")

    plot_heatmap_slider(AGG_HUMAN_SLIDER, "Aggregated", "human", agg_slider_human_svg, cmap="Greens")
    plot_heatmap_slider(AGG_MODEL_SLIDER, "Aggregated", "gpt-4o-mini-2024-07-18", agg_slider_model_svg, cmap="Greens")
    plot_heatmap_slider(AGG_MODEL_SLIDER - AGG_HUMAN_SLIDER, "Aggregated", "gpt-4o-mini-2024-07-18", agg_slider_diff_svg, cmap="RdBu_r", center_zero=True)

    print("\n=== Saving Aggregated SLIDER Tweet Transition Matrices ===")

    agg_tweet_slider_human_svg = os.path.join(EVAL_PATH, "aggregated_tweet_transition_slider_human.svg")
    agg_tweet_slider_model_svg = os.path.join(EVAL_PATH, "aggregated_tweet_transition_slider_gpt-4o-mini-2024-07-18.svg")
    agg_tweet_slider_diff_svg = os.path.join(EVAL_PATH, "aggregated_tweet_transition_slider_diff.svg")

    plot_heatmap_slider(AGG_TWEET_HUMAN_SLIDER, "Aggregated Tweets", "human", agg_tweet_slider_human_svg, cmap="Greens")
    plot_heatmap_slider(AGG_TWEET_MODEL_SLIDER, "Aggregated Tweets", "gpt-4o-mini-2024-07-18", agg_tweet_slider_model_svg, cmap="Greens")
    plot_heatmap_slider(AGG_TWEET_MODEL_SLIDER - AGG_TWEET_HUMAN_SLIDER, "Aggregated Tweets", "gpt-4o-mini-2024-07-18", agg_tweet_slider_diff_svg, cmap="RdBu_r", center_zero=True)


    # === AGGREGATED TRANSITION DATA ACROSS ALL TOPICS ===
    if all_transition_rows:
        df_all_used = pd.concat(all_transition_rows, ignore_index=True)
        df_all_used.to_csv(os.path.join(EVAL_PATH, "transition_mat_data_all.csv"), index=False)
        print("[INFO] ✅ Saved full transition data to transition_mat_data_all.csv")

    save_tuples_txt(all_human_tuples, os.path.join(EVAL_PATH, "tuples_human_transition_mat_all.txt"), "Human")
    save_tuples_txt(all_llm_tuples, os.path.join(EVAL_PATH, "tuples_llm_transition_mat_all.txt"), "gpt-4o-mini-2024-07-18")
