from tqdm import tqdm

from utils.constants import *
from utils.general_utils import save_results, create_plots_from_files
from utils.simulations_utils import comparison


# run functions
def run_lambda_no_regret_comparison(desc='lambda_comparison_no_regret', lam_vals=LAM_VALS_, k=BASE_K, n=BASE_N,
                                    s=BASE_S, B=B_, eta=ETA_REGRET_, T=T_REGRET_, eps=EPS_REGRET_):
    """Runs the comparison of the proportional ranking functions for different lambda values.
    
    Args:
        desc (str, optional): Description of the simulation. Default to 'lambda_comparison_no_regret'.
        lam_vals (iterable, optional): collection of lambda values. Default to LAM_VALS.
        k (int, optional): embedding space dimension. Default to BASE_K.
        n (int, optional): number of publishers. Default to BASE_N.
        s (int, optional): demand function support size. Default to BASE_S.
        B (int, optional): number of samples. Default to B_.
        eta (float, optional): learning rate. Default to ETA_.
        T (int, optional): number of simulation rounds. Default to T_.
        eps (float, optional): epsilon value for epsilon-PNE. Default to EPS_.
    """
    np.random.seed(SEED)
    ranking_function_lst = [LINEAR_PROPORTIONAL, ROOT_PROPORTIONAL, LOG_PROPORTIONAL]
    additional_param_lst = [None] * 3
    params = {'eta': eta, 'T': T, 'eps': eps}

    res = comparison(ranking_function_lst, additional_param_lst, 
                     k, n, s, tqdm(lam_vals, miniters=1), B, params)

    save_results(res, desc, ranking_function_lst, additional_param_lst, 
                 k, n, s, lam_vals, B, params)


def run_n_no_regret_comparison(desc='n_comparison_no_regret', n_vals=N_VALS_, k=BASE_K, lam=BASE_LAM, s=BASE_S,
                               B=B_, eta=ETA_REGRET_, T=T_REGRET_, eps=EPS_REGRET_):
    """Runs the comparison of the proportional ranking functions for different number of publishers.
    
    Args:
        desc (str, optional): Description of the simulation. Default to 'n_comparison_no_regret'.
        n_vals (iterable, optional): collection of publishers' amounts. Default to N_VALS.
        k (int, optional): embedding space dimension. Default to BASE_K.
        lam (float, optional): lambda value. Default to BASE_LAM.
        s (int, optional): demand function support size. Default to BASE_S.
        B (int, optional): number of samples. Default to B_.
        eta (float, optional): learning rate. Default to ETA_.
        T (int, optional): number of simulation rounds. Default to T_.
        eps (float, optional): epsilon value for epsilon-PNE. Default to EPS_.
    """
    np.random.seed(SEED)
    ranking_function_lst = [LINEAR_PROPORTIONAL, ROOT_PROPORTIONAL, LOG_PROPORTIONAL]
    additional_param_lst = [None] * 3
    params = {'eta': eta, 'T': T, 'eps': eps}

    res = comparison(ranking_function_lst, additional_param_lst, 
                     k, tqdm(n_vals, miniters=1), s, lam, B, params)

    save_results(res, desc, ranking_function_lst, additional_param_lst, 
                 k, n_vals, s, lam, B, params)


def run_s_no_regret_comparison(desc='s_comparison_no_regret', s_vals=S_VALS_, k=BASE_K, n=BASE_N, lam=BASE_LAM,
                               B=B_, eta=ETA_REGRET_, T=T_REGRET_, eps=EPS_REGRET_):
    """Runs the comparison of the proportional ranking functions for different demand function support sizes.
            
        Args:
            desc (str, optional): Description of the simulation. Default to 's_comparison_no_regret'.
            s_vals (iterable, optional): collection of demand function support sizes. Default to S_VALS.
            k (int, optional): embedding space dimension. Default to BASE_K.
            n (int, optional): number of publishers. Default to BASE_N.
            lam (float, optional): lambda value. Default to BASE_LAM.
            B (int, optional): number of samples. Default to B_.
            eta (float, optional): learning rate. Default to ETA_.
            T (int, optional): number of simulation rounds. Default to T_.
            eps (float, optional): epsilon value for epsilon-PNE. Default to EPS_.
        """
    np.random.seed(SEED)
    ranking_function_lst = [LINEAR_PROPORTIONAL, ROOT_PROPORTIONAL, LOG_PROPORTIONAL]
    additional_param_lst = [None] * 3
    params = {'eta': eta, 'T': T, 'eps': eps}

    res = comparison(ranking_function_lst, additional_param_lst, 
                     k, n, tqdm(s_vals, miniters=1), lam, B, params)

    save_results(res, desc, ranking_function_lst, additional_param_lst, 
                 k, n, s_vals, lam, B, params)


