from tqdm import tqdm

from utils.constants import *
from utils.general_utils import save_results, create_plots_from_files, create_single_plot_from_files
from utils.simulations_utils import comparison, no_regret_dynamics_regret_calc

# TEX_SCALE = 1.5
# PARAMS_FIGSIZE = (6.4 * TEX_SCALE, 4.8 * TEX_SCALE)
REGRET_GRAPHS = [{'axis_name': AVG_REGRET, 'index': 2}]

# run functions
def run_calc_regret_lambda_comparison(desc='regret_calc_lambda_comparison', lam_vals=LAM_VALS_, k=BASE_K, n=BASE_N,
                                      s=BASE_S, B=B_, eta=ETA_REGRET_, T=T_REGRET_CALC_):
    """Compare the regret of the proportional ranking functions for different lambda values.
    
    Args:
        desc (str, optional): Description of the simulation. Default to 'regret_calc_lambda_comparison'.
        lam_vals (iterable, optional): collection of lambda values. Default to LAM_VALS.
        k (int, optional): embedding space dimension. Default to BASE_K.
        n (int, optional): number of publishers. Default to BASE_N.
        s (int, optional): demand function support size. Default to BASE_S.
        B (int, optional): number of samples. Default to B_.
        eta (float, optional): learning rate. Default to ETA_.
        T (int, optional): number of simulation rounds. Default to T_REGRET_CALC_.
    """
    np.random.seed(SEED)
    ranking_function_lst = [LINEAR_PROPORTIONAL, ROOT_PROPORTIONAL, LOG_PROPORTIONAL]
    additional_param_lst = [None] * 3
    params = {'eta': eta, 'T': T}

    res = comparison(ranking_function_lst, additional_param_lst,
                     k, n, s, tqdm(lam_vals, miniters=1), B, params,
                     simulation_func=no_regret_dynamics_regret_calc)

    save_results(res, desc, ranking_function_lst, additional_param_lst,
                 k, n, s, lam_vals, B, params, suffix=CALC_REGRET_SUFFIX)


def run_calc_regret_n_comparison(desc='regret_calc_n_comparison', n_vals=N_VALS_, k=BASE_K, s=BASE_S, lam=BASE_LAM,
                                 B=B_, eta=ETA_REGRET_, T=T_REGRET_CALC_):
    """Compare the regret of the proportional ranking functions for different n values.
    
    Args:
        desc (str, optional): Description of the simulation. Default to 'regret_calc_n_comparison'.
        n_vals (iterable, optional): collection of n values. Default to N_VALS.
        k (int, optional): embedding space dimension. Default to BASE_K.
        s (int, optional): demand function support size. Default to BASE_S.
        lam (float, optional): regularization parameter. Default to BASE_LAM.
        B (int, optional): number of samples. Default to B_.
        eta (float, optional): learning rate. Default to ETA_.
        T (int, optional): number of simulation rounds. Default to T_REGRET_CALC_.
    """
    np.random.seed(SEED)
    ranking_function_lst = [LINEAR_PROPORTIONAL, ROOT_PROPORTIONAL, LOG_PROPORTIONAL]
    additional_param_lst = [None] * 3
    params = {'eta': eta, 'T': T}

    res = comparison(ranking_function_lst, additional_param_lst,
                     k, tqdm(n_vals, miniters=1), s, lam, B, params,
                     simulation_func=no_regret_dynamics_regret_calc)

    save_results(res, desc, ranking_function_lst, additional_param_lst,
                 k, n_vals, s, lam, B, params, suffix=CALC_REGRET_SUFFIX)


