#!/usr/bin/env python3
import argparse
import os
from math import pi

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns


def load_scenario_data(dataset_name, models, scenarios, results_dir="experiments/plot_data/scenario"):
    data_rows = []
    for model in models:
        filename = os.path.join(results_dir, f"scenario_{dataset_name}_{model}.csv")
        if not os.path.exists(filename):
            print(f"Missing file: {filename}")
            continue
        df = pd.read_csv(filename)
        if df.empty:
            print(f"Warning: empty file {filename}")
            continue
        if "scenario" not in df.columns:
            print(f"Warning: 'scenario' column not found in {filename}")
            continue
        grouped = df.groupby("scenario").mean(numeric_only=True)
        for scenario in scenarios:
            if scenario not in grouped.index:
                print(f"Warning: no data for '{scenario}' in {filename}")
                continue
            row = grouped.loc[scenario].to_dict()
            row["model"] = model
            row["scenario"] = scenario
            data_rows.append(row)
        print(f"Loaded and processed: {filename}")
    if not data_rows:
        raise ValueError("No data could be loaded! Please run experiments first.")
    return pd.DataFrame(data_rows)


def find_metric_column(dataframe, metric_name):
    matches = [col for col in dataframe.columns if col.lower() == metric_name.lower()]
    if not matches:
        raise KeyError(f"Metric '{metric_name}' not found in columns: {list(dataframe.columns)}")
    return matches[0]


def normalize_values(values, global_min, global_max, invert=False):
    arr = np.asarray(values, dtype=float)
    if np.isnan(arr).all():
        normalized = np.full_like(arr, 0.5)
    else:
        nan_mask = np.isnan(arr)
        if nan_mask.any():
            valid = arr[~nan_mask]
            fill_value = valid.mean() if valid.size else float(global_min)
            arr[nan_mask] = fill_value
        span = global_max - global_min
        if span <= 1e-12:
            normalized = np.full_like(arr, 0.5)
        else:
            normalized = (arr - global_min) / span
    normalized = np.clip(normalized, 0.0, 1.0)
    if invert:
        normalized = 1 - normalized
    return normalized


def polygon_area(values):
    angles = np.linspace(0, 2 * pi, len(values), endpoint=False)
    x = values * np.cos(angles)
    y = values * np.sin(angles)
    return 0.5 * np.abs(np.dot(x, np.roll(y, -1)) - np.dot(y, np.roll(x, -1)))


def collect_metric_values(dataframe, model_name, metric_column, scenarios):
    model_data = dataframe[dataframe["model"] == model_name]
    values = []
    for scenario in scenarios:
        scenario_slice = model_data[model_data["scenario"] == scenario]
        if scenario_slice.empty:
            values.append(np.nan)
        else:
            values.append(scenario_slice.iloc[0][metric_column])
    return np.array(values, dtype=float)


def sanitize_raw_metric(values):
    arr = np.asarray(values, dtype=float)
    if np.isnan(arr).all():
        return np.full_like(arr, 0.5)
    nan_mask = np.isnan(arr)
    if nan_mask.any():
        valid = arr[~nan_mask]
        fill_value = valid.mean() if valid.size else 0.5
        arr[nan_mask] = fill_value
    return np.clip(arr, 0.0, 1.0)


def normalize_with_max(values, max_value):
    arr = np.asarray(values, dtype=float)
    if np.isnan(arr).all():
        sanitized = np.full_like(arr, max_value)
    else:
        nan_mask = np.isnan(arr)
        if nan_mask.any():
            valid = arr[~nan_mask]
            fill_value = valid.mean() if valid.size else max_value
            arr[nan_mask] = fill_value
        sanitized = arr
    sanitized = np.clip(sanitized, 0.0, max_value)
    normalized = 1.0 - (sanitized / max_value)
    return normalized


def build_radial_ticks():
    positions = [0.25, 0.5, 0.75, 1.0]
    labels = ["0.25", "0.50", "0.75", "1.00"]
    return positions, labels


