## This script is for computing baseline experiments over different T values
import numpy as np
import time
import random
from scipy.stats import ortho_group
from sklearn.cluster import DBSCAN
from itertools import product
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
import sys

# Add project root to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from OffClusBandit.core.Environment import Environment
from OffClusBandit.core.utils import generate_items, edge_probability, generate_gap_items
from OffClusBandit.core.Tools import build_method_filename, create_algorithm_instance
from OffClusBandit.configs.default import methods_baseline,alpha_list, gamma_list, choose_gamma_alpha_list


def run_experiment(params):
    (is_empirical,method_name, envir, seed, dataset, uniforms_str, T, nu, d, m, L, choose_gamma_alpha, alpha, gamma,) = params
    try:
        algo = create_algorithm_instance(method_name, nu, d, T, L, choose_gamma_alpha, alpha, gamma)
        start_time = time.time()
        gamma_estimate = algo.run(envir)
        run_time = time.time() - start_time
        method = methods_baseline[method_name]
        file_name = build_method_filename(is_empirical = is_empirical, dataset=dataset, method = method,
            uniforms_str=uniforms_str, choose_gamma_alpha = choose_gamma_alpha,alpha=alpha, gamma=gamma, T=T, nu=nu, d=d, m=m, L=L, seed=seed)
        np.savez(
            file_name,
            seed=seed,
            test_rewards=algo.test_rewards,
            best_test_rewards=algo.best_test_rewards,
            run_time=run_time,
            gamma_estimate=gamma_estimate
        )
        print(f"Completed: {file_name}")
    except Exception as e:
        print(f"Error in experiment {method_name} with alpha={alpha}, gamma={gamma}, seed={seed}: {e}")

def get_parameter_list(config, gamma_list, alpha_list, choose_gamma_alpha_list):
    gammas_to_run = gamma_list if config['gamma_required'] else [0]
    alphas_to_run = alpha_list if config['alpha_varying'] else [0.1]
    choose_gamma_alphas_to_run = choose_gamma_alpha_list if config['choose_gamma_alpha_required'] else [0.1]
    return gammas_to_run, alphas_to_run, choose_gamma_alphas_to_run

def main_unified(is_empirical, T_list, num_users, d, m, L, pj, seed, dataset='', max_workers=None):
    tasks = []
    def _get_theta(thetam, num_users, m):
        k = int(num_users / m)
        theta = {i: thetam[0] for i in range(k)}
        for j in range(1, m):
            theta.update({i: thetam[j] for i in range(k * j, k * (j + 1))})
        return theta

    thetam = generate_gap_items(num_items=m, d=d, gap = 0.05)
    theta = _get_theta(thetam, num_users, m)

    def calculate_min_gap(thetam):
        m = len(thetam)
        min_gap = float('inf')
        for i in range(m):
            for j in range(i + 1, m):
                distance = np.linalg.norm(thetam[i] - thetam[j])
                min_gap = min(min_gap, distance)
        return min_gap

    print(f"Seed = {seed}")
    np.random.seed(seed)
    random.seed(seed)

    uniforms = ['uniform', 'half', 'arbitrary']
    def _get_half_frequency_vector(num_users, m):
        p0 = list(np.random.dirichlet(np.ones(m)))
        p = np.ones(num_users)
        k_ = int(num_users / m)
        for jj_ in range(m):
            for ii_ in range(k_ * jj_, k_ * (jj_ + 1)):
                p[ii_] = p0[jj_] / k_
        return list(p)

    ps = [
        list(np.ones(num_users) / num_users),
        _get_half_frequency_vector(num_users=num_users, m=m),
        list(np.random.dirichlet(np.ones(num_users)))
    ]
    p = ps[pj]
    envir = Environment(L=L, d=d, m=m, num_users=num_users, p=p, theta=theta)

    for T in T_list:
        for method_name, config in methods_baseline.items():
            gammas_to_run, alphas_to_run, choose_gamma_alphas_to_run = get_parameter_list(config, gamma_list, alpha_list,
                                                                                 choose_gamma_alpha_list)
            for choose_gamma_alpha_, alpha_, gamma_ in product(choose_gamma_alphas_to_run, alphas_to_run, gammas_to_run):
                params = (
                is_empirical, method_name, envir, seed, dataset, uniforms[pj], T, num_users, d, m, L, choose_gamma_alpha_,
                alpha_, gamma_)
                tasks.append(params)

    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(run_experiment, task) for task in tasks]
    for future in as_completed(futures):
        pass


def main_synthetic():
    dataset = 'synthetic'
    is_empirical = 0
    T_list = [100000 * i for i in range(1, 2)]
    nu = 1000
    m = 10
    pj = 0
    for pj in range(2):
        for seed in range(1, 2):
            main_unified(
                is_empirical = is_empirical, T_list=T_list, num_users=nu, d=20, m=m, L=20,
                pj=pj, seed=seed,
                dataset=dataset,max_workers=10
            )

def main_real():
    datasets = ['ml', 'yelp']
    is_empirical = 1
    T_list = [100000 * i for i in range(1, 2)]
    nu = 1000
    m = 10
    pj = 0
    for dataset in datasets:
        for seed in range(1,2):
            main_unified(
                is_empirical = is_empirical, T_list=T_list, num_users=nu, d=20, m=m, L=20,
                pj=pj, seed=seed,
                dataset=dataset,max_workers=10
            )

if __name__ == "__main__":
    main_real()
    main_synthetic()
    # pass