from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

from utils.constants import *
from utils.general_utils import save_results, create_plots_from_files, create_single_plot_from_files
from utils.simulations_utils import comparison, no_regret_dynamics_regret_calc, calc_regret
from utils.publishers_game import GAME_TYPES

# regret over time functions
def plot_regret_over_time_softmax(desc='softmax_regret_over_time', beta=10, other=LINEAR_PROPORTIONAL, 
                                  k=BASE_K, n=BASE_N, s=BASE_S, lam=BASE_LAM, 
                                  eta=ETA_REGRET_, T=200, step_size=2):
    np.random.seed(0)
    x_star_lst = np.random.rand(s, k)
    x_0 = np.random.rand(n, k)
    G1 = GAME_TYPES[SOFTMAX](k, n, s, lam, x_0, x_star_lst, beta=beta)
    
    history, _ = no_regret_dynamics_regret_calc(G1, T, eta, return_hist=True)

    range_vals = range(1, T+1, step_size)
    regret_per_round_softmax = np.zeros(shape=(len(range_vals), n))
        
    for i, t in enumerate(tqdm(range_vals)):
        regret_per_round_softmax[i] = calc_regret(history[:t], G1)
        
    if other is not None:
        G2 = GAME_TYPES[other](k, n, s, lam, x_0, x_star_lst)
        history, _ = no_regret_dynamics_regret_calc(G2, T, eta, return_hist=True)
        
        regret_per_round_other = np.zeros(shape=(len(range_vals), n))
        
        for i, t in enumerate(tqdm(range_vals)):
            regret_per_round_other[i] = calc_regret(history[:t], G2)
    
    print()
    print("Initial documents:")
    for i in range(n):
        print("Publisher", i+1, "initial document:", x_0[i])
    print()
    print("Demand distribution:")
    for i in range(s):
        print("Probability 1/3 for document", x_star_lst[i])
        
    plt.figure(figsize=(6.4 * 1.5, 4.8 * 1.5))
    colors_v2 = ['#9467bd', '#8c564b', '#bcbd22', '#17becf']
    for i in range(n):
        if i >= len(colors_v2):
            current_color = '#' + ''.join([np.random.choice(list('0123456789ABCDEF')) for j in range(6)])
        else:
            current_color = colors_v2[i]
        plt.plot(range_vals, regret_per_round_softmax[:, i], label=f"{SOFTMAX}: Publisher {i+1}", color=current_color)
        if other is not None:
            plt.plot(range_vals, regret_per_round_other[:, i], label=f"{other}: Publisher {i+1}", color=current_color, linestyle='--')
    plt.xlabel("Round", fontsize=LABELS_FONT_SIZE)
    plt.ylabel("Regret", fontsize=LABELS_FONT_SIZE)
    plt.tick_params(axis='both', which='major', labelsize=TICKS_FONT_SIZE)
    plt.legend(fontsize=LEGEND_FONT_SIZE)
    plt.tight_layout()
    plt.savefig(f"{FIGURES_PATH}/{desc}.png")
    plt.show()


# run functions
def run_softmax_beta_no_regret_comparison(desc='softmax_beta_comparison_no_regret', beta_vals=BETA_VALS_,
                                          k=BASE_K, n=BASE_N, s=BASE_S, lam=BASE_LAM, B=B_,
                                          eta=ETA_REGRET_, T=T_REGRET_, eps=EPS_REGRET_):
    """Runs the comparison of the softmax ranking function for different beta values.
    
    Args:
        desc (str, optional): Description of the simulation. Default to 'softmax_beta_comparison_no_regret'.
        beta_vals (iterable, optional): collection of beta values. Default to BETA_VALS_.
        k (int, optional): embedding space dimension. Default to BASE_K.
        n (int, optional): number of publishers. Default to BASE_N.
        s (int, optional): demand function support size. Default to BASE_S.
        lam (float, optional): lambda value. Default to BASE_LAM.
        B (int, optional): number of samples. Default to B_.
        eta (float, optional): learning rate. Default to ETA_REGRET_.
        T (int, optional): number of simulation rounds. Default to T_REGRET_.
        eps (float, optional): epsilon value for epsilon-PNE. Default to EPS_REGRET_.
    """
    np.random.seed(SEED)
    amount = len(beta_vals)
    ranking_function_lst = [SOFTMAX] * amount
    additional_param_lst = beta_vals
    params = {'eta': eta, 'T': T, 'eps': eps}

    res = comparison(tqdm(ranking_function_lst, miniters=1), additional_param_lst,
                     k, n, s, lam, B, params)

    save_results(res, desc, ranking_function_lst, additional_param_lst,
                 k, n, s, lam, B, params)


