from tqdm import tqdm

from utils.constants import *
from utils.general_utils import save_results, create_plots_from_files, generate_normal_x0_x_star
from utils.simulations_utils import comparison

RHO1_VALS = [0, 0.5, -0.5]
RHO2_VALS = [0, 0.5, -0.5]
LAMBDA_X_TICKS = np.round(np.arange(0, 1.1, 0.1), 1)


# run functions
def run_normal_rho_lambda_comparison(base_desc='normal_rho1_rho2_lambda_comparison_no_regret', rho1=0, rho2=0, 
                                     lam_vals=LAM_VALS_, k=BASE_K, n=BASE_N, s=BASE_S, B=B_, 
                                     eta=ETA_REGRET_, T=T_REGRET_, eps=EPS_REGRET_):
    """Runs the normal distribution lambda comparison simulation of the proportional ranking functions with specific ro values.
    
    Args:
        base_desc (str, optional): Template description of the simulation. Default to 'normal_rho1_rho2_lambda_comparison_no_regret'.
        rho1 (float, optional): correlation parameter 1. Default to 0.
        rho2 (float, optional): correlation parameter 2. Default to 0.
        lam_vals (iterable, optional): collection of lambda values. Default to LAM_VALS.
        k (int, optional): embedding space dimension. Default to BASE_K.
        n (int, optional): number of publishers. Default to BASE_N.
        s (int, optional): demand function support size. Default to BASE_S.
        B (int, optional): number of samples. Default to B_.
        eta (float, optional): learning rate. Default to ETA_.
        T (int, optional): number of simulation rounds. Default to T_.
        eps (float, optional): epsilon value for epsilon-PNE. Default to EPS_.
    """
    assert 'rho1' in base_desc and 'rho2' in base_desc, 'The description must contain rho1 and rho2'
    
    np.random.seed(SEED)
    ranking_function_lst = [LINEAR_PROPORTIONAL, ROOT_PROPORTIONAL, LOG_PROPORTIONAL]
    additional_param_lst = [None] * 3
    params = {'eta': eta, 'T': T, 'eps': eps}
    generate_x0_x_star = lambda k, n, s: generate_normal_x0_x_star(k, n, s, rho1, rho2)
    
    res = comparison(ranking_function_lst, additional_param_lst, 
                     k, n, s, tqdm(lam_vals, miniters=1), B, params, 
                     generate_x0_x_star=generate_x0_x_star)

    rho1, rho2 = str(rho1), str(rho2)
    desc = base_desc.replace('rho1', rho1).replace('rho2', rho2)
    save_results(res, desc, ranking_function_lst, additional_param_lst, 
                 k, n, s, lam_vals, B, params)
    
    
def run_all_normal_rho_lambda_comparison(base_desc='normal_rho1_rho2_lambda_comparison_no_regret',
                                         rho1_vals=RHO1_VALS, rho2_vals=RHO2_VALS, lam_vals=LAM_VALS_,
                                         k=BASE_K, n=BASE_N, s=BASE_S, B=B_, eta=ETA_REGRET_, T=T_REGRET_, eps=EPS_REGRET_):
    """Runs the normal distribution lambda comparison simulation of the proportional ranking functions with combinations of rho values.
    
    Args:
        base_desc (str, optional): Template description of the simulation. Default to 'normal_rho1_rho2_lambda_comparison_no_regret'.
        rho1_vals (iterable, optional): collection of correlation parameter 1 values. Default to RHO1_VALS.
        rho2_vals (iterable, optional): collection of correlation parameter 2 values. Default to RHO2_VALS.
        lam_vals (iterable, optional): collection of lambda values. Default to LAM_VALS.
        k (int, optional): embedding space dimension. Default to BASE_K.
        n (int, optional): number of publishers. Default to BASE_N.
        s (int, optional): demand function support size. Default to BASE_S.
        B (int, optional): number of samples. Default to B_.
        eta (float, optional): learning rate. Default to ETA_.
        T (int, optional): number of simulation rounds. Default to T_.
        eps (float, optional): epsilon value for epsilon-PNE. Default to EPS_.
    """
    assert 'rho1' in base_desc and 'rho2' in base_desc, 'The description must contain rho1 and rho2'
    
    for rho1 in rho1_vals:
        for rho2 in rho2_vals:
            print(f'Running comparison for rho1={rho1}, rho2={rho2}')
            run_normal_rho_lambda_comparison(base_desc, rho1, rho2, lam_vals,
                                                k, n, s, B, eta, T, eps)


# load functions
def load_normal_rho_lam_comparison(base_desc='normal_rho1_rho2_lambda_comparison_no_regret', rho1=0, rho2=0):
    """Loads and plots the results of the normal distribution lambda comparison simulation with specific ro values.
    
    Args:
        base_desc (str, optional): Template description of the simulation. Default to 'normal_rho1_rho2_lambda_comparison_no_regret'.
        rho1 (float, optional): correlation parameter 1. Default to 0.
        rho2 (float, optional): correlation parameter 2. Default to 0.
    """
    assert 'rho1' in base_desc and 'rho2' in base_desc, 'The description must contain rho1 and rho2'
    
    rho1, rho2 = str(rho1), str(rho2)
    desc = base_desc.replace('rho1', rho1).replace('rho2', rho2)
    create_plots_from_files(desc, lam_key_function, ranking_function_label_function, 
                            PROPORTIONAL_COLORS, x_ticks=LAMBDA_X_TICKS, split=True,
                            labels_sort_key=ranking_function_label_function)
    