import os
import pandas as pd
import numpy as np
import logging

# Setup logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")

# Constants
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
PREPROCESSED_PATH = os.path.join(PROJECT_ROOT, "result", "group_level_eval", "preprocessed_depth_valid.csv")
ROUND_SPLIT_DIR = os.path.join(PROJECT_ROOT, "data", "finetune_data", "round_split_data")
OUTPUT_ROOT = os.path.join(PROJECT_ROOT, "result", "group_level_eval", "round_split_depth")
EVENT_KEYS = {
    "Initial": ("Initial Opinion", None),
    "Post": ("Post Opinion", None),
    "Tweet1": ("tweet", 1),
    "Tweet2": ("tweet", 2),
    "Tweet3": ("tweet", 3),
}

def load_train_test_exp_dirs():
    train_dir = os.path.join(ROUND_SPLIT_DIR, "train")
    test_dir = os.path.join(ROUND_SPLIT_DIR, "test")
    train_exp_dirs = {os.path.splitext(f)[0] for f in os.listdir(train_dir) if f.endswith(".csv")}
    test_exp_dirs = {os.path.splitext(f)[0] for f in os.listdir(test_dir) if f.endswith(".csv")}
    return train_exp_dirs, test_exp_dirs

def normalize_column(series):
    series = pd.to_numeric(series, errors="coerce").fillna(0)
    return series.apply(lambda x: x - 3.5 if x != 0 else 0)

def calculate_bias_diversity(series):
    return series.mean(), series.std(ddof=0)

def process_round_topic_stats(df_topic, topic):
    output_dir = os.path.join(OUTPUT_ROOT, topic)
    os.makedirs(output_dir, exist_ok=True)

    df_topic.to_csv(os.path.join(output_dir, f"round_split_{topic}.csv"), index=False)

    rows = []

    for target, label in [("human_likert_pred", "human")] + [("llm_likert_pred", m) for m in df_topic["model_name"].unique()]:
        model_df = df_topic if label == "human" else df_topic[df_topic["model_name"] == label]

        for ttype in model_df["type"].unique():
            subdf = model_df[model_df["type"] == ttype]
            stat_row = {
                "Topic": topic,
                "Human/LLM": label,
                "Type": ttype
            }

            metrics = {}
            for name, (etype, order) in EVENT_KEYS.items():
                filtered = subdf[subdf["event_type"] == etype]
                if order is not None:
                    filtered = filtered[filtered["chat_order"] == order]
                values = filtered[target]
                metrics[name] = calculate_bias_diversity(values)

            stat_row.update({
                "Initial Opinion - Slider Bias": metrics["Initial"][0],
                "Initial Opinion - Slider Diversity": metrics["Initial"][1],
                "Post Opinion - Slider Bias": metrics["Post"][0],
                "Post Opinion - Slider Diversity": metrics["Post"][1],
                "Post Opinion (Slider) Bias - Initial Opinion (Slider) Bias": metrics["Post"][0] - metrics["Initial"][0],
                "Post Opinion (Slider) Diversity - Initial Opinion (Slider) Diversity": metrics["Post"][1] - metrics["Initial"][1],
                "Tweet 1 Bias": metrics["Tweet1"][0],
                "Tweet 1 Diversity": metrics["Tweet1"][1],
                "Tweet 2 Bias": metrics["Tweet2"][0],
                "Tweet 2 Diversity": metrics["Tweet2"][1],
                "Tweet 3 Bias": metrics["Tweet3"][0],
                "Tweet 3 Diversity": metrics["Tweet3"][1],
                "Tweet 3 Bias - Tweet 1 Bias": metrics["Tweet3"][0] - metrics["Tweet1"][0],
                "Tweet 3 Diversity - Tweet 1 Diversity": metrics["Tweet3"][1] - metrics["Tweet1"][1],
            })

            rows.append(stat_row)

    stats_df = pd.DataFrame(rows)
    stats_df.to_csv(os.path.join(output_dir, "stats.csv"), index=False)
    return stats_df

