
import time
import itertools
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass, asdict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from env_synth_tree import SyntheticTree
from algorithms import CATSOSelector, PATSOSelector, ScalarTSOptSelector, UCTSelector, PowerUCTSelector
from mcts_runner import MCTSRunner


@dataclass
class ExpConfig:
    algo: str                       # 'CATSO', 'PATSO', 'ScalarTSOpt', 'PowerUCT', 'UCT'
    k: int = 8
    d: int = 3
    sigma: float = 0.5
    trans_noise: float = 0.5
    gamma: float = 1.0
    p: float = 1.0                  # power mean exponent (1 for mean, np.inf for max)
    C: float = 1.5                  # optimism constant for TS variants
    N: int = 100                    # atoms for CATSO
    K: int = 200                    # particle cap for PATSO
    ucb_c: float = 1.0              # UCB exploration constant
    seed: int = 0
    n_sims_schedule: Tuple[int, ...] = (50, 100, 200, 500, 1000)
    repeats: int = 5                # number of seeds per setting


def make_selector(cfg: ExpConfig):
    if cfg.algo == "CATSO":
        return CATSOSelector(k=cfg.k, C=cfg.C, N=cfg.N, p=cfg.p)
    if cfg.algo == "PATSO":
        return PATSOSelector(k=cfg.k, C=cfg.C, K=cfg.K, p=cfg.p)
    if cfg.algo == "ScalarTSOpt":
        return ScalarTSOptSelector(k=cfg.k, C=cfg.C, p=cfg.p)
    if cfg.algo == "PowerUCT":
        return PowerUCTSelector(k=cfg.k, ucb_c=cfg.ucb_c)
    if cfg.algo == "UCT":
        return UCTSelector(k=cfg.k, ucb_c=cfg.ucb_c)
    raise ValueError(f"Unknown algo {cfg.algo}")


def run_ablation_component(cfg: ExpConfig) -> pd.DataFrame:
    """Runs ablation study for a single algo & env setting across n_sims_schedule, averaged over repeats.
    Returns a tidy DataFrame with columns:
      ['algo','k','d','sigma','trans_noise','p','C','N','K','ucb_c','seed','n_sims','abs_err']
    """
    rows = []
    for rep in range(cfg.repeats):
        seed = cfg.seed + rep
        env = SyntheticTree(k=cfg.k, d=cfg.d, sigma=cfg.sigma, trans_noise=cfg.trans_noise, gamma=cfg.gamma, seed=seed)
        Vstar = env.true_V_root
        for n_sims in cfg.n_sims_schedule:
            # fresh MCTS per (seed, n_sims)
            selector = make_selector(cfg)
            runner = MCTSRunner(env=env, selector=selector, p=cfg.p, seed=seed)
            runner.run(n_sims)
            vhat = runner.root_value_estimate()
            abs_err = abs(Vstar - vhat)
            rows.append({
                "algo": cfg.algo, "k": cfg.k, "d": cfg.d, "sigma": cfg.sigma, "trans_noise": cfg.trans_noise,
                "p": cfg.p, "C": cfg.C, "N": cfg.N, "K": cfg.K, "ucb_c": cfg.ucb_c,
                "seed": seed, "n_sims": n_sims, "abs_err": abs_err
            })
    return pd.DataFrame(rows)


def component_ablation_suite(kd_list: List[Tuple[int,int]], repeats: int = 5,
                             n_sims_schedule: Tuple[int, ...] = (50,100,200,500,1000),
                             sigma: float = 0.5, trans_noise: float = 0.5) -> pd.DataFrame:
    """Implements Sec. 5.3 Component ablations:
        1) CATSO/PATSO + mean backup (p=1)
        2) CATSO/PATSO + max backup (p=inf)
        3) Scalar TS + optimism (scalar mean)
        4) Power-UCT
    """
    algos = [
        ("CATSO", 1.0), ("CATSO", float("inf")),
        ("PATSO", 1.0), ("PATSO", float("inf")),
        ("ScalarTSOpt", 1.0),
        ("PowerUCT", 1.0),
    ]
    all_rows = []
    for (k, d) in kd_list:
        for algo, p in algos:
            cfg = ExpConfig(algo=algo, k=k, d=d, sigma=sigma, trans_noise=trans_noise,
                            p=p, repeats=repeats, n_sims_schedule=n_sims_schedule)
            df = run_ablation_component(cfg)
            all_rows.append(df)
    return pd.concat(all_rows, ignore_index=True)


