from tqdm import tqdm

from utils.constants import *
from utils.general_utils import save_results, create_single_plot_from_files
from utils.simulations_utils import comparison


# run functions
def run_linear_b_no_regret_comparison(desc='proportional_linear_b_comparison_no_regret', b_vals=INTERCEPT_VALS,
                                      k=BASE_K, n=BASE_N, s=BASE_S, lam=BASE_LAM, B=B_,
                                      eta=ETA_REGRET_, T=T_REGRET_, eps=EPS_REGRET_):
    """Runs the comparison of the linear proportional ranking function for different intercept values.
    
    Args:
        desc (str, optional): Description of the simulation. Default to 'proportional_linear_b_comparison_no_regret'.
        b_vals (iterable, optional): collection of intercept values. Default to INTERCEPT_VALS.
        k (int, optional): embedding space dimension. Default to BASE_K.
        n (int, optional): number of publishers. Default to BASE_N.
        s (int, optional): demand function support size. Default to BASE_S.
        lam (float, optional): lambda value. Default to BASE_LAM.
        B (int, optional): number of samples. Default to B_.
        eta (float, optional): learning rate. Default to ETA_REGRET_.
        T (int, optional): number of simulation rounds. Default to T_REGRET_.
        eps (float, optional): epsilon value for epsilon-PNE. Default to EPS_REGRET_.
    """
    np.random.seed(SEED)
    amount = len(b_vals)
    ranking_function_lst = [LINEAR_PROPORTIONAL] * amount
    additional_param_lst = b_vals
    params = {'eta': eta, 'T': T, 'eps': eps}

    res = comparison(tqdm(ranking_function_lst, miniters=1), additional_param_lst, 
                     k, n, s, lam, B, params)

    save_results(res, desc, ranking_function_lst, additional_param_lst, k, n, s, lam, B, params)


def run_root_a_no_regret_comparison(desc='proportional_root_a_comparison_no_regret', a_vals=POWER_VALS,
                                    k=BASE_K, n=BASE_N, s=BASE_S, lam=BASE_LAM, B=B_,
                                    eta=ETA_REGRET_, T=T_REGRET_, eps=EPS_REGRET_):
    """Runs the comparison of the root proportional ranking function for different power values.
    
    Args:
        desc (str, optional): Description of the simulation. Default to 'proportional_root_a_comparison_no_regret'.
        a_vals (iterable, optional): collection of power values. Default to POWER_VALS.
        k (int, optional): embedding space dimension. Default to BASE_K.
        n (int, optional): number of publishers. Default to BASE_N.
        s (int, optional): demand function support size. Default to BASE_S.
        lam (float, optional): lambda value. Default to BASE_LAM.
        B (int, optional): number of samples. Default to B_.
        eta (float, optional): learning rate. Default to ETA_REGRET_.
        T (int, optional): number of simulation rounds. Default to T_REGRET_.
        eps (float, optional): epsilon value for epsilon-PNE. Default to EPS_REGRET_.
    """
    np.random.seed(SEED)
    amount = len(a_vals)
    ranking_function_lst = [ROOT_PROPORTIONAL] * amount
    additional_param_lst = a_vals
    params = {'eta': eta, 'T': T, 'eps': eps}

    res = comparison(tqdm(ranking_function_lst, miniters=1), additional_param_lst, 
                     k, n, s, lam, B, params)

    save_results(res, desc, ranking_function_lst, additional_param_lst, k, n, s, lam, B, params)


def run_log_c_no_regret_comparison(desc='proportional_log_c_comparison_no_regret', c_vals=SHIFT_VALS,
                                   k=BASE_K, n=BASE_N, s=BASE_S, lam=BASE_LAM, B=B_,
                                   eta=ETA_REGRET_, T=T_REGRET_, eps=EPS_REGRET_):
    """Runs the comparison of the logarithmic proportional ranking function for different shift values.
    
    Args:
        desc (str, optional): Description of the simulation. Default to 'proportional_log_c_comparison_no_regret'.
        c_vals (iterable, optional): collection of shift values. Default to SHIFT_VALS.
        k (int, optional): embedding space dimension. Default to BASE_K.
        n (int, optional): number of publishers. Default to BASE_N.
        s (int, optional): demand function support size. Default to BASE_S.
        lam (float, optional): lambda value. Default to BASE_LAM.
        B (int, optional): number of samples. Default to B_.
        eta (float, optional): learning rate. Default to ETA_REGRET_.
        T (int, optional): number of simulation rounds. Default to T_REGRET_.
        eps (float, optional): epsilon value for epsilon-PNE. Default to EPS_REGRET_.
    """
    np.random.seed(SEED)
    amount = len(c_vals)
    ranking_function_lst = [LOG_PROPORTIONAL] * amount
    additional_param_lst = c_vals
    params = {'eta': eta, 'T': T, 'eps': eps}

    res = comparison(tqdm(ranking_function_lst, miniters=1), additional_param_lst, 
                     k, n, s, lam, B, params)

    save_results(res, desc, ranking_function_lst, additional_param_lst, k, n, s, lam, B, params)


# load functions
def load_linear_b_no_regret_results(desc='proportional_linear_b_comparison_no_regret'):
    """Loads the results of the linear proportional ranking function comparison for different intercept values.
    
    Args:
        desc (str, optional): directory name of the results. Default to 'proportional_linear_b_comparison_no_regret'.
    """

    create_single_plot_from_files(desc, additional_param_label_function, 
                                  PROPORTIONAL_COLORS[LINEAR_PROPORTIONAL], INTERCEPT_STR, 
                                  x_ticks=ALL, functions_sort_key=float_sort_key)


def load_root_a_no_regret_results(desc='proportional_root_a_comparison_no_regret'):
    """Loads the results of the root proportional ranking function comparison for different power values.
    
    Args:
        desc (str, optional): directory name of the results. Default to 'proportional_root_a_comparison_no_regret'.
    """

    create_single_plot_from_files(desc, additional_param_label_function, 
                                  PROPORTIONAL_COLORS[ROOT_PROPORTIONAL], POWER_STR, 
                                  x_ticks=ALL, functions_sort_key=float_sort_key)


def load_log_c_no_regret_results(desc='proportional_log_c_comparison_no_regret'):
    """Loads the results of the logarithmic proportional ranking function comparison for different shift values.
    
    Args:
        desc (str, optional): directory name of the results. Default to 'proportional_log_c_comparison_no_regret'.
    """

    create_single_plot_from_files(desc, additional_param_label_function, 
                                  PROPORTIONAL_COLORS[LOG_PROPORTIONAL], SHIFT_STR, 
                                  x_ticks=ALL, functions_sort_key=float_sort_key)
