## This script is for computing experiments over different T values
import numpy as np
import time
import random
from scipy.stats import ortho_group
from sklearn.cluster import DBSCAN
from itertools import product
from concurrent.futures import ProcessPoolExecutor, as_completed
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from OffClusBandit.core.Environment import Environment
from OffClusBandit.core.utils import generate_items, edge_probability, generate_gap_items
from OffClusBandit.core.Tools import build_method_filename, create_algorithm_instance
from OffClusBandit.configs.default import alpha_list, choose_gamma_alpha_list
from OffClusBandit.configs.default import plot_gamma_methods as methods
from OffClusBandit.configs.default import gamma_vary_list as gamma_list
# from OffClusBandit.configs.default import methods_baseline as methods


def run_experiment(params):
    (is_empirical,method_name, envir, seed, dataset, uniforms_str, T, nu, d, m, L, choose_gamma_alpha, alpha, gamma,) = params
    try:
        algo = create_algorithm_instance(method_name, nu, d, T, L, choose_gamma_alpha, alpha, gamma)
        start_time = time.time()
        gamma_estimate = algo.run(envir)
        offline_learn_method = algo.offline_learn_method
        run_time = time.time() - start_time
        method = methods[method_name]
        file_name = build_method_filename(is_empirical=is_empirical, dataset=dataset, method = method,
            uniforms_str=uniforms_str, choose_gamma_alpha = choose_gamma_alpha,alpha=alpha, gamma=gamma, T=T, nu=nu, d=d, m=m, L=L, seed=seed, offline_learn_method=offline_learn_method)
        
        np.savez(
            file_name,
            seed=seed,
            test_rewards=algo.test_rewards,
            best_test_rewards=algo.best_test_rewards,
            run_time=run_time,
            gamma_estimate=gamma_estimate
        )
        print(f"Completed: {file_name}")
    except Exception as e:
        print(f"Error in experiment {method_name} with alpha={alpha}, gamma={gamma}, seed={seed}: {e}")

def get_parameter_list(config, gamma_list, alpha_list, choose_gamma_alpha_list):
    gammas_to_run = gamma_list if config['gamma_required'] else [0]
    alphas_to_run = alpha_list if config['alpha_varying'] else [0.1]
    choose_gamma_alphas_to_run = choose_gamma_alpha_list if config['choose_gamma_alpha_required'] else [0.1]
    return gammas_to_run, alphas_to_run, choose_gamma_alphas_to_run

def main_unified(is_empirical, T_list, num_users, d, m, L, pj, seed, dataset='', filename='', max_workers=None,gamma_list=None):
    tasks = []
    def _get_theta(thetam, num_users, m):
        k = int(num_users / m)
        theta = {i: thetam[0] for i in range(k)}
        for j in range(1, m):
            theta.update({i: thetam[j] for i in range(k * j, k * (j + 1))})
        return theta
    if filename != '':
        theta = np.load(filename)
    else:
        thetam = generate_gap_items(num_items=m, d=d, gap = 0.05)
        theta = _get_theta(thetam, num_users, m)
    print(f"Seed = {seed}")
    np.random.seed(seed)
    random.seed(seed)

    uniforms = ['uniform', 'half', 'arbitrary']
    def _get_half_frequency_vector(num_users, m):
        p0 = list(np.random.dirichlet(np.ones(m)))
        p = np.ones(num_users)
        k_ = int(num_users / m)
        for jj_ in range(m):
            for ii_ in range(k_ * jj_, k_ * (jj_ + 1)):
                p[ii_] = p0[jj_] / k_
        return list(p)

    ps = [
        list(np.ones(num_users) / num_users),
        _get_half_frequency_vector(num_users=num_users, m=m),
        list(np.random.dirichlet(np.ones(num_users)))
    ]
    p = ps[pj]

    envir = Environment(L=L, d=d, m=m, num_users=num_users, p=p, theta=theta)

    for T_index,T in enumerate(T_list):
        for method_name, config in methods.items():
            for gamma in gamma_list:
                params = (
                is_empirical, method_name, envir, seed, dataset, uniforms[pj], T, num_users, d, m, L, choose_gamma_alpha_list[0], alpha_list[0], gamma)
                tasks.append(params)

    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(run_experiment, task) for task in tasks]
    for future in as_completed(futures):
        pass


def main_synthetic():
    dataset = "synthetic"
    is_empirical = 0
    nu = 1000
    m = 10
    seed = 1
    T_list = [5000 * i for i in range(6, 9)]
    for pj in range(1):
        for seed in range(100,101):
            main_unified(
                is_empirical = is_empirical, T_list=T_list, num_users=nu, d=20, m=m, L=20,
                pj=pj, seed=seed,
            dataset=dataset,
            max_workers=10, gamma_list=gamma_list
            )
def main_real():
    T_list = [5000 * i for i in range(6, 9)]
    nu = 1000
    m = 10
    seed = 100
    pj = 0
    is_empirical = 1
    for dataset in ["yelp","ml"]:
        filename = f'OffClusBandit/data/datasets/{dataset}_1000user_d20.npy'
        main_unified(
            is_empirical = is_empirical, T_list=T_list, num_users=nu, d=20, m=m, L=20,
            pj=pj, seed=seed,
        dataset=dataset, filename=filename,
        max_workers=10, gamma_list=gamma_list
        )
if __name__ == "__main__":
    main_synthetic()
    main_real()
    pass