#!/usr/bin/env python3
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
import os
import argparse

from evaluation import eval_UCB, eval_TS, eval_PHE, eval_KLBandit, eval_SupLinUCB

# ---- JSON helpers ----
def _json_default(o):
    if isinstance(o, np.ndarray):
        return o.tolist()
    if isinstance(o, (np.floating, np.integer)):
        return o.item()
    return str(o)

def save_json(path, data):
    with open(path, "w") as f:
        json.dump(data, f, default=_json_default)

def load_json(path):
    with open(path, "r") as f:
        return json.load(f)

def require_regrets(d, name):
    if not isinstance(d, dict) or "regrets" not in d:
        raise KeyError(f"'{name}' JSON does not contain 'regrets' key: {d}")
    arr = np.asarray(d["regrets"], dtype=float)
    if arr.ndim != 2:
        raise ValueError(f"'{name}.regrets' shape must be (M, T). Got shape={arr.shape}")
    return arr

def main():
    parser = argparse.ArgumentParser(description="Run bandit algorithm experiments (with SupLinUCB).")
    parser.add_argument("--d", type=int, default=10, help="Feature dimension")
    parser.add_argument("--N", type=int, default=20, help="Number of actions")
    parser.add_argument("--T", type=int, default=5000, help="Time horizon")
    parser.add_argument("--M", type=int, default=1, help="Number of simulations")
    parser.add_argument("--lam", type=float, default=1.0, help="Ridge parameter (SupLinUCB)")
    parser.add_argument("--S", type=int, default=None, help="#levels for SupLinUCB (default ceil(ln T))")
    parser.add_argument("--rho", type=float, default=0.5, help="Context corr (generator arg) for SupLinUCB")
    parser.add_argument("--R", type=float, default=1.0, help="Noise std for rewards (SupLinUCB)")
    parser.add_argument("--seed", type=int, default=0, help="Base random seed (SupLinUCB)")
    args = parser.parse_args()

    d, N, T, M = args.d, args.N, args.T, args.M

    # set hyperparameters
    alpha = np.sqrt(d * np.log(T))
    alpha_set = [alpha]       # UCB
    v_set = [alpha]           # TS
    alpha_PHE_set = [alpha]   # PHE
    eta_set = [np.sqrt(T)]    # KL-EXP
    alpha_sup_set = [alpha]   # SupLinUCB

    # run evaluations
    best_UCB = eval_UCB(N=N, d=d, alpha_set=alpha_set, T=T, M=M, output=True)
    best_TS  = eval_TS(N=N, d=d, v_set=v_set, T=T, M=M, output=True)
    best_PHE = eval_PHE(N=N, d=d, alpha_set=alpha_PHE_set, T=T, M=M, output=True)
    best_KL  = eval_KLBandit(N=N, d=d, eta_set=eta_set, T=T, M=M, output=True)
    best_SL  = eval_SupLinUCB(
        N=N, d=d,
        alpha_set=alpha_sup_set,
        T=T, M=M,
        rho=args.rho, R=args.R,
        seed=args.seed,
        output=True
    )

    # create results folder
    os.makedirs("results", exist_ok=True)

    # save raw data (convert NumPy → Python native types)
    path_ucb = f"results/best_UCB_N{N}_d{d}.json"
    path_ts  = f"results/best_TS_N{N}_d{d}.json"
    path_phe = f"results/best_PHE_N{N}_d{d}.json"
    path_kl  = f"results/best_KL_N{N}_d{d}.json"
    path_sl  = f"results/best_SupLinUCB_N{N}_d{d}.json"

    save_json(path_ucb, best_UCB)
    save_json(path_ts,  best_TS)
    save_json(path_phe, best_PHE)
    save_json(path_kl,  best_KL)
    save_json(path_sl,  best_SL)

    # load raw data with validation
    D_ucb = load_json(path_ucb)
    D_ts  = load_json(path_ts)
    D_phe = load_json(path_phe)
    D_kl  = load_json(path_kl)
    D_sl  = load_json(path_sl)

    R_ucb = require_regrets(D_ucb, "UCB")
    R_ts  = require_regrets(D_ts,  "TS")
    R_phe = require_regrets(D_phe, "PHE")
    R_kl  = require_regrets(D_kl,  "KL-EXP")
    R_sl  = require_regrets(D_sl,  "SupLinUCB")

    # compute mean and std
    mean_UCB, std_UCB = R_ucb.mean(axis=0), R_ucb.std(axis=0)
    mean_TS,  std_TS  = R_ts.mean(axis=0),  R_ts.std(axis=0)
    mean_PHE, std_PHE = R_phe.mean(axis=0), R_phe.std(axis=0)
    mean_KL,  std_KL  = R_kl.mean(axis=0),  R_kl.std(axis=0)
    mean_SL,  std_SL  = R_sl.mean(axis=0),  R_sl.std(axis=0)

    print(f"Experiment finished for N={N}, d={d}. Results saved in 'results/' folder.")

if __name__ == "__main__":
    main()
