import argparse
import os
from typing import Iterable, List

import numpy as np
import pandas as pd

from experiments_ablation import (
    component_ablation_suite,
    stochasticity_sweeps,
    sweep_p,
    sweep_C,
    sweep_N,
    sweep_K,
)


def _add_mean_ci(
    df: pd.DataFrame,
    group_cols: Iterable[str],
    value_col: str = "abs_err",
    ci_level: float = 0.95,
) -> pd.DataFrame:
    """
    Compute mean and (two-sided) confidence interval for `value_col`
    within each group defined by `group_cols`.

    Returns a DataFrame with:
        group_cols + ["n_seeds", "mean", "ci_half_width", "ci_lower", "ci_upper"]
    """
    # Normal approximation quantile for 95% CI
    z = 1.96 if abs(ci_level - 0.95) < 1e-9 else 1.96

    def _agg(group: pd.DataFrame) -> pd.Series:
        vals = group[value_col].to_numpy(dtype=float)
        n = vals.size
        mean = float(vals.mean()) if n > 0 else 0.0
        if n > 1:
            std = float(vals.std(ddof=1))
            se = std / np.sqrt(n)
            h = z * se
        else:
            h = 0.0
        return pd.Series(
            {
                "n_seeds": n,
                "mean": mean,
                "ci_half_width": h,
                "ci_lower": mean - h,
                "ci_upper": mean + h,
            }
        )

    grouped = df.groupby(list(group_cols), as_index=False).apply(_agg)

    # groupby(..., as_index=False).apply() can return a MultiIndex; flatten it
    if isinstance(grouped.index, pd.MultiIndex):
        grouped = grouped.reset_index(drop=True)
    return grouped