def stochasticity_sweeps(k: int = 8, d: int = 3, repeats: int = 5,
                         n_sims_schedule: Tuple[int, ...] = (50,100,200,500,1000)) -> pd.DataFrame:
    """Implements Sec. 5.3 stochasticity sweeps for Deterministic/Low/Medium/High noise levels."""
    settings = [
        ("deterministic", 0.0, 0.0),
        ("low", 0.25, 0.25),
        ("medium", 0.5, 0.5),
        ("high", 1.0, 0.75),
    ]
    algos = [("CATSO", 1.0), ("PATSO", 1.0), ("ScalarTSOpt", 1.0), ("PowerUCT", 1.0)]
    all_rows = []
    for label, sigma, trans_noise in settings:
        for algo, p in algos:
            cfg = ExpConfig(algo=algo, k=k, d=d, sigma=sigma, trans_noise=trans_noise,
                            p=p, repeats=repeats, n_sims_schedule=n_sims_schedule)
            df = run_ablation_component(cfg)
            df["noise_setting"] = label
            all_rows.append(df)
    return pd.concat(all_rows, ignore_index=True)


def sweep_p(k: int = 8, d: int = 3, repeats: int = 5, algo: str = "PATSO",
            n_sims: int = 1000) -> pd.DataFrame:
    """Sweep p in {1,2,4,8,inf}."""
    p_values = [1.0, 2.0, 4.0, 8.0, float("inf")]
    rows = []
    for p in p_values:
        cfg = ExpConfig(algo=algo, k=k, d=d, p=p, repeats=repeats, n_sims_schedule=(n_sims,))
        df = run_ablation_component(cfg)
        df["hyper"] = f"p={p}"
        rows.append(df)
    return pd.concat(rows, ignore_index=True)


def sweep_C(k: int = 8, d: int = 3, repeats: int = 5, algo: str = "PATSO",
            n_sims: int = 1000) -> pd.DataFrame:
    """Sweep C in {0.5, 1.0, 2.0, 4.0}."""
    C_values = [0.5, 1.0, 2.0, 4.0]
    rows = []
    for C in C_values:
        cfg = ExpConfig(algo=algo, k=k, d=d, C=C, repeats=repeats, n_sims_schedule=(n_sims,))
        df = run_ablation_component(cfg)
        df["hyper"] = f"C={C}"
        rows.append(df)
    return pd.concat(rows, ignore_index=True)


def sweep_N(k: int = 8, d: int = 3, repeats: int = 5, algo: str = "CATSO",
            n_sims: int = 1000) -> pd.DataFrame:
    """Sweep N (atoms) in {50,100,200,400} for CATSO."""
    N_values = [50, 100, 200, 400]
    rows = []
    for N in N_values:
        cfg = ExpConfig(algo=algo, k=k, d=d, N=N, repeats=repeats, n_sims_schedule=(n_sims,))
        df = run_ablation_component(cfg)
        df["hyper"] = f"N={N}"
        rows.append(df)
    return pd.concat(rows, ignore_index=True)


def sweep_K(k: int = 8, d: int = 3, repeats: int = 5, algo: str = "PATSO",
            n_sims: int = 1000) -> pd.DataFrame:
    """Sweep K (particle cap) in {50,100,200,400} for PATSO."""
    K_values = [50, 100, 200, 400]
    rows = []
    for K in K_values:
        cfg = ExpConfig(algo=algo, k=k, d=d, K=K, repeats=repeats, n_sims_schedule=(n_sims,))
        df = run_ablation_component(cfg)
        df["hyper"] = f"K={K}"
        rows.append(df)
    return pd.concat(rows, ignore_index=True)


