import pickle
import os
import numpy as np
import multiprocessing as mp
from functools import partial
import sys
from experiment import experiment_synthetic, experiment_real
import os

os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'

# Function for single synthetic data experiment
def run_single_synthetic_experiment(setting, K, seed, T, m, d, epsilon, delta, signal_norm, sigma, norm_outliers, eta, lbd_factors, n_fold, T_iteration, path, L, policy, offline_learn_method):
    """Run single synthetic data experiment"""
    try:
        n = T//m
        ratios = np.sqrt(d / n) * np.ones(m)
        lbd_list = [factor * ratios for factor in lbd_factors]

        dataset_tag = 'synthetic'
        algo_tag = 'ARMUL'
        offline_tag = 'LinUCB_ind' if offline_learn_method == 'LinUCB_ind' else 'random'
        userdist_tag = 'uniform' if policy == 'uniform' else 'half'
        policy_tag = f'{userdist_tag}_{offline_tag}'
        nu = m  # Convention nu=1000 corresponds to number of users; current experiment m=1000 corresponds to this
        save_file = path + '/{}_{}_{}_T{}nu{}d{}m{}L{}_{}.npz'.format(dataset_tag, algo_tag, policy_tag, T, nu, d, K, L, seed)
        if os.path.exists(save_file):
            print('seed {} T {} is done before'.format(seed, T))
            return

        print('starting seed {} T {} (policy: {})'.format(seed, T, policy_tag))
        
        test = experiment_synthetic(n, m, d)
        offline_method = 'LinUCB_ind' if offline_learn_method == 'LinUCB_ind' else 'random'
        test.getsamples(setting=setting, K=K, signal_norm=signal_norm, sigma=sigma, 
                        delta=delta, epsilon=epsilon, norm_outliers=norm_outliers, seed=seed, offline_method=offline_method, policy=policy)
        results = test.run(lbd_list, n_fold=n_fold, eta=eta, T=T_iteration)
        np.savez_compressed(save_file, **results)
        print('seed {} T {} finished (policy: {})'.format(seed, T, policy_tag))
            
    except Exception as e:
        print(f'Error in seed {seed} T {T}: {str(e)}')

# run synthetic data experiments
def run_synthetic(T_lists, L = 20, n_processes=None, policy='uniform', offline_learn_method='random'): 
    setting = 'clustered'
    K = 10  # Increased from 3 to 10, increase number of clusters
    S = 10
    seed_list = np.array(range(1, S + 1)) # random seeds
    for i in range(S):
        seed_list[i] = seed_list[i]
    m = 1000 # user
    T_iteration = 10
    n_fold = 2 # for CV


    # n, m, d = 200, 10, 20
    d = 20
    epsilon = 0  # Changed to single value
    delta = 0    # Changed to single value
    signal_norm = 2
    sigma = 1
    norm_outliers = 2

    eta = 0.01 # step-size (reduced learning rate)
    lbd_factors = [0.5]

    ###############################################

    path = 'artifacts/output_data'
    if not os.path.exists(path):
        os.makedirs(path)
    
    # T_lists = [5000 * i for i in range(1, 21)]

    
    # Create all experiment parameter combinations
    experiments = []

    for seed in seed_list:
        for T in T_lists:
            experiments.append((setting, K, seed, T, m, d, epsilon, delta, 
                             signal_norm, sigma, norm_outliers, eta, lbd_factors, 
                             n_fold, T_iteration, path, L, policy, offline_learn_method))
    
    # Set number of processes
    if n_processes is None:
        n_processes = min(mp.cpu_count(), len(experiments))

    # Execute experiments in parallel
    with mp.Pool(processes=n_processes) as pool:
        pool.starmap(run_single_synthetic_experiment, experiments)

