from __future__ import annotations
import os
import re
from pathlib import Path
from typing import Sequence, Tuple, Dict, List
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from util import *
from bandit import *
from alg_linear import *

CONST = {
    "EPISODES": 100,
    "HORIZON": 500_000,
    "K": 50,
    "MODEL": "linear",
    "ARM_SET_TYPE": "fixed",
    "D_LIST": list(range(1, 16)),
    "ALGORITHMS": [
        (LinTS, {"beta": True},  "LinTS"),
        (LinFP, {"mode": "normal", "beta": True}, "LinFP"),
    ],
    "OUT_DIR": Path("Results"),
    "CSV_NAME": "regret_scaling_d_results.csv",
    "FIG_NAME": "regret_scaling_plot.pdf",
}

def _parse_mean_std(txt: str) -> Tuple[float, float]:
    """'123.45±6.78' → (123.45, 6.78)"""
    m = re.match(r"([0-9.]+)±([0-9.]+)", txt)
    return (float(m.group(1)), float(m.group(2))) if m else (np.nan, np.nan)

def _set_mpl_style() -> None:
    mpl.rcParams.update({
        "axes.linewidth": .75,
        "grid.linewidth": .75,
        "lines.linewidth": 1,
        "patch.linewidth": 1.5,
        "xtick.major.size": 3,
        "ytick.major.size": 3,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
        "font.size": 14,
        "axes.titlesize": "large",
        "axes.labelsize": "medium",
        "xtick.labelsize": "medium",
        "ytick.labelsize": "medium",
        "legend.fontsize": "small",
        "text.usetex": False,
        # "text.usetex": True,
        # "text.latex.preamble": r"\usepackage{amsmath,amssymb}",
    })

def run_scaling_experiment(cfg: Dict = CONST) -> None:
    out_dir = cfg["OUT_DIR"]
    out_dir.mkdir(exist_ok=True)
    csv_path = out_dir / cfg["CSV_NAME"]

    alg_names = [name for *_ , name in cfg["ALGORITHMS"]]
    with csv_path.open("w") as f:
        f.write("d," + ",".join(alg_names) + "\n")

    all_results: Dict[str, List[Tuple[float, float]]] = {n: [] for n in alg_names}

    for d in cfg["D_LIST"]:
        print(f"\n===== d = {d} =====")
        envs = [
            Bandit(d=d, K=cfg["K"], norm_theta=1, norm_X=1,
                   seed=run, arm_set_type=cfg["ARM_SET_TYPE"],
                   model=cfg["MODEL"])
            for run in range(cfg["EPISODES"])
        ]

        row_values = []
        for alg_class, alg_params, name in cfg["ALGORITHMS"]:
            print(f"▶ {name}")
            mean_reg, std_reg, _ = evaluate_compare(alg_class, alg_params, envs, cfg["HORIZON"])
            all_results[name].append((mean_reg, std_reg))
            row_values.append(f"{mean_reg:.2f}±{std_reg:.2f}")

        with csv_path.open("a") as f:
            f.write(f"{d}," + ",".join(row_values) + "\n")

    print(f"\n✅ file saved: {csv_path}")
    _plot_results(csv_path, out_dir / cfg["FIG_NAME"], alg_names)

def _plot_results(csv_path: Path, fig_path: Path, algs: Sequence[str]) -> None:
    _set_mpl_style()
    df = pd.read_csv(csv_path)
    d_vals = df["d"]

    mean_df, std_df = pd.DataFrame(), pd.DataFrame()
    for alg in algs:
        parsed = df[alg].apply(_parse_mean_std)
        mean_df[alg] = parsed.apply(lambda x: x[0])
        std_df[alg]  = parsed.apply(lambda x: x[1])

    colors = {"LinFP": "blue", "LinTS": "black"}
    plt.figure(figsize=(3.5, 3))

    for alg in algs:
        y = mean_df[alg] / 10_000
        yerr = std_df[alg] / 10_000
        plt.errorbar(d_vals, y, yerr=yerr, label=alg,
                     color=colors.get(alg, None), marker="o",
                     capsize=3, lw=1)

    plt.xlabel(r"$d$")
    plt.ylabel("Cumulative Regret")
    plt.text(0.02, 1.03, r"(Unit: $10^{4}$)",
             transform=plt.gca().transAxes, fontsize=10)
    plt.xticks(np.arange(d_vals.min(), d_vals.max()+1, 2))
    plt.yticks(np.arange(0, 5, 1))
    plt.grid(True, linestyle="--", alpha=.6)
    plt.legend()
    plt.tight_layout()
    plt.savefig(fig_path, dpi=1200, bbox_inches="tight")
    plt.close()
    print(f"✅ graph saved: {fig_path}")

if __name__ == "__main__":
    run_scaling_experiment()



# import numpy as np
# import pandas as pd
# import matplotlib as mpl
# import matplotlib.pyplot as plt
# import os
# import re
# from util import *
# from bandit import *
# from alg_linear import *

# def run_scaling_experiment():
#     episodes = 100
#     horizon = 500000
#     K = 50
#     model = "linear"
#     arm_set_type = "fixed"
#     d_list = list(range(1, 16))