def run_softmax_lambda_no_regret_comparison(desc='softmax_lambda_comparison_no_regret',
                                            lam_vals=LAM_VALS_, k=BASE_K, n=BASE_N, s=BASE_S,
                                            B=B_, eta=ETA_REGRET_, T=T_REGRET_, eps=EPS_REGRET_):
    """Runs the comparison of the softmax ranking function for different lambda values.
    
    Args:
        desc (str, optional): Description of the simulation. Default to 'lambda_comparison_no_regret'.
        lam_vals (iterable, optional): collection of lambda values. Default to LAM_VALS.
        k (int, optional): embedding space dimension. Default to BASE_K.
        n (int, optional): number of publishers. Default to BASE_N.
        s (int, optional): demand function support size. Default to BASE_S.
        B (int, optional): number of samples. Default to B_.
        eta (float, optional): learning rate. Default to ETA_.
        T (int, optional): number of simulation rounds. Default to T_.
        eps (float, optional): epsilon value for epsilon-PNE. Default to EPS_.
    """
    np.random.seed(SEED)
    ranking_function_lst = [SOFTMAX]
    additional_param_lst = [None]
    params = {'eta': eta, 'T': T, 'eps': eps}

    res = comparison(ranking_function_lst, additional_param_lst,
                     k, n, s, tqdm(lam_vals, miniters=1), B, params)

    save_results(res, desc, ranking_function_lst, additional_param_lst, k, n, s, lam_vals, B, params)


def run_softmax_n_no_regret_comparison(desc='softmax_n_comparison_no_regret',
                                       n_vals=N_VALS_, k=BASE_K, lam=BASE_LAM, s=BASE_S,
                                       B=B_, eta=ETA_REGRET_, T=T_REGRET_, eps=EPS_REGRET_):
    """Runs the comparison of the softmax ranking function for different number of publishers.
    
    Args:
        desc (str, optional): Description of the simulation. Default to 'n_comparison_no_regret'.
        n_vals (iterable, optional): collection of publishers' amounts. Default to N_VALS.
        k (int, optional): embedding space dimension. Default to BASE_K.
        lam (float, optional): lambda value. Default to BASE_LAM.
        s (int, optional): demand function support size. Default to BASE_S.
        B (int, optional): number of samples. Default to B_.
        eta (float, optional): learning rate. Default to ETA_.
        T (int, optional): number of simulation rounds. Default to T_.
        eps (float, optional): epsilon value for epsilon-PNE. Default to EPS_.
    """
    np.random.seed(SEED)
    ranking_function_lst = [SOFTMAX]
    additional_param_lst = [None]
    params = {'eta': eta, 'T': T, 'eps': eps}

    res = comparison(ranking_function_lst, additional_param_lst,
                     k, tqdm(n_vals, miniters=1), s, lam, B, params)

    save_results(res, desc, ranking_function_lst, additional_param_lst, k, n_vals, s, lam, B, params)


def run_softmax_s_no_regret_comparison(desc='softmax_s_comparison_no_regret',
                                       s_vals=S_VALS_, k=BASE_K, n=BASE_N, lam=BASE_LAM,
                                       B=B_, eta=ETA_REGRET_, T=T_REGRET_, eps=EPS_REGRET_):
    """Runs the comparison of the softmax ranking function for different demand function support sizes.
    
    Args:
        desc (str, optional): Description of the simulation. Default to 's_comparison_no_regret'.
        s_vals (iterable, optional): collection of demand function support sizes. Default to S_VALS.
        k (int, optional): embedding space dimension. Default to BASE_K.
        n (int, optional): number of publishers. Default to BASE_N.
        lam (float, optional): lambda value. Default to BASE_LAM.
        B (int, optional): number of samples. Default to B_.
        eta (float, optional): learning rate. Default to ETA_.
        T (int, optional): number of simulation rounds. Default to T_.
        eps (float, optional): epsilon value for epsilon-PNE. Default to EPS_.
    """
    np.random.seed(SEED)
    ranking_function_lst = [SOFTMAX]
    additional_param_lst = [None]
    params = {'eta': eta, 'T': T, 'eps': eps}

    res = comparison(ranking_function_lst, additional_param_lst,
                     k, n, tqdm(s_vals, miniters=1), lam, B, params)

    save_results(res, desc, ranking_function_lst, additional_param_lst, k, n, s_vals, lam, B, params)