def draw_model_pair(ax, dataframe, metric_name, metric_column, scenarios, base_model, pep_model,
                    display_names, invert_metric, color_map, scenario_display_names,
                    global_min, global_max):
    base_values = collect_metric_values(dataframe, base_model, metric_column, scenarios)
    pep_values = collect_metric_values(dataframe, pep_model, metric_column, scenarios)

    metric_lower = metric_name.lower()

    if metric_lower == "f1":
        base_norm = sanitize_raw_metric(base_values)
        pep_norm = sanitize_raw_metric(pep_values)
    elif metric_lower == "shd":
        base_norm = normalize_with_max(base_values, 40.0)
        pep_norm = normalize_with_max(pep_values, 40.0)
    elif metric_lower == "sid":
        base_norm = normalize_with_max(base_values, 90.0)
        pep_norm = normalize_with_max(pep_values, 90.0)
    else:
        base_norm = normalize_values(base_values, global_min, global_max, invert=invert_metric)
        pep_norm = normalize_values(pep_values, global_min, global_max, invert=invert_metric)
    angles = [n / float(len(scenarios)) * 2 * pi for n in range(len(scenarios))]
    angles += angles[:1]

    base_plot = base_norm.tolist()
    pep_plot = pep_norm.tolist()
    base_plot += base_plot[:1]
    pep_plot += pep_plot[:1]

    ax.plot(angles, base_plot, linewidth=2, linestyle="-", label=display_names[base_model],
            color=color_map[base_model], marker="o", markersize=4)
    ax.fill(angles, base_plot, alpha=0.08, color=color_map[base_model])
    ax.plot(angles, pep_plot, linewidth=2, linestyle="-", label=display_names[pep_model],
            color=color_map[pep_model], marker="o", markersize=4)
    ax.fill(angles, pep_plot, alpha=0.08, color=color_map[pep_model])

    ax.set_xticks(angles[:-1])
    scenario_labels = [scenario_display_names.get(s, s.upper()) for s in scenarios]
    ax.set_xticklabels(scenario_labels, size=16)
    ax.set_rlabel_position(235)

    tick_positions, tick_labels = build_radial_ticks()
    ax.set_yticks(tick_positions)
    ax.set_yticklabels(tick_labels, size=14)
    ax.set_ylim(0, 1.0)

    base_area = polygon_area(base_norm)
    pep_area = polygon_area(pep_norm)
    if base_area <= 1e-8:
        improvement = np.nan
    else:
        improvement = (pep_area - base_area) / base_area * 100.0
    if np.isnan(improvement):
        title_improvement = "Delta Area N/A"
    else:
        title_improvement = r"$\Delta$ Area " + f"{improvement:+.1f}%" 

    ax.text(0.16, 1.09, title_improvement, transform=ax.transAxes,
            ha="left", va="center", fontsize=18, fontstyle="italic", fontweight='bold', color='red')

    return improvement


