#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Read experiment logs and produce 3 figures (one per distribution):
- x-axis: function names (exactly those used in run_experiments.py)
- For each x-tick: three side-by-side BARs for k ∈ {3, 6, 9}
- Bar height = mean(SHD), black error bar = std(SHD)
- Saves PNGs to plots/{distribution}.png

Also:
- Filters out rows with Success == 0 (accepts 'success' or 'Success').
- Uses colors inspired by your screenshot for k = [3, 6, 9].
"""

import os
import argparse
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

FUNCTION_ORDER = [
    "anm_c2_plus_e",
    "pnl_cpluse_cubed",
    "lsnm_c2_times_e",
    "linear_5c_plus_e",
    "c2e_plus_arctan_e",
    "c2_arctan_e_plus_e3",
    "c2_plus_arctan_c_times_e_plus_c_e3",
]

PRETTY_LABELS = {
    "anm_c2_plus_e":                      "nonlinear anm",
    "pnl_cpluse_cubed":                   "pnl",
    "lsnm_c2_times_e":                    "lsnm",
    "linear_5c_plus_e":                   "linear anm",
    "c2e_plus_arctan_e":                  "arbitrary nonlinear (i)",
    "c2_arctan_e_plus_e3":                "arbitrary nonlinear (ii)",
    "c2_plus_arctan_c_times_e_plus_c_e3": "arbitrary nonlinear (iii)",
}

K_ORDER = [3, 6, 9]

# Colors approximating the screenshot (k=3 -> golden/yellow, k=6 -> magenta, k=9 -> blue)
BAR_COLORS = {
    3:  "#d4a21f",   # golden-yellow
    6:  "#9c144a",   # magenta/wine
    9:  "#1e88e5",   # blue
}


def main():
    parser = argparse.ArgumentParser(description="Plot SHD mean/std as bar charts from experiment logs.")
    parser.add_argument("--in_csv", type=str, required=True, help="Path to experiments.csv produced by run_experiments.py")
    parser.add_argument("--out_dir", type=str, default="plots", help="Output directory for figures")
    parser.add_argument("--dpi", type=int, default=220)
    args = parser.parse_args()

    os.makedirs(args.out_dir, exist_ok=True)

    df = pd.read_csv(args.in_csv)

    # --- Filter out failures (accept either 'success' or 'Success') ---
    if "success" in df.columns:
        df = df[df["success"] == 1]
    elif "Success" in df.columns:
        df = df[df["Success"] == 1]

    required = {"func_name", "distribution", "k_envs", "seed", "shd"}
    missing = required - set(df.columns)
    if missing:
        raise ValueError(f"CSV is missing required columns: {missing}")

    # Consistent ordering of functions on x-axis
    df["func_name"] = pd.Categorical(df["func_name"], categories=FUNCTION_ORDER, ordered=True)

    for dist in df["distribution"].unique():
        sub = df[df["distribution"] == dist].copy()
        sub.sort_values(["func_name", "k_envs", "seed"], inplace=True)

        base_positions = np.arange(len(FUNCTION_ORDER)) + 1  # centers for each function tick
        offsets = [-0.22, 0.0, 0.22]                          # align three bars at each tick
        width = 0.18

        fig = plt.figure(figsize=(max(8, 1.2 * len(FUNCTION_ORDER)), 5))
        ax = plt.gca()

        legend_handles = []
        for j, kval in enumerate(K_ORDER):
            heights = []
            errors = []
            pos_j = []

            for idx, func in enumerate(FUNCTION_ORDER, start=1):
                series = sub[(sub["func_name"] == func) & (sub["k_envs"] == kval)]["shd"].dropna().astype(float).values
                if series.size == 0:
                    continue
                p = series.mean()
                heights.append(p)
                errors.append(p * (1 - p) if series.size > 1 else 0.0)
                pos_j.append(idx + offsets[j])

            if not heights:
                continue

            bars = ax.bar(
                pos_j,
                heights,
                width=width,
                color=BAR_COLORS.get(kval, "#888888"),
                edgecolor=None,
                zorder=2,
            )
            # Std-dev error bars as black lines on top
            ax.errorbar(
                pos_j,
                heights,
                yerr=errors,
                fmt="none",
                ecolor="#555555",   # slightly grey
                elinewidth=1,     # thinner
                capsize=0,
                alpha=0.7,
                zorder=3,
            )

            legend_handles.append(Patch(facecolor=BAR_COLORS.get(kval, "#888888"), label=f"k={kval}"))

        # X ticks & labels
        ax.set_xticks(base_positions)
        ax.set_xticklabels([PRETTY_LABELS.get(f, f) for f in FUNCTION_ORDER], rotation=30, ha="right")

        # Increase font sizes
        ax.tick_params(axis='y', labelsize=12)  # Y-axis ticks
        ax.tick_params(axis='x', labelsize=12)  # X-axis ticks
        ax.set_ylabel("Mean SHD (± std)", fontsize=14)  # Y-axis label
        title_fontsize = 14
        if dist == "gaussian":
            ax.set_title(f"SHD with Gaussian noise", fontsize=title_fontsize)  # Title
        elif dist == "gamma_a1_theta2":
            ax.set_title(fr"SHD with Gamma noise, $\alpha \in [0.5, 1]$", fontsize=title_fontsize)  # Title
        elif dist == "gamma_a2_theta2":
            ax.set_title(fr"SHD with Gamma noise, $\alpha \in [2, 2.5]$", fontsize=title_fontsize)  # Title

        # Update legend font size
        if legend_handles:
            ax.legend(handles=legend_handles, loc="upper left", frameon=True, title="# environments", title_fontsize=12, fontsize=12)

        ax.set_yticks(np.arange(0, 1.1, 0.2))

        # Grid behind bars
        ax.yaxis.grid(True, linestyle="--", alpha=0.4)
        ax.set_axisbelow(True)

        # Remove top and right borders
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

        out_path = os.path.join(args.out_dir, f"{dist}.pdf")
        plt.tight_layout()
        plt.savefig(out_path, dpi=args.dpi)
        plt.close(fig)
        print(f"✓ Saved {out_path}")

    print("✓ All done.")


if __name__ == "__main__":
    main()