def main() -> None:
    parser = argparse.ArgumentParser(
        description=(
            "Run Ablations & Sensitivity Analysis with multiple random seeds and "
            "report mean ± 95% confidence intervals."
        )
    )
    parser.add_argument(
        "--outdir",
        type=str,
        default="outputs_ci",
        help="Directory where CSV summaries will be written.",
    )
    parser.add_argument(
        "--repeats",
        type=int,
        default=10,
        help="Number of random seeds per configuration (default: 10).",
    )
    args = parser.parse_args()
    os.makedirs(args.outdir, exist_ok=True)

    repeats = int(args.repeats)

    # ------------------------------------------------------------------
    # 1) Component ablations:
    #    - CATSO/PATSO + mean backup (p=1)
    #    - CATSO/PATSO + max backup  (p=inf)
    #    - ScalarTSOpt (scalar TS+optimism)
    #    - Power-UCT
    # ------------------------------------------------------------------
    kd_list: List[tuple] = [(16, 1), (200, 1), (14, 3), (16, 3), (16, 4), (200, 2)]
    df_ablate = component_ablation_suite(kd_list, repeats=repeats)
    raw_path = os.path.join(args.outdir, "component_ablation_raw.csv")
    df_ablate.to_csv(raw_path, index=False)

    # mean ± 95% CI per (k, d, algo, p, n_sims)
    ablate_ci = _add_mean_ci(
        df_ablate,
        group_cols=["k", "d", "algo", "p", "n_sims"],
        value_col="abs_err",
    )
    ci_path = os.path.join(args.outdir, "component_ablation_ci.csv")
    ablate_ci.to_csv(ci_path, index=False)

    # ------------------------------------------------------------------
    # 2) Stochasticity sweeps:
    #    Deterministic / Low / Medium / High noise
    # ------------------------------------------------------------------
    df_stoch = stochasticity_sweeps(repeats=repeats)
    stoch_raw_path = os.path.join(args.outdir, "stochasticity_sweeps_raw.csv")
    df_stoch.to_csv(stoch_raw_path, index=False)

    stoch_ci = _add_mean_ci(
        df_stoch,
        group_cols=["noise_setting", "algo", "n_sims"],
        value_col="abs_err",
    )
    stoch_ci_path = os.path.join(args.outdir, "stochasticity_sweeps_ci.csv")
    stoch_ci.to_csv(stoch_ci_path, index=False)

    # ------------------------------------------------------------------
    # 3) Hyperparameter sensitivity sweeps
    #    (p, C, N, K) as described in Sec. 5.3
    # ------------------------------------------------------------------

    # 3a) Power-mean exponent p (PATSO)
    df_p_patso = sweep_p(k=8, d=3, repeats=repeats, algo="PATSO", n_sims=1000)
    p_patso_raw = os.path.join(args.outdir, "sweep_p_PATSO_raw.csv")
    df_p_patso.to_csv(p_patso_raw, index=False)

    p_patso_ci = _add_mean_ci(
        df_p_patso,
        group_cols=["algo", "hyper"],
        value_col="abs_err",
    )
    p_patso_ci_path = os.path.join(args.outdir, "sweep_p_PATSO_ci.csv")
    p_patso_ci.to_csv(p_patso_ci_path, index=False)

    # 3b) Power-mean exponent p (CATSO) – optional but handy to inspect
    df_p_catso = sweep_p(k=8, d=3, repeats=repeats, algo="CATSO", n_sims=1000)
    p_catso_raw = os.path.join(args.outdir, "sweep_p_CATSO_raw.csv")
    df_p_catso.to_csv(p_catso_raw, index=False)

    p_catso_ci = _add_mean_ci(
        df_p_catso,
        group_cols=["algo", "hyper"],
        value_col="abs_err",
    )
    p_catso_ci_path = os.path.join(args.outdir, "sweep_p_CATSO_ci.csv")
    p_catso_ci.to_csv(p_catso_ci_path, index=False)

    # 3c) Optimism constant C (PATSO)
    df_C_patso = sweep_C(k=8, d=3, repeats=repeats, algo="PATSO", n_sims=1000)
    C_patso_raw = os.path.join(args.outdir, "sweep_C_PATSO_raw.csv")
    df_C_patso.to_csv(C_patso_raw, index=False)

    C_patso_ci = _add_mean_ci(
        df_C_patso,
        group_cols=["algo", "hyper"],
        value_col="abs_err",
    )
    C_patso_ci_path = os.path.join(args.outdir, "sweep_C_PATSO_ci.csv")
    C_patso_ci.to_csv(C_patso_ci_path, index=False)

    # 3d) Number of atoms N (CATSO)
    df_N_catso = sweep_N(k=8, d=3, repeats=repeats, algo="CATSO", n_sims=1000)
    N_catso_raw = os.path.join(args.outdir, "sweep_N_CATSO_raw.csv")
    df_N_catso.to_csv(N_catso_raw, index=False)

    N_catso_ci = _add_mean_ci(
        df_N_catso,
        group_cols=["algo", "hyper"],
        value_col="abs_err",
    )
    N_catso_ci_path = os.path.join(args.outdir, "sweep_N_CATSO_ci.csv")
    N_catso_ci.to_csv(N_catso_ci_path, index=False)

    # 3e) Particle cap K (PATSO)
    df_K_patso = sweep_K(k=8, d=3, repeats=repeats, algo="PATSO", n_sims=1000)
    K_patso_raw = os.path.join(args.outdir, "sweep_K_PATSO_raw.csv")
    df_K_patso.to_csv(K_patso_raw, index=False)

    K_patso_ci = _add_mean_ci(
        df_K_patso,
        group_cols=["algo", "hyper"],
        value_col="abs_err",
    )
    K_patso_ci_path = os.path.join(args.outdir, "sweep_K_PATSO_ci.csv")
    K_patso_ci.to_csv(K_patso_ci_path, index=False)

    # ------------------------------------------------------------------
    # Done
    # ------------------------------------------------------------------
    print("Finished all ablations & sensitivity sweeps.")
    print(f"Raw CSVs and mean±95% CI summaries written under: {args.outdir}")


if __name__ == "__main__":
    main()