# Function for single real data experiment
def run_single_real_experiment(setting, K, seed, T, m, d, epsilon, delta, signal_norm, sigma, norm_outliers, eta, lbd_factors, n_fold, T_iteration, path, dataset_name, L, policy, offline_learn_method):
    """Run single real data experiment"""
    try:
        n = T//m
        ratios = np.sqrt(d / n) * np.ones(m)
        lbd_list = [factor * ratios for factor in lbd_factors]
        
        # Check if already completed - use new naming format
        dataset_tag = dataset_name  # 'ml' or 'yelp'
        algo_tag = 'ARMUL'  # Replace XMeans_improve with ARMUL
        offline_tag = 'LinUCB_ind' if offline_learn_method == 'LinUCB_ind' else 'random'
        userdist_tag = 'uniform' if policy == 'uniform' else 'half'
        policy_tag = f'{userdist_tag}_{offline_tag}'
        nu = m  # Convention nu=1000 corresponds to number of users; current experiment m=1000 corresponds to this
        save_file = path + '/{}_{}_{}_T{}nu{}d{}m{}L{}_{}.npz'.format(dataset_tag, algo_tag, policy_tag, T, nu, d, K, L, seed)
        # if os.path.exists(save_file):
        #     print('seed {} T {} is done before'.format(seed, T))
        #     return
        
        print('starting seed {} T {} for dataset {} (policy: {})'.format(seed, T, dataset_name, policy_tag))
        
        test = experiment_real(n, m, d)
        offline_method = 'linucb' if offline_learn_method == 'LinUCB_ind' else 'random'
        test.getsamples(setting=setting, K=K, signal_norm=signal_norm, sigma=sigma, 
                       delta=delta, epsilon=epsilon, norm_outliers=norm_outliers, seed=seed, dataset=dataset_name, offline_method=offline_method, policy=policy)
        results = test.run(lbd_list, n_fold=n_fold, eta=eta, T=T_iteration)
        
        # Save results - use new naming format
        np.savez_compressed(save_file, **results)
        
        print('seed {} T {} finished for dataset {} (policy: {})'.format(seed, T, dataset_name, policy_tag))
        
    except Exception as e:
        print(f'Error in seed {seed} T {T} for dataset {dataset_name}: {str(e)}')

# run real data experiments
def run_real(T_lists, dataset_name='ml', L = 20, n_processes=None, policy='uniform', offline_learn_method='random'): 
    # dataset_name: 'ml' or 'yelp'
    setting = 'clustered'
    K = 10  # Increased from 3 to 10, increase number of clusters
    S = 10
    seed_list = np.array(range(1, S + 1)) # random seeds
    for i in range(S):
        seed_list[i] = seed_list[i]
    d = 20
    epsilon = 0  # Changed to single value
    delta = 0    # Changed to single value

    signal_norm = 2
    sigma = 1
    norm_outliers = 2

    eta = 0.01 # step-size (reduced learning rate)
    lbd_factors = [0.5]
    path = 'artifacts/output_data_real'
    if not os.path.exists(path):
        os.makedirs(path)
    m = 1000
    T_iteration = 10
    n_fold = 2 # for CV
    
    # Create all experiment parameter combinations
    experiments = []
    for T in T_lists:
        for seed in seed_list:
            experiments.append((setting, K, seed, T, m, d, epsilon, delta, 
                             signal_norm, sigma, norm_outliers, eta, lbd_factors, 
                             n_fold, T_iteration, path, dataset_name, L, policy, offline_learn_method))
    
    # Set number of processes
    if n_processes is None:
        n_processes = min(mp.cpu_count(), len(experiments))
    
    print(f'Running {len(experiments)} experiments with {n_processes} processes for dataset {dataset_name}')
    
    # Execute experiments in parallel
    with mp.Pool(processes=n_processes) as pool:
        pool.starmap(run_single_real_experiment, experiments)

if __name__ == "__main__":
    # Default run uniform user distribution + random offline selection;
    # Switch to LinUCB offline selection: offline_learn_method='LinUCB_ind'
    T_lists = [5000*i for i in range(1,21)]
    for offline_learn_method in ['random', 'LinUCB_ind']:
        run_real(T_lists, dataset_name='ml', n_processes=10, policy='uniform', offline_learn_method=offline_learn_method)
        run_real(T_lists, dataset_name='yelp', n_processes=10, policy='uniform', offline_learn_method=offline_learn_method)
        for policy in ['uniform', 'half']:
            run_synthetic(T_lists, n_processes=10, policy=policy, offline_learn_method=offline_learn_method)

    T_lists = [200000*i for i in range(1,6)]
    offline_learn_method = 'random'
    run_real(T_lists, dataset_name='ml', n_processes=10, policy='uniform', offline_learn_method=offline_learn_method)
    run_real(T_lists, dataset_name='yelp', n_processes=10, policy='uniform', offline_learn_method=offline_learn_method)
    run_synthetic(T_lists, n_processes=10, policy='uniform', offline_learn_method=offline_learn_method)