def run_calc_regret_s_comparison(desc='regret_calc_s_comparison', s_vals=S_VALS_, k=BASE_K, n=BASE_N, lam=BASE_LAM,
                                 B=B_, eta=ETA_REGRET_, T=T_REGRET_CALC_):
    """Compare the regret of the proportional ranking functions for different s values.
        
        Args:
            desc (str, optional): Description of the simulation. Default to 'regret_calc_s_comparison'.
            s_vals (iterable, optional): collection of s values. Default to S_VALS.
            k (int, optional): embedding space dimension. Default to BASE_K.
            n (int, optional): number of publishers. Default to BASE_N.
            lam (float, optional): regularization parameter. Default to BASE_LAM.
            B (int, optional): number of samples. Default to B_.
            eta (float, optional): learning rate. Default to ETA_.
            T (int, optional): number of simulation rounds. Default to T_REGRET_CALC_.
        """
    np.random.seed(SEED)
    ranking_function_lst = [LINEAR_PROPORTIONAL, ROOT_PROPORTIONAL, LOG_PROPORTIONAL]
    additional_param_lst = [None] * 3
    params = {'eta': eta, 'T': T}

    res = comparison(ranking_function_lst, additional_param_lst,
                     k, n, tqdm(s_vals, miniters=1), lam, B, params,
                     simulation_func=no_regret_dynamics_regret_calc)

    save_results(res, desc, ranking_function_lst, additional_param_lst,
                 k, n, s_vals, lam, B, params, suffix=CALC_REGRET_SUFFIX)


def run_calc_regret_k_comparison(desc='regret_calc_k_comparison', k_vals=K_VALS_, n=BASE_N, s=BASE_S, lam=BASE_LAM,
                                 B=B_, eta=ETA_REGRET_, T=10 * T_REGRET_CALC_):
    """Compare the regret of the proportional ranking functions for different k values.

    Args:
        desc (str, optional): Description of the simulation. Default to 'regret_calc_k_comparison'.
        k_vals (iterable, optional): collection of k values. Default to K_VALS.
        n (int, optional): number of publishers. Default to BASE_N.
        s (int, optional): demand function support size. Default to BASE_S.
        lam (float, optional): regularization parameter. Default to BASE_LAM.
        B (int, optional): number of samples. Default to B_.
        eta (float, optional): learning rate. Default to ETA_.
        T (int, optional): number of simulation rounds. Default to 10 * T_REGRET_CALC_.
    """
    np.random.seed(SEED)
    ranking_function_lst = [LINEAR_PROPORTIONAL, ROOT_PROPORTIONAL, LOG_PROPORTIONAL]
    additional_param_lst = [None] * 3
    params = {'eta': eta, 'T': T}

    res = comparison(ranking_function_lst, additional_param_lst,
                     tqdm(k_vals, miniters=1), n, s, lam, B, params,
                     simulation_func=no_regret_dynamics_regret_calc)

    save_results(res, desc, ranking_function_lst, additional_param_lst,
                 k_vals, n, s, lam, B, params, suffix=CALC_REGRET_SUFFIX)


def run_calc_regret_linear_b_comparison(desc='regret_calc_linear_b_comparison', b_vals=INTERCEPT_VALS,
                                        k=BASE_K, n=BASE_N, s=BASE_S, lam=BASE_LAM, B=B_, eta=ETA_REGRET_,
                                        T=T_REGRET_CALC_):
    """Compare the regret of the linear proportional ranking function for different intercept values.

    Args:
        desc (str, optional): Description of the simulation. Default to 'regret_calc_linear_b_comparison'.
        b_vals (iterable, optional): collection of intercept values. Default to INTERCEPT_VALS.
        k (int, optional): embedding space dimension. Default to BASE_K.
        n (int, optional): number of publishers. Default to BASE_N.
        s (int, optional): demand function support size. Default to BASE_S.
        lam (float, optional): lambda value. Default to BASE_LAM.
        B (int, optional): number of samples. Default to B_.
        eta (float, optional): learning rate. Default to ETA_.
        T (int, optional): number of simulation rounds. Default to T_REGRET_CALC_.
    """
    np.random.seed(SEED)
    amount = len(b_vals)
    ranking_function_lst = [LINEAR_PROPORTIONAL] * amount
    additional_param_lst = b_vals
    params = {'eta': eta, 'T': T}

    res = comparison(tqdm(ranking_function_lst, miniters=1), additional_param_lst,
                     k, n, s, lam, B, params, simulation_func=no_regret_dynamics_regret_calc)

    save_results(res, desc, ranking_function_lst, additional_param_lst,
                 k, n, s, lam, B, params, suffix=CALC_REGRET_SUFFIX)