def round_split_analysis():
    if not os.path.exists(PREPROCESSED_PATH):
        logging.error(f"Preprocessed file not found: {PREPROCESSED_PATH}")
        return

    df = pd.read_csv(PREPROCESSED_PATH)
    df = df[df["ft_type"].isin(["round", "base"])].copy()
    # Filter out rows with empty llm_likert_pred
    df = df[df["llm_likert_pred"].notna() & (df["llm_likert_pred"] != "")]

    df["human_likert_pred"] = normalize_column(df["human_likert_pred"])
    df["llm_likert_pred"] = normalize_column(df["llm_likert_pred"])

    train_exp_dirs, test_exp_dirs = load_train_test_exp_dirs()
    def infer_type(exp_dir):
        if exp_dir in train_exp_dirs: return "train"
        if exp_dir in test_exp_dirs: return "test"
        return "unknown"
    df["type"] = df["exp_dir"].apply(infer_type)
    # Save filtered + normalized flat data for round split
    flat_output_path = os.path.join(OUTPUT_ROOT, "round_split.csv")
    os.makedirs(OUTPUT_ROOT, exist_ok=True)
    df.to_csv(flat_output_path, index=False)
    logging.info(f"Saved full filtered data to {flat_output_path}")

    all_stats = []
    for topic in df["topic"].unique():
        df_topic = df[df["topic"] == topic]
        stats_df = process_round_topic_stats(df_topic, topic)
        all_stats.append(stats_df)

    final_stats = pd.concat(all_stats, ignore_index=True)

    # Add aggregated 'All' rows
    agg_rows = []
    for target, label in [("human_likert_pred", "human")] + [("llm_likert_pred", m) for m in df["model_name"].unique()]:
        model_df = df if label == "human" else df[df["model_name"] == label]
        for ttype in model_df["type"].unique():
            subdf = model_df[model_df["type"] == ttype]
            stat_row = {
                "Topic": "All",
                "Human/LLM": label,
                "Type": ttype
            }
            metrics = {}
            for name, (etype, order) in EVENT_KEYS.items():
                filtered = subdf[subdf["event_type"] == etype]
                if order is not None:
                    filtered = filtered[filtered["chat_order"] == order]
                values = filtered[target]
                metrics[name] = calculate_bias_diversity(values)

            stat_row.update({
                "Initial Opinion - Slider Bias": metrics["Initial"][0],
                "Initial Opinion - Slider Diversity": metrics["Initial"][1],
                "Post Opinion - Slider Bias": metrics["Post"][0],
                "Post Opinion - Slider Diversity": metrics["Post"][1],
                "Post Opinion (Slider) Bias - Initial Opinion (Slider) Bias": metrics["Post"][0] - metrics["Initial"][0],
                "Post Opinion (Slider) Diversity - Initial Opinion (Slider) Diversity": metrics["Post"][1] - metrics["Initial"][1],
                "Tweet 1 Bias": metrics["Tweet1"][0],
                "Tweet 1 Diversity": metrics["Tweet1"][1],
                "Tweet 2 Bias": metrics["Tweet2"][0],
                "Tweet 2 Diversity": metrics["Tweet2"][1],
                "Tweet 3 Bias": metrics["Tweet3"][0],
                "Tweet 3 Diversity": metrics["Tweet3"][1],
                "Tweet 3 Bias - Tweet 1 Bias": metrics["Tweet3"][0] - metrics["Tweet1"][0],
                "Tweet 3 Diversity - Tweet 1 Diversity": metrics["Tweet3"][1] - metrics["Tweet1"][1],
            })

            agg_rows.append(stat_row)

    final_stats = pd.concat([final_stats, pd.DataFrame(agg_rows)], ignore_index=True)

    combined_rows = []
    for topic in df["topic"].unique():
        df_topic = df[df["topic"] == topic]
        for target, label in [("human_likert_pred", "human")] + [("llm_likert_pred", m) for m in df["model_name"].unique()]:
            model_df = df_topic if label == "human" else df_topic[df_topic["model_name"] == label]

            stat_row = {
                "Topic": topic,
                "Human/LLM": label,
                "Type": "train+test"
            }

            metrics = {}
            for name, (etype, order) in EVENT_KEYS.items():
                filtered = model_df[model_df["event_type"] == etype]
                if order is not None:
                    filtered = filtered[filtered["chat_order"] == order]
                values = filtered[target]
                metrics[name] = calculate_bias_diversity(values)

            stat_row.update({
                "Initial Opinion - Slider Bias": metrics["Initial"][0],
                "Initial Opinion - Slider Diversity": metrics["Initial"][1],
                "Post Opinion - Slider Bias": metrics["Post"][0],
                "Post Opinion - Slider Diversity": metrics["Post"][1],
                "Post Opinion (Slider) Bias - Initial Opinion (Slider) Bias": metrics["Post"][0] - metrics["Initial"][0],
                "Post Opinion (Slider) Diversity - Initial Opinion (Slider) Diversity": metrics["Post"][1] - metrics["Initial"][1],
                "Tweet 1 Bias": metrics["Tweet1"][0],
                "Tweet 1 Diversity": metrics["Tweet1"][1],
                "Tweet 2 Bias": metrics["Tweet2"][0],
                "Tweet 2 Diversity": metrics["Tweet2"][1],
                "Tweet 3 Bias": metrics["Tweet3"][0],
                "Tweet 3 Diversity": metrics["Tweet3"][1],
                "Tweet 3 Bias - Tweet 1 Bias": metrics["Tweet3"][0] - metrics["Tweet1"][0],
                "Tweet 3 Diversity - Tweet 1 Diversity": metrics["Tweet3"][1] - metrics["Tweet1"][1],
            })

            combined_rows.append(stat_row)

    # ➕ Add final "All + train+test" aggregation
    for target, label in [("human_likert_pred", "human")] + [("llm_likert_pred", m) for m in df["model_name"].unique()]:
        model_df = df if label == "human" else df[df["model_name"] == label]

        stat_row = {
            "Topic": "All",
            "Human/LLM": label,
            "Type": "train+test"
        }

        metrics = {}
        for name, (etype, order) in EVENT_KEYS.items():
            filtered = model_df[model_df["event_type"] == etype]
            if order is not None:
                filtered = filtered[filtered["chat_order"] == order]
            values = filtered[target]
            metrics[name] = calculate_bias_diversity(values)

        stat_row.update({
            "Initial Opinion - Slider Bias": metrics["Initial"][0],
            "Initial Opinion - Slider Diversity": metrics["Initial"][1],
            "Post Opinion - Slider Bias": metrics["Post"][0],
            "Post Opinion - Slider Diversity": metrics["Post"][1],
            "Post Opinion (Slider) Bias - Initial Opinion (Slider) Bias": metrics["Post"][0] - metrics["Initial"][0],
            "Post Opinion (Slider) Diversity - Initial Opinion (Slider) Diversity": metrics["Post"][1] - metrics["Initial"][1],
            "Tweet 1 Bias": metrics["Tweet1"][0],
            "Tweet 1 Diversity": metrics["Tweet1"][1],
            "Tweet 2 Bias": metrics["Tweet2"][0],
            "Tweet 2 Diversity": metrics["Tweet2"][1],
            "Tweet 3 Bias": metrics["Tweet3"][0],
            "Tweet 3 Diversity": metrics["Tweet3"][1],
            "Tweet 3 Bias - Tweet 1 Bias": metrics["Tweet3"][0] - metrics["Tweet1"][0],
            "Tweet 3 Diversity - Tweet 1 Diversity": metrics["Tweet3"][1] - metrics["Tweet1"][1],
        })

        combined_rows.append(stat_row)

    # Append to final_stats
    final_stats = pd.concat([final_stats, pd.DataFrame(combined_rows)], ignore_index=True)
    os.makedirs(OUTPUT_ROOT, exist_ok=True)
    final_stats.to_csv(os.path.join(OUTPUT_ROOT, "stats_summary.csv"), index=False)
    logging.info(f"Saved round split summary with {len(final_stats)} rows to stats_summary.csv")

if __name__ == "__main__":
    round_split_analysis()