# --------- plotting helpers ---------
def plot_value_error_curves(df: pd.DataFrame, title: str, savepath: Optional[str] = None):
    """Plot abs error vs #sims for each algo, averaged over seeds.
    (One figure per call, matplotlib only, default colors, no style settings per tool instructions.)"""
    plt.figure()
    # average over seeds
    g = df.groupby(["algo", "n_sims"], as_index=False)["abs_err"].mean()
    for algo in sorted(g["algo"].unique()):
        sub = g[g["algo"] == algo]
        plt.plot(sub["n_sims"].values, sub["abs_err"].values, marker="o", label=algo)
    plt.xlabel("# simulations")
    plt.ylabel("Root value abs. error")
    plt.title(title)
    plt.legend()
    if savepath:
        plt.savefig(savepath, bbox_inches="tight")
    plt.close()


def plot_bar_sensitivity(df: pd.DataFrame, title: str, savepath: Optional[str] = None):
    """Bar plot for single n_sims sweeps (e.g., hyperparameter sweeps)."""
    plt.figure()
    g = df.groupby(["hyper"], as_index=False)["abs_err"].mean()
    plt.bar(g["hyper"].values, g["abs_err"].values)
    plt.ylabel("Abs. error (mean over seeds)")
    plt.title(title)
    plt.xticks(rotation=45, ha="right")
    if savepath:
        plt.savefig(savepath, bbox_inches="tight")
    plt.close()


# --------- CLI examples ---------
if __name__ == "__main__":
    import argparse, os
    parser = argparse.ArgumentParser()
    parser.add_argument("--outdir", type=str, default="outputs")
    args = parser.parse_args()
    os.makedirs(args.outdir, exist_ok=True)

    # 1) Component ablations on multiple (k,d) pairs similar to Fig. 3
    kd_list = [(16,1),(200,1),(14,3),(16,3),(16,4),(200,2)]
    df_ablate = component_ablation_suite(kd_list, repeats=3)  # keep repeats small for demo
    df_ablate.to_csv(os.path.join(args.outdir, "ablation_components.csv"), index=False)
    # plot for one kd (example)
    one = df_ablate[(df_ablate.k==16)&(df_ablate.d==3)&(df_ablate.sigma==0.5)&(df_ablate.trans_noise==0.5)]
    plot_value_error_curves(one, title="Component ablations (k=16,d=3)", savepath=os.path.join(args.outdir,"ablation_k16d3.png"))

    # 2) Stochasticity sweeps
    df_noise = stochasticity_sweeps(k=8, d=3, repeats=3)
    df_noise.to_csv(os.path.join(args.outdir, "stochasticity_sweeps.csv"), index=False)

    # 3) Hyperparameter sensitivity
    df_p = sweep_p(k=8, d=3, repeats=3, algo="PATSO", n_sims=1000)
    df_p.to_csv(os.path.join(args.outdir, "sweep_p_PATSO.csv"), index=False)
    plot_bar_sensitivity(df_p, title="PATSO: sensitivity to p (n_sims=1000)",
                         savepath=os.path.join(args.outdir,"sweep_p_PATSO.png"))

    df_C = sweep_C(k=8, d=3, repeats=3, algo="PATSO", n_sims=1000)
    df_C.to_csv(os.path.join(args.outdir, "sweep_C_PATSO.csv"), index=False)
    plot_bar_sensitivity(df_C, title="PATSO: sensitivity to C (n_sims=1000)",
                         savepath=os.path.join(args.outdir,"sweep_C_PATSO.png"))

    df_N = sweep_N(k=8, d=3, repeats=3, algo="CATSO", n_sims=1000)
    df_N.to_csv(os.path.join(args.outdir, "sweep_N_CATSO.csv"), index=False)
    plot_bar_sensitivity(df_N, title="CATSO: sensitivity to N (n_sims=1000)",
                         savepath=os.path.join(args.outdir,"sweep_N_CATSO.png"))

    df_K = sweep_K(k=8, d=3, repeats=3, algo="PATSO", n_sims=1000)
    df_K.to_csv(os.path.join(args.outdir, "sweep_K_PATSO.csv"), index=False)
    plot_bar_sensitivity(df_K, title="PATSO: sensitivity to K (n_sims=1000)",
                         savepath=os.path.join(args.outdir,"sweep_K_PATSO.png"))
    print("Done. CSVs and example plots written to", args.outdir)