def run_calc_regret_root_a_comparison(desc='regret_calc_root_a_comparison', a_vals=POWER_VALS,
                                      k=BASE_K, n=BASE_N, s=BASE_S, lam=BASE_LAM, B=B_,
                                      eta=ETA_REGRET_, T=T_REGRET_CALC_):
    """Compare the regret of the root proportional ranking function for different power values.
    
    Args:
        desc (str, optional): Description of the simulation. Default to 'regret_calc_root_a_comparison'.
        a_vals (iterable, optional): collection of power values. Default to POWER_VALS.
        k (int, optional): embedding space dimension. Default to BASE_K.
        n (int, optional): number of publishers. Default to BASE_N.
        s (int, optional): demand function support size. Default to BASE_S.
        lam (float, optional): lambda value. Default to BASE_LAM.
        B (int, optional): number of samples. Default to B_.
        eta (float, optional): learning rate. Default to ETA_.
        T (int, optional): number of simulation rounds. Default to T_REGRET_CALC_.
    """
    np.random.seed(SEED)
    amount = len(a_vals)
    ranking_function_lst = [ROOT_PROPORTIONAL] * amount
    additional_param_lst = a_vals
    params = {'eta': eta, 'T': T}

    res = comparison(tqdm(ranking_function_lst, miniters=1), additional_param_lst,
                     k, n, s, lam, B, params, simulation_func=no_regret_dynamics_regret_calc)

    save_results(res, desc, ranking_function_lst, additional_param_lst,
                 k, n, s, lam, B, params, suffix=CALC_REGRET_SUFFIX)


def run_calc_regret_log_c_comparison(desc='regret_calc_log_c_comparison', c_vals=SHIFT_VALS,
                                     k=BASE_K, n=BASE_N, s=BASE_S, lam=BASE_LAM, B=B_,
                                     eta=ETA_REGRET_, T=T_REGRET_CALC_):
    """Compare the regret of the log proportional ranking function for different power values.
    
    Args:
        desc (str, optional): Description of the simulation. Default to 'regret_calc_log_c_comparison'.
        c_vals (iterable, optional): collection of shift values. Default to SHIFT_VALS.
        k (int, optional): embedding space dimension. Default to BASE_K.
        n (int, optional): number of publishers. Default to BASE_N.
        s (int, optional): demand function support size. Default to BASE_S.
        lam (float, optional): lambda value. Default to BASE_LAM.
        B (int, optional): number of samples. Default to B_.
        eta (float, optional): learning rate. Default to ETA_.
        T (int, optional): number of simulation rounds. Default to T_REGRET_CALC_.
    """
    np.random.seed(SEED)
    amount = len(c_vals)
    ranking_function_lst = [LOG_PROPORTIONAL] * amount
    additional_param_lst = c_vals
    params = {'eta': eta, 'T': T}

    res = comparison(tqdm(ranking_function_lst, miniters=1), additional_param_lst,
                     k, n, s, lam, B, params, simulation_func=no_regret_dynamics_regret_calc)

    save_results(res, desc, ranking_function_lst, additional_param_lst,
                 k, n, s, lam, B, params, suffix=CALC_REGRET_SUFFIX)


