import numpy as np
import os
from util import *
from bandit import *
from alg_linear import *
from alg_logistic import *
from alg_new import *

if __name__ == "__main__":
    #-------------------
    # 실험 설정 부분
    #-------------------
    episodes     = 20
    # 여기서 (d, horizon) 쌍을 지정합니다
    env_settings = [
        (10, 2000),
        (50, 2000),
    ]
    K            = 100
    C            = 1  # C=1이면 fixed, C>1이면 contextual
    model        = "logistic"
    arm_set_type = "fixed" if C == 1 else "contextual"
    noise_var    = 1.0
    norm_theta   = 4.0
    norm_X       = 1.0

    # 결과 디렉토리
    filename = f"N{episodes}_{model}_{arm_set_type}_C{C}_σ{noise_var}"
    filename = "test3"
    base_dir = os.path.join("Results", model, filename)
    os.makedirs(base_dir, exist_ok=True)

    #-------------------
    # 알고리즘 목록
    #-------------------
    if model == "linear":
        algorithms = [
            (EpsilonGreedy, {"forced_exploration":K},                                                    "e-greedy"),
            (Greedy,        {"forced_exploration":K},                                                    "greedy"),
            (LinUCB,        {"sigma":np.sqrt(noise_var), "L":norm_X, "S":norm_theta},                    "LinUCB"),
            (LinTS,         {"beta":False},                                                              "LinTS"),
            (LinPHE,        {"a":0.5},                                                                   "LinPHE"),
            (RandLinUCB,    {"is_coupled":True, "sigma":np.sqrt(noise_var), "L":norm_X, "S":norm_theta}, "RandLinUCB"),
            (LinFP,         {"beta":False},                                                              "LinFP"),
            (FGTS,          {"eta":1.0, "lam":1.0, "b":float('inf'), "delta":0.01, "rho":100.0},         "FGTS"),
            (OPAS_FGP,      {"eta":1.0, "lam":1.0, "b":float('inf'), "delta":0.01, "rho":100.0},         "OPAS-FGP"),
            # (FGFP,           {},                                                                   "FGFP"),
        ]
    else:
        algorithms = [
            (LogGreedy,  {"epsilon": 0.05},                                                                                "log_e-greedy"),
            (UCBLog,     {"beta": True, "crs": 1.0, "delta": 0.01, "lam0": 1e-4, "S": norm_theta},                         "UCB-GLM"),
            (RandUCBLog, {"beta": False, "crs": 1.0, "pdist": "normal", "pnormal_std": 0.125, "is_coupled": True},         "RandUCBLog"),
            (LogTS,      {"beta": False, "crs": 1.0},                                                                      "GLM-TS"),
            (LogFPL,     {"a": 0.5},                                                                                       "LogFPL"),
            (LogFP,      {"beta": False, "crs": 1.0},                                                                      "GLM-FP"),
            (RSGLinCB,   {"S": norm_theta, "delta": 0.01, "lazy_update_fr": 1}, "RS-GLinCB")
        ]

    environments = [(f"{model} bandit ({arm_set_type}, K={K}, d={d}, C={C})", d, T)
                    for d, T in env_settings]

    for env_name, d, T in environments:
        print(f"\n===== Environment: {env_name}, T={T} =====")
        # 에피소드별 환경 생성
        envs = [Bandit(d=d, K=K, C=C, arm_set_type=arm_set_type, model=model,
                       norm_theta=norm_theta, norm_X=norm_X, noise_var=noise_var, seed=run)
                       for run in range(episodes)]

        res_dir = os.path.join(base_dir, env_name)
        os.makedirs(res_dir, exist_ok=True)

        for Alg, params, name in algorithms:
            out_path = os.path.join(res_dir, name + ".csv")
            if os.path.exists(out_path):
                print(f"[{name}] exists, skip.")
                continue
            regret, _ = evaluate(Alg, params, envs, T)
            cum = regret.cumsum(axis=0)
            np.savetxt(out_path, cum, delimiter=",")

    alg_names = [name for _, _, name in algorithms]
    plot_results(f"{model}/{filename}", environments, alg_names)