def run_k_no_regret_comparison(desc='k_comparison_no_regret', k_vals=K_VALS_, n=BASE_N, lam=BASE_LAM, s=BASE_S,
                               B=B_, eta=ETA_REGRET_, T=T_REGRET_, eps=EPS_REGRET_):
    """Runs the comparison of the proportional ranking functions for different embedding space dimensions.
        
        Args:
            desc (str, optional): Description of the simulation. Default to 'k_comparison_no_regret'.
            k_vals (iterable, optional): collection of embedding space dimensions. Default to K_VALS.
            n (int, optional): number of publishers. Default to BASE_N.
            lam (float, optional): lambda value. Default to BASE_LAM.
            s (int, optional): demand function support size. Default to BASE_S.
            B (int, optional): number of samples. Default to B_.
            eta (float, optional): learning rate. Default to ETA_.
            T (int, optional): number of simulation rounds. Default to T_.
            eps (float, optional): epsilon value for epsilon-PNE. Default to EPS_.
        """
    np.random.seed(SEED)
    ranking_function_lst = [LINEAR_PROPORTIONAL, ROOT_PROPORTIONAL, LOG_PROPORTIONAL]
    additional_param_lst = [None] * 3
    params = {'eta': eta, 'T': T, 'eps': eps}

    res = comparison(ranking_function_lst, additional_param_lst, 
                     tqdm(k_vals, miniters=1), n, s, lam, B, params)

    save_results(res, desc, ranking_function_lst, additional_param_lst, 
                 k_vals, n, s, lam, B, params)


# load functions
def load_lambda_no_regret_results(desc='lambda_comparison_no_regret'):
    """Loads and plots the results of the lambda comparison simulation.

    Args:
        desc (str, optional): directory name of the results. Default to 'lam_comparison_no_regret'.
    """
    x_ticks = np.round(np.arange(0, 1.1, 0.1), 1)
    create_plots_from_files(desc, lam_key_function, ranking_function_label_function, 
                            PROPORTIONAL_COLORS, x_ticks=x_ticks, labels_sort_key=ranking_function_label_function)


def load_n_no_regret_results(desc='n_comparison_no_regret'):
    """Loads and plots the results of the publishers' amount comparison simulation.

    Args:
        desc (str, optional): directory name of the results. Default to 'n_comparison_no_regret'.
    """

    create_plots_from_files(desc, n_key_function, ranking_function_label_function, 
                            PROPORTIONAL_COLORS, x_ticks=ALL, labels_sort_key=ranking_function_label_function)


def load_s_no_regret_results(desc='s_comparison_no_regret'):
    """Loads and plots the results of the demand function support size comparison simulation.

    Args:
        desc (str, optional): directory name of the results. Default to 's_comparison_no_regret'.
    """
    x_ticks = [1] + list(range(5, 51, 5))
    create_plots_from_files(desc, s_key_function, ranking_function_label_function,
                            PROPORTIONAL_COLORS, x_ticks=x_ticks, labels_sort_key=ranking_function_label_function)


def load_k_no_regret_results(desc='k_comparison_no_regret'):
    """Loads and plots the results of the embedding space dimension comparison simulation.

    Args:
        desc (str, optional): directory name of the results. Default to 'k_comparison_no_regret'.
    """
    conv_graph = {"axis_name": CONVERGENCE_RATE, "limits": EXTENDED_CONVERGENCE_RATE, "ticks": EXTENDED_CONVERGENCE_RATE}
    graphs=[(PUBLISHERS_WELFARE,), (USERS_WELFARE,), conv_graph]
    create_plots_from_files(desc, k_key_function, ranking_function_label_function,
                            PROPORTIONAL_COLORS, x_ticks=ALL, graphs=graphs, labels_sort_key=ranking_function_label_function)