# load functions
def load_calc_regret_lambda_results(desc='regret_calc_lambda_comparison'):
    """Loads the results of the regret calculation for different lambda values.
    
    Args:
        desc (str, optional): directory name of the results. Default to 'regret_calc_lambda_comparison'.
    """
    x_ticks = np.round(np.arange(0, 1.1, 0.1), 1)
    create_plots_from_files(desc, lam_key_function, ranking_function_label_function, 
                            PROPORTIONAL_COLORS, x_ticks=x_ticks, graphs=REGRET_GRAPHS, 
                            labels_sort_key=ranking_function_label_function, figsize=None)


def load_calc_regret_n_results(desc='regret_calc_n_comparison'):
    """Loads the results of the regret calculation for different n values.
    
    Args:
        desc (str, optional): directory name of the results. Default to 'regret_calc_n_comparison'.
    """
    create_plots_from_files(desc, n_key_function, ranking_function_label_function, 
                            PROPORTIONAL_COLORS, x_ticks=ALL, graphs=REGRET_GRAPHS, 
                            labels_sort_key=ranking_function_label_function, figsize=None)


def load_calc_regret_s_results(desc='regret_calc_s_comparison'):
    """Loads the results of the regret calculation for different s values.
    
    Args:
        desc (str, optional): directory name of the results. Default to 'regret_calc_s_comparison'.
    """
    x_ticks = [1] + list(range(5, 51, 5))
    create_plots_from_files(desc, s_key_function, ranking_function_label_function, 
                            PROPORTIONAL_COLORS, x_ticks=x_ticks, graphs=REGRET_GRAPHS, 
                            labels_sort_key=ranking_function_label_function, figsize=None)


def load_calc_regret_k_results(desc='regret_calc_k_comparison'):
    """Loads the results of the regret calculation for different k values.
    
    Args:
        desc (str, optional): directory name of the results. Default to 'regret_calc_k_comparison'.
    """
    graphs = [{'axis_name': AVG_REGRET, 'index': 2, "limits": EXTENDED_AVG_REGRET, "ticks": EXTENDED_AVG_REGRET}]
    create_plots_from_files(desc, k_key_function, ranking_function_label_function, 
                            PROPORTIONAL_COLORS, x_ticks=ALL, graphs=graphs, 
                            labels_sort_key=ranking_function_label_function, figsize=None)


def load_calc_regret_linear_b_results(desc='regret_calc_linear_b_comparison'):
    """Loads the results of the regret calculation for different intercept values
    of the linear proportional ranking function.
    
    Args:
        desc (str, optional): directory name of the results. Default to 'regret_calc_linear_b_comparison'.
    """
    create_single_plot_from_files(desc, additional_param_label_function, PROPORTIONAL_COLORS[LINEAR_PROPORTIONAL],
                                  INTERCEPT_STR, x_ticks=ALL, graphs=REGRET_GRAPHS, functions_sort_key=float_sort_key, figsize=None)


def load_calc_regret_root_a_results(desc='regret_calc_root_a_comparison'):
    """Loads the results of the regret calculation for different power values
    of the root proportional ranking function.
    
    Args:
        desc (str, optional): directory name of the results. Default to 'regret_calc_root_a_comparison'.
    """
    create_single_plot_from_files(desc, additional_param_label_function, PROPORTIONAL_COLORS[ROOT_PROPORTIONAL],
                                  POWER_STR, x_ticks=ALL, graphs=REGRET_GRAPHS, functions_sort_key=float_sort_key, figsize=None)


def load_calc_regret_log_c_results(desc='regret_calc_log_c_comparison'):
    """Loads the results of the regret calculation for different shift values
    of the logarithmic proportional ranking function.
    
    Args:
        desc (str, optional): directory name of the results. Default to 'regret_calc_log_c_comparison'.
    """
    create_single_plot_from_files(desc, additional_param_label_function, PROPORTIONAL_COLORS[LOG_PROPORTIONAL],
                                  SHIFT_STR, x_ticks=ALL, graphs=REGRET_GRAPHS, functions_sort_key=float_sort_key, figsize=None)