def main():
    parser = argparse.ArgumentParser(description="Generate 3x6 spider plots (backbone vs PEP) across metrics")
    parser.add_argument("--dataset", type=str, required=True,
                        help="Dataset name (e.g., sachs, SynER4, SynSF4)")
    parser.add_argument("--output_dir", type=str, default="fig/scenarios",
                        help="Output directory for plots")
    parser.add_argument("--results_dir", type=str, default="experiments/plot_data/scenario",
                        help="Directory containing scenario CSV files")
    args = parser.parse_args()

    backbones = ["CAM", "SCORE", "DAS", "NoGAM", "DiffAN", "CaPS"]
    pep_counterparts = {
        "CAM": "CAM-P2",
        "SCORE": "SCORE-P2",
        "DAS": "DAS-P2",
        "NoGAM": "NoGAM-P2",
        "DiffAN": "DiffAN-P2",
        "CaPS": "OURS",
    }

    models = []
    for base in backbones:
        models.extend([base, pep_counterparts[base]])

    display_names = {
        "CAM": "CAM",
        "CAM-P2": "CAM w/ PEP",
        "SCORE": "SCORE",
        "SCORE-P2": "SCORE w/ PEP",
        "DAS": "DAS",
        "DAS-P2": "DAS w/ PEP",
        "NoGAM": "NoGAM",
        "NoGAM-P2": "NoGAM w/ PEP",
        "DiffAN": "DiffAN",
        "DiffAN-P2": "DiffAN w/ PEP",
        "CaPS": "CaPS",
        "OURS": "CaPS w/ PEP",
    }

    scenario_display_names = {
        # "vanilla": "Vanilla",
        "pnl": "PNL",
        "lingam": "LiNGAM",
        "confounded": "Confound",
        "measure_err": "M-Err",
        "timino": "Non-i.i.d",
        "unfaithful": "Unfaithful",
    }

    scenarios = ["vanilla", "pnl", "lingam", "confounded", "measure_err", "timino", "unfaithful"]
    scenarios = ["pnl", "lingam", "confounded", "measure_err", "timino", "unfaithful"]
    metrics = ["shd", "sid", "F1"]
    lower_is_better_metrics = {"shd", "sid"}
    metric_row_labels = {
        "shd": "SHD (normalized & inverted)",
        "sid": "SID (normalized & inverted)",
        "F1": "F1",
    }

    print(f"Generating 3x6 spider plots for {args.dataset}")
    data_df = load_scenario_data(args.dataset, models, scenarios, args.results_dir)
    os.makedirs(args.output_dir, exist_ok=True)

    metric_columns = {metric: find_metric_column(data_df, metric) for metric in metrics}
    metric_ranges = {}
    for metric, column in metric_columns.items():
        if metric.lower() == "shd":
            metric_ranges[metric] = (0.0, 40.0)
            continue
        if metric.lower() == "sid":
            metric_ranges[metric] = (0.0, 90.0)
            continue
        values = data_df[column].to_numpy(dtype=float)
        values = values[~np.isnan(values)]
        if values.size == 0:
            metric_ranges[metric] = (0.0, 1.0)
            continue
        metric_min = values.min()
        metric_max = values.max()
        if np.isclose(metric_max - metric_min, 0.0):
            if metric_min == 0.0:
                metric_max = 1.0
            else:
                metric_min *= 0.9
                metric_max *= 1.1
        metric_ranges[metric] = (metric_min, metric_max)

    colors_base = sns.color_palette("tab20c", 20)
    colors = sns.color_palette("Paired", 12)
    color_map = {
        "CAM": colors_base[19],
        "CAM-P2": colors_base[17],  # Match CAM hue
        "SCORE": colors[6], 
        "SCORE-P2": colors[7],  # Match SCORE hue
        "DAS": colors[0],
        "DAS-P2": colors[1],  # Match SCORE hue
        "NoGAM": colors[2],
        "NoGAM-P2": colors[3],  # Match NoGAM hue
        "DiffAN": colors[8],
        "DiffAN-P2": colors[9],  # Match DiffAN hue
        "CaPS": colors[4],
        "OURS": colors[5]  # Highlight P^2_CaPS in red tone
    }

    fig, axes = plt.subplots(len(metrics), len(backbones), figsize=(24, 12), subplot_kw=dict(polar=True))

    for row, metric in enumerate(metrics):
        metric_column = metric_columns[metric]
        global_min, global_max = metric_ranges[metric]
        invert_flag = metric.lower() in lower_is_better_metrics
        for col, backbone in enumerate(backbones):
            ax = axes[row, col]
            pep_model = pep_counterparts[backbone]
            draw_model_pair(
                ax,
                data_df,
                metric,
                metric_column,
                scenarios,
                backbone,
                pep_model,
                display_names,
                invert_flag,
                color_map,
                scenario_display_names,
                global_min,
                global_max,
            )
            if row == len(metrics) - 1:
                ax.set_xlabel(display_names[backbone], fontsize=18, labelpad=15, fontweight="bold")
            if col == 0:
                ax.text(-0.2, 0.5, metric_row_labels[metric], transform=ax.transAxes,
                        rotation=90, va="center", ha="center", fontsize=18, fontweight="bold")

    handles, labels = axes[0, 0].get_legend_handles_labels()
    # if handles:
    #     fig.legend(handles, labels, loc="upper center", ncol=4, bbox_to_anchor=(0.5, 1.05), fontsize=12)

    plt.tight_layout(rect=[0, 0.02, 1, 0.98])
    output_path = os.path.join(args.output_dir, f"{args.dataset}_six_backbones_spider.png")
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close(fig)
    print(f"Saved: {output_path}")


if __name__ == "__main__":
    main()
