import numpy as np
import time
import traceback  # Put at the top of the file
import random
from scipy.stats import ortho_group
from sklearn.cluster import DBSCAN
from itertools import product
from concurrent.futures import ProcessPoolExecutor, as_completed
import os
import sys

# Add project root to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from OffClusBandit.core.Environment import Environment
from OffClusBandit.core.utils import generate_items, edge_probability, generate_gap_items
from OffClusBandit.core.Tools import build_method_filename, create_algorithm_instance
from OffClusBandit.configs.default import methods, alpha_list, synthetic_gamma_list, choose_gamma_alpha_list, long_T_synthetic_gamma_list,yelp_gamma_list,ml_gamma_list,long_T_yelp_gamma_list,long_T_ml_gamma_list
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")

def run_experiment(params):
    (is_empirical,method_name, envir, seed, dataset, uniforms_str, T, nu, d, m, L, choose_gamma_alpha, alpha, gamma,) = params
    try:
        # Build filename first and check if it already exists
        method = methods[method_name]
        file_name = build_method_filename(is_empirical=is_empirical, dataset=dataset, method = method,
            uniforms_str=uniforms_str, choose_gamma_alpha = choose_gamma_alpha,alpha=alpha, gamma=gamma, T=T, nu=nu, d=d, m=m, L=L, seed=seed, offline_learn_method="random")  # filename will be updated later
        
        # Check whether the file already exists
        if os.path.exists(file_name):
            print(f"exist: {file_name}")
            return
        
        # Only run the algorithm if the file does not exist
        algo = create_algorithm_instance(method_name, nu, d, T, L, choose_gamma_alpha, alpha, gamma)
        start_time = time.time()
        gamma_estimate = algo.run(envir)
        offline_learn_method = algo.offline_learn_method
        run_time = time.time() - start_time
        
        # Rebuild the filename including offline_learn_method
        file_name = build_method_filename(is_empirical=is_empirical, dataset=dataset, method = method,
            uniforms_str=uniforms_str, choose_gamma_alpha = choose_gamma_alpha,alpha=alpha, gamma=gamma, T=T, nu=nu, d=d, m=m, L=L, seed=seed, offline_learn_method=offline_learn_method)
        
        np.savez(
            file_name,
            seed=seed,
            test_rewards=algo.test_rewards,
            best_test_rewards=algo.best_test_rewards,
            run_time=run_time,
            gamma_estimate=gamma_estimate
        )
        print(f"Completed: {file_name}")
    # except Exception as e:
    #     print(f"Error in experiment {method_name} with alpha={alpha}, gamma={gamma}, seed={seed}: {e}")

    except Exception as e:
        traceback.print_exc()  # Print full stack trace with error line numbers
        print(f"Error in experiment {method_name} with alpha={alpha}, gamma={gamma}, seed={seed}: {e}")
        # Alternatively: raise  # to interrupt and show the standard traceback

def get_parameter_list(config, gamma_list, alpha_list, choose_gamma_alpha_list):
    gammas_to_run = gamma_list if config['gamma_required'] else [0]
    alphas_to_run = alpha_list if config['alpha_varying'] else [0.1]
    choose_gamma_alphas_to_run = choose_gamma_alpha_list if config['choose_gamma_alpha_required'] else [0.1]
    return gammas_to_run, alphas_to_run, choose_gamma_alphas_to_run

def main_unified(is_empirical, T_list, num_users, d, m, L, pj, seed, dataset='', filename='', max_workers=None,best_gamma_list=None):
    tasks = []
    def _get_theta(thetam, num_users, m):
        k = int(num_users / m)
        theta = {i: thetam[0] for i in range(k)}
        for j in range(1, m):
            theta.update({i: thetam[j] for i in range(k * j, k * (j + 1))})
        return theta
    if filename != '':
        theta = np.load(filename)
    else:
        thetam = generate_gap_items(num_items=m, d=d, gap = 0.05)
        theta = _get_theta(thetam, num_users, m)
    print(f"Seed = {seed}")
    np.random.seed(seed)
    random.seed(seed)

    uniforms = ['uniform', 'half', 'arbitrary']
    def _get_half_frequency_vector(num_users, m):
        p0 = list(np.random.dirichlet(np.ones(m)))
        p = np.ones(num_users)
        k_ = int(num_users / m)
        for jj_ in range(m):
            for ii_ in range(k_ * jj_, k_ * (jj_ + 1)):
                p[ii_] = p0[jj_] / k_
        return list(p)

    ps = [
        list(np.ones(num_users) / num_users),
        _get_half_frequency_vector(num_users=num_users, m=m),
        list(np.random.dirichlet(np.ones(num_users)))
    ]
    p = ps[pj]

    envir = Environment(L=L, d=d, m=m, num_users=num_users, p=p, theta=theta)

    for T_index,T in enumerate(T_list):
        for method_name, config in methods.items():
            params = (
            is_empirical, method_name, envir, seed, dataset, uniforms[pj], T, num_users, d, m, L, choose_gamma_alpha_list[0], alpha_list[0], best_gamma_list[T_index])
            tasks.append(params)

    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(run_experiment, task) for task in tasks]
    for future in as_completed(futures):
        pass


def main_synthetic():
    dataset = "synthetic"
    is_empirical = 0
    nu = 1000
    m = 10
    seed = 1
    pj = 0
    # T_lists = [[5000 * i for i in range(1, 21)], [200000 * i for i in range(1, 6)]]
    # T_lists = [[5000 * i for i in range(6, 7)]]
    # T_lists = [[5000 * i for i in range(1, 21)]]
    T_lists = [[200000 * i for i in range(1, 6)]]
    for T_list in T_lists:
        if T_list[0] == 200000:
            best_gamma_list = long_T_synthetic_gamma_list
        else:
            best_gamma_list = synthetic_gamma_list
        for pj in range(0,1):
            for seed in range(12,13):
                main_unified(
                    is_empirical = is_empirical, T_list=T_list, num_users=nu, d=20, m=m, L=20,
                    pj=pj, seed=seed,
                    dataset=dataset,max_workers=len(T_list), best_gamma_list=best_gamma_list # best_gamma_list
                )
def main_real():
    # T_lists = [[5000 * i for i in range(1, 21)], [200000 * i for i in range(1,6)]]
    # T_lists = [[5000 * i for i in range(1, 21)]]
    T_lists = [[5000 * i for i in range(1, 2)]]
    # T_lists = [[200000 * i for i in range(1, 6)]]
    nu = 1000
    m = 10
    seed = 1
    pj = 0
    is_empirical = 1
    for dataset in ["yelp","ml"]:
        filename = f'OffClusBandit/data/datasets/{dataset}_1000user_d20.npy'
        for T_list in T_lists:
            if T_list[-1] == 100000:
                if dataset == "yelp":
                    best_gamma_list = yelp_gamma_list
                else:
                    best_gamma_list = ml_gamma_list
            else:
                if dataset == "yelp":
                    best_gamma_list = long_T_yelp_gamma_list
                else:
                    best_gamma_list = long_T_ml_gamma_list
            for seed in range(12,13):
                main_unified(
                    is_empirical = is_empirical, T_list=T_list, num_users=nu, d=20, m=m, L=20,
                    pj=pj, seed=seed,
                    dataset=dataset, filename=filename,
                    max_workers=len(T_list), best_gamma_list=best_gamma_list
                )
if __name__ == "__main__":
    main_synthetic()
    # main_real() 