#!/usr/bin/env python3
import argparse
from pathlib import Path

import pandas as pd
import matplotlib.pyplot as plt


BRANCHES = ["ours", "naive", "aaai"]
COLORS = {"ours": "C0", "naive": "C1", "aaai": "C2"}

# Backward-compatibility aliases (if older CSVs had base_*)
ALIASES = {
    "ours": ["ours", "base", "baseline"],
    "naive": ["naive"],
    "aaai": ["aaai"],
}


def find_mean_std(df: pd.DataFrame, branch: str, metric: str):
    """Return (mean_col, std_col) names for a branch/metric if present, else (None, None)."""
    for b in ALIASES.get(branch, [branch]):
        mean_col = f"{b}_{metric}_mean"
        std_col = f"{b}_{metric}_std"
        if mean_col in df.columns and std_col in df.columns:
            return mean_col, std_col
    return None, None


def compute_outcome_mean_std(df: pd.DataFrame, branch: str):
    """Compute outcome mean and std for a branch using ones/zeros mean/std.

    outcome = ones / (ones + zeros)
    std is approximated via error propagation ignoring covariance:
      Var(R) ~= (Z/T^2)^2 Var(O) + (O/T^2)^2 Var(Z)
    where T=O+Z.
    """
    ones_mean_col, ones_std_col = find_mean_std(df, branch, "ones")
    zeros_mean_col, zeros_std_col = find_mean_std(df, branch, "zeros")
    if None in (ones_mean_col, ones_std_col, zeros_mean_col, zeros_std_col):
        return None, None

    O = df[ones_mean_col].astype(float).values
    Z = df[zeros_mean_col].astype(float).values
    T = O + Z

    # Avoid division by zero
    eps = 1e-9
    T_safe = T.copy()
    T_safe[T_safe == 0] = eps

    outcome_mean = O / T_safe

    VarO = df[ones_std_col].astype(float).values ** 2
    VarZ = df[zeros_std_col].astype(float).values ** 2
    dRdO = Z / (T_safe ** 2)
    dRdZ = -O / (T_safe ** 2)
    outcome_var = (dRdO ** 2) * VarO + (dRdZ ** 2) * VarZ
    outcome_std = outcome_var ** 0.5

    return outcome_mean, outcome_std


def compute_cost_mean_std(df: pd.DataFrame, branch: str):
    """Return (mean, std) arrays for the cost metric of a branch, or (None, None) if not present."""
    mean_col, std_col = find_mean_std(df, branch, "cost")
    if mean_col is None:
        return None, None
    return df[mean_col].astype(float).values, df[std_col].astype(float).values


def plot_outcome(csv_path: Path, out_path: Path | None):
    df = pd.read_csv(csv_path)
    if "step" not in df.columns:
        raise ValueError("CSV must contain a 'step' column")

    x = df["step"].values

    plt.figure(figsize=(10, 6))

    found_any = False
    for branch in BRANCHES:
        y, ystd = compute_outcome_mean_std(df, branch)
        if y is None:
            continue
        found_any = True
        color = COLORS.get(branch, None)
        plt.plot(x, y, label=branch, color=color)
        plt.fill_between(x, y - ystd, y + ystd, color=color, alpha=0.2)

    if not found_any:
        available = ", ".join(df.columns)
        raise ValueError(
            f"No matching columns found to compute outcome. Available columns: {available}"
        )

    plt.title(f"Simulation results: outcome (mean ± std)\n{csv_path.name}")
    plt.xlabel("step")
    plt.ylabel("outcome = ones / (ones + zeros)")
    plt.ylim(-0.05, 1.05)
    plt.legend()
    plt.grid(True, linestyle=":", alpha=0.4)
    plt.tight_layout()

    if out_path:
        out_path.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(out_path, dpi=150)
    else:
        plt.show()


def plot_cost(csv_path: Path, out_path: Path | None):
    df = pd.read_csv(csv_path)
    if "step" not in df.columns:
        raise ValueError("CSV must contain a 'step' column")

    x = df["step"].values

    plt.figure(figsize=(10, 6))

    found_any = False
    for branch in BRANCHES:
        y, ystd = compute_cost_mean_std(df, branch)
        if y is None:
            continue
        found_any = True
        color = COLORS.get(branch, None)
        plt.plot(x, y, label=branch, color=color)
        plt.fill_between(x, y - ystd, y + ystd, color=color, alpha=0.2)

    if not found_any:
        available = ", ".join(df.columns)
        raise ValueError(
            f"No matching cost columns found. Available columns: {available}"
        )

    plt.title(f"Simulation results: cost (mean ± std)\n{csv_path.name}")
    plt.xlabel("step")
    plt.ylabel("cost")
    plt.legend()
    plt.grid(True, linestyle=":", alpha=0.4)
    plt.tight_layout()

    if out_path:
        out_path.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(out_path, dpi=150)
    else:
        plt.show()


def main():
    ap = argparse.ArgumentParser(description="Plot outcome (mean±std) and optionally cost for ours/naive/aaai from simulation CSV")
    ap.add_argument("csv", type=Path, help="Path to CSV produced by run_sim")
    ap.add_argument(
        "--out",
        type=Path,
        default=None,
        help="Output image path for outcome (e.g., outcome.png). If omitted, shows an interactive window.",
    )
    ap.add_argument(
        "--cost-out",
        type=Path,
        default=None,
        help="If set, write a separate cost plot image to this path (e.g., cost.png).",
    )
    args = ap.parse_args()

    plot_outcome(args.csv, args.out)
    if args.cost_out:
        plot_cost(args.csv, args.cost_out)


if __name__ == "__main__":
    main()