def run_softmax_k_no_regret_comparison(desc='softmax_k_comparison_no_regret',
                                       k_vals=K_VALS_, n=BASE_N, lam=BASE_LAM, s=BASE_S,
                                       B=B_, eta=ETA_REGRET_, T=T_REGRET_, eps=EPS_REGRET_):
    """Runs the comparison of the softmax ranking function for different embedding space dimensions.
    
    Args:
        desc (str, optional): Description of the simulation. Default to 'k_comparison_no_regret'.
        k_vals (iterable, optional): collection of embedding space dimensions. Default to K_VALS.
        n (int, optional): number of publishers. Default to BASE_N.
        lam (float, optional): lambda value. Default to BASE_LAM.
        s (int, optional): demand function support size. Default to BASE_S.
        B (int, optional): number of samples. Default to B_.
        eta (float, optional): learning rate. Default to ETA_.
        T (int, optional): number of simulation rounds. Default to T_.
        eps (float, optional): epsilon value for epsilon-PNE. Default to EPS_.
    """
    np.random.seed(SEED)
    ranking_function_lst = [SOFTMAX]
    additional_param_lst = [None]
    params = {'eta': eta, 'T': T, 'eps': eps}

    res = comparison(ranking_function_lst, additional_param_lst,
                     tqdm(k_vals, miniters=1), n, s, lam, B, params)

    save_results(res, desc, ranking_function_lst, additional_param_lst, k_vals, n, s, lam, B, params)


# load functions
def load_softmax_beta_no_regret_results(desc='softmax_beta_comparison_no_regret'):
    """Loads the results of the softmax ranking function comparison for different beta values.
    
    Args:
        desc (str, optional): directory name of the results. Default to 'softmax_beta_comparison_no_regret'.
    """
    x_ticks = [1, 2, 3, 5, 10]
    create_single_plot_from_files(desc, additional_param_label_function, PROPORTIONAL_COLORS[SOFTMAX],
                                  BETA_STR, x_ticks=x_ticks, functions_sort_key=float_sort_key)


def load_softmax_lambda_no_regret_results(desc='softmax_lambda_comparison_no_regret'):
    """Loads and plots the results of the lambda comparison simulation
    for the softmax ranking function.

    Args:
        desc (str, optional): directory name of the results. Default to 'softmax_lambda_comparison_no_regret'.
    """
    descs = [desc, desc.removeprefix('softmax_')]
    x_ticks = np.round(np.arange(0, 1.1, 0.1), 1)
    create_plots_from_files(descs, lam_key_function, ranking_function_label_function, PROPORTIONAL_COLORS,
                            x_ticks=x_ticks, labels_sort_key=ranking_function_label_function)


def load_softmax_n_no_regret_results(desc='softmax_n_comparison_no_regret'):
    """Loads and plots the results of the publishers' amount comparison simulation
    for the softmax ranking function.

    Args:
        desc (str, optional): directory name of the results. Default to 'softmax_n_comparison_no_regret'.
    """
    descs = [desc, desc.removeprefix('softmax_')]
    create_plots_from_files(descs, n_key_function, ranking_function_label_function, PROPORTIONAL_COLORS,
                            x_ticks=ALL, labels_sort_key=ranking_function_label_function)


def load_softmax_s_no_regret_results(desc='softmax_s_comparison_no_regret'):
    """Loads and plots the results of the demand function support size comparison simulation
    for the softmax ranking function.

    Args:
        desc (str, optional): directory name of the results. Default to 'softmax_s_comparison_no_regret'.
    """
    descs = [desc, desc.removeprefix('softmax_')]
    x_ticks = [1] + list(range(5, 51, 5))
    create_plots_from_files(descs, s_key_function, ranking_function_label_function, PROPORTIONAL_COLORS,
                            x_ticks=x_ticks, labels_sort_key=ranking_function_label_function)


def load_softmax_k_no_regret_results(desc='softmax_k_comparison_no_regret'):
    """Loads and plots the results of the embedding space dimension comparison simulation
    for the softmax ranking function.
    
    Args:
        desc (str, optional): directory name of the results. Default to 'softmax_k_comparison_no_regret'.
    """
    descs = [desc, desc.removeprefix('softmax_')]
    conv_graph = {"axis_name": CONVERGENCE_RATE, "limits": EXTENDED_CONVERGENCE_RATE,
                  "ticks": EXTENDED_CONVERGENCE_RATE}
    graphs = [(PUBLISHERS_WELFARE,), (USERS_WELFARE,), conv_graph]
    create_plots_from_files(descs, k_key_function, ranking_function_label_function, PROPORTIONAL_COLORS,
                            x_ticks=ALL, graphs=graphs, labels_sort_key=ranking_function_label_function)
