import numpy as np
import time
import pickle
import os
from bandit import *
from util import *
from selection import *
import argparse
from tqdm import tqdm


def run_selector_comparison(env, info, T, REPEATS, update_frequency, d, N, K_list, selector_dict, **kwargs):
    seed = kwargs.get("seed", 12345)
    C = kwargs.get("C", 50)
    noise_scale = kwargs.get("noise_scale", 0.1)

    # Pre-generate REPEATS different (theta_star, X_set, z_dicts) tuples
    repeat_configs = []
    for i in range(REPEATS):
        np.random.seed(seed+i)
        theta_star = np.random.randn(d)
        theta_star /= np.linalg.norm(theta_star)
        X_set = get_contexts(env, N, d, theta_star, C=C, noise_scale=noise_scale, seed=seed+i)
        z_dicts = {
            c: {(j, k): X[j] - X[k] for j in range(N) for k in range(j + 1, N)}
            for c, X in enumerate(X_set)
        }
        repeat_configs.append((theta_star, X_set, z_dicts))

    results_by_K_pl = {K: {} for K in K_list}
    results_by_K_rb = {K: {} for K in K_list}

    for K in K_list:
        eta_pl = (1 + 3 * np.sqrt(2)) / 2
        eta_rb = (1 + np.sqrt(6)) / 2

        for name, selector in selector_dict.items():
            print(f"Running selector: {name} (K={K})")

            all_trials_pl, all_trials_rb = [], []
            
            for repeat_id, (theta_star, X_set, z_dicts) in enumerate(tqdm(repeat_configs)):    
                
                pl_regret_block, rb_regret_block = [], []
                mean_pl_regrets, mean_rb_regrets = [], []

                if name == "DopeWolfe":
                    V_inv_dict = {c: np.eye(d) for c in range(len(X_set))}
                    theta_pl_dict = {c: np.zeros(d) for c in range(len(X_set))}
                    theta_rb_dict = {c: np.zeros(d) for c in range(len(X_set))}
                    H_pl_dict = {c: 6 * np.sqrt(2) * (1 + 3 * np.sqrt(2)) * np.identity(d) for c in range(len(X_set))}
                    H_rb_dict = {c: 6 * np.sqrt(2) * (1 + np.sqrt(6)) * np.identity(d) for c in range(len(X_set))}
                else:
                    theta_pl, theta_rb = np.zeros(d), np.zeros(d)
                    H_pl = 6 * np.sqrt(2) * (1 + 3 * np.sqrt(2)) * np.identity(d)
                    H_rb = 6 * np.sqrt(2) * (1 + np.sqrt(6)) * np.identity(d)

                # start_time = time.time()
                
                for t in range(1, T + 1):
                    c_t = np.random.randint(len(X_set))
                    X = X_set[c_t]

                    if name == "DopeWolfe":
                        V_inv = V_inv_dict[c_t]
                        theta_pl, theta_rb = theta_pl_dict[c_t], theta_rb_dict[c_t]
                        H_pl, H_rb = H_pl_dict[c_t], H_rb_dict[c_t]
                        z_dict = z_dicts[c_t]
                        V_inv, S = selector(X, K, V_inv, z_dict)
                        V_inv_dict[c_t] = V_inv
                        indices_pl, indices_rb = S, S
                    else:
                        a_ref_pl, indices_pl = selector(X, H_pl, K)
                        a_ref_rb, indices_rb = selector(X, H_rb, K)

                    ranking_pl = generate_pl_ranking(indices_pl, X, theta_star)
                    ranking_rb = generate_pl_ranking(indices_rb, X, theta_star)
                    pairs = break_into_pairs(ranking_rb)

                    theta_pl = online_update_pl(theta_pl, ranking_pl, X, H_pl, eta_pl)
                    theta_rb = online_update_rb(theta_rb, pairs, X, H_rb, eta_rb)
                    
                    H_pl += pl_hessian(theta_pl, X, ranking_pl)
                    H_rb += rb_hessian(theta_rb, X, pairs)
                    
                    if name == "DopeWolfe":
                        theta_pl_dict[c_t], theta_pl_dict[c_t] = theta_pl, theta_rb
                        H_pl_dict[c_t], H_rb_dict[c_t] = H_pl, H_rb

                    pl_val = X[np.argmax(X @ theta_pl)] @ theta_star
                    rb_val = X[np.argmax(X @ theta_rb)] @ theta_star
                    opt_val = np.max(X @ theta_star)

                    pl_regret_block.append(opt_val - pl_val)
                    rb_regret_block.append(opt_val - rb_val)

                    if t % update_frequency == 0 or t == 1:
                        mean_pl_regrets.append(np.mean(pl_regret_block))
                        mean_rb_regrets.append(np.mean(rb_regret_block))
                        pl_regret_block, rb_regret_block = [], []

                if T % update_frequency != 0:
                    mean_pl_regrets.append(np.mean(pl_regret_block))
                    mean_rb_regrets.append(np.mean(rb_regret_block))

                all_trials_pl.append(mean_pl_regrets)
                all_trials_rb.append(mean_rb_regrets)

                # end_time = time.time()
                # print("time_elapsed: ", end_time - start_time)
                
            mean_pl = np.mean(all_trials_pl, axis=0)
            std_pl = np.std(all_trials_pl, axis=0)
            mean_rb = np.mean(all_trials_rb, axis=0)
            std_rb = np.std(all_trials_rb, axis=0)

            results_by_K_pl[K][name] = (mean_pl, std_pl)
            results_by_K_rb[K][name] = (mean_rb, std_rb)

    plot_selector_regret_subplots_by_method(
        results_by_K_pl,
        results_by_K_rb,
        f"plot/selector_comparison_{env}_T{T}_d{d}_N{N}_K{K}_C{C}_{info}.pdf",
        T,
        K_list,
        update_frequency
    )

    os.makedirs("results_checkpoint", exist_ok=True)
    with open(f"results_checkpoint/selector_comparison_results_{env}_{info}.pkl", "wb") as f:
        pickle.dump({
            "results_by_K_pl": results_by_K_pl,
            "results_by_K_rb": results_by_K_rb,
            "K_list": K_list,
            "T": T,
            "REPEATS": REPEATS
        }, f)