#     algorithms = [
#         (LinTS, {"beta": True}, "LinTS"),
#         (LinFP, {"mode": "normal", "beta": True}, "LinFP"),
#     ]

#     results = {alg_name: [] for _, _, alg_name in algorithms}
#     d_recorded = []

#     os.makedirs("Results", exist_ok=True)
#     csv_path = "Results/regret_scaling_d_results.csv"
#     plot_path = "Results/regret_scaling_plot.png"

#     with open(csv_path, "w") as f:
#         f.write("d," + ",".join(alg_name for _, _, alg_name in algorithms) + "\n")

#     for d in d_list:
#         print(f"\n================= Running d = {d} =================")
#         envs = []
#         for run in range(episodes):
#             np.random.seed(run)
#             env = Bandit(d=d, K=K, norm_theta=1, norm_X=1, seed=run,
#                          arm_set_type=arm_set_type, model=model)
#             envs.append(env)

#         regret_values = []
#         for alg_class, alg_params, alg_name in algorithms:
#             print(f"▶ Running {alg_name}")
#             mean_regret, std_regret, _ = evaluate(alg_class, alg_params, envs, horizon)

#             results[alg_name].append((mean_regret, std_regret))
#             regret_values.append(f"{mean_regret:.2f}±{std_regret:.2f}")

#         d_recorded.append(d)

#         with open(csv_path, "a") as f:
#             f.write(f"{d}," + ",".join(regret_values) + "\n")

#         # Plot with error bars
#         plt.figure()
#         for alg_name in results:
#             means = [m for m, s in results[alg_name]]
#             stds = [s for m, s in results[alg_name]]
#             plt.errorbar(d_recorded, means, yerr=stds, marker='o', label=alg_name, capsize=3)
#         plt.xlabel("d (dimension)")
#         plt.ylabel("Final Regret (mean ± std)")
#         plt.title(f"Regret vs d (T={horizon}, K={K}, episodes={episodes})")
#         plt.legend()
#         plt.grid(True, linestyle="--")
#         plt.tight_layout()
#         plt.savefig(plot_path)
#         plt.close()
#         print(f"📈 Plot updated at: {plot_path}")

#     print(f"\n✅ Final output saved completed: {csv_path}")
#     print(f"✅ Final plot saved completed: {plot_path}")

#     mpl.rcParams["axes.linewidth"] = 0.75
#     mpl.rcParams["grid.linewidth"] = 0.75
#     mpl.rcParams["lines.linewidth"] = 1
#     mpl.rcParams["patch.linewidth"] = 1.5
#     mpl.rcParams["xtick.major.size"] = 3
#     mpl.rcParams["ytick.major.size"] = 3
#     mpl.rcParams["pdf.fonttype"] = 42
#     mpl.rcParams["ps.fonttype"] = 42
#     mpl.rcParams["font.size"] = 14
#     mpl.rcParams["axes.titlesize"] = "large"
#     mpl.rcParams["axes.labelsize"] = "medium"
#     mpl.rcParams["xtick.labelsize"] = "medium"
#     mpl.rcParams["ytick.labelsize"] = "medium"
#     mpl.rcParams["legend.fontsize"] = "small"
#     mpl.rcParams["text.usetex"] = True
#     mpl.rcParams['text.latex.preamble'] = r'\usepackage{amsmath} \usepackage{amssymb}'

#     def parse_mean_std(value):
#         match = re.match(r"([0-9.]+)±([0-9.]+)", value)
#         return (float(match.group(1)), float(match.group(2))) if match else (np.nan, np.nan)

#     # Load and process data
#     df = pd.read_csv("Results/regret_scaling_d_results.csv").iloc[1:16]
#     d_vals = df['d'].values
#     algs = ["LinFP","LinTS"]

#     mean_df, std_df = pd.DataFrame(), pd.DataFrame()
#     for alg in algs:
#         parsed = df[alg].apply(parse_mean_std)
#         mean_df[label] = parsed.apply(lambda x: x[0])
#         std_df[label] = parsed.apply(lambda x: x[1])

#     colors = {"LinFP": "blue", "LinTS": "black"}
#     # Plot
#     plt.figure(figsize=(3.5, 3))
#     for label in algs:
#         y = mean_df[label] / 10000
#         plt.plot(d_vals, y, label=label, color=colors[label], marker='')

#     plt.xlabel(r"$d$")
#     plt.ylabel("Cumulative Regret")
#     plt.text(0, 1.05, "(Unit: $10^4$)", transform=plt.gca().transAxes, ha='center', fontsize=10)
#     plt.xticks(np.arange(d_vals.min(), d_vals.max() + 1, 2))
#     plt.xlim(d_vals.min(), d_vals.max())
#     plt.yticks(np.arange(0, 5, 1))
#     plt.grid(True)
#     plt.legend()
#     plt.tight_layout()
#     plt.savefig("regret_upper_bound_plot.pdf", format="pdf", dpi=1200, bbox_inches="tight")
#     plt.show()


# if __name__ == "__main__":
#     run_scaling_experiment()
