from joblib import Parallel, delayed
import multiprocessing as mp
import numpy as np
import pandas as pd
import pickle as pkl
from time import time
import os
import matplotlib.pyplot as plt
import seaborn as sns
from .BernoulliMAB import BetaBernoulliMAB
from .BetaMAB import BetaMAB
from .DiracMAB import DiracMAB
from .EmpiricalMAB import EmpiricalMAB
from .ExponentialMAB import ExponentialMAB
from .GaussianMAB import GaussianMAB
from .GaussianMixtureMAB import GaussianMixtureMAB
from .LogGaussianMAB import LogGaussianMAB
from .MultinomialMAB import MultinomialMAB
from .PoissonMAB import PoissonMAB
from .ParetoMAB import ParetoMAB
from .RademacherMAB import RadeMAB
from .Trunc_GaussianMAB import TruncGaussianMAB
from .UniformMAB import UniformMAB
from .MAB import GenericMAB
from tqdm import tqdm


mapping = {
    'B': BetaBernoulliMAB,
    'Beta': BetaMAB,
    'Dirac': DiracMAB,
    'Emp': EmpiricalMAB,
    'Exp': ExponentialMAB,
    'G': GaussianMAB,
    'GMx': GaussianMixtureMAB,
    'LG': LogGaussianMAB,
    'M': MultinomialMAB,
    'P': PoissonMAB,
    'Par': ParetoMAB,
    'TG': TruncGaussianMAB,
    'U': UniformMAB,
}

mapping_name = {
    'B': 'Bernoulli',
    'Beta': 'Beta',
    'Dirac': 'Dirac',
    'Emp': 'Empirical',
    'Exp': 'Exponential',
    'G': 'Gaussian',
    'GMisE': 'Gaussian Misspecified Exponential',
    'GMisP': 'Gaussian Misspecified Pareto',
    'GMx': 'Gaussian Mixture',
    'LG': 'LogGaussian',
    'M': 'Multinomial',
    'P': 'Poisson',
    'Par': 'Pareto',
    'TG': 'Truncated Gaussian',
    'U': 'Uniform',
}

colorset = {
    'kl_ucb': ('darkviolet', 'plum', 'darkviolet'),
    'IMED': ('darkviolet', 'plum', 'darkviolet'),
    'empirical_kl_ucb': ('deepskyblue', 'paleturquoise', 'deepskyblue'),
    'empirical_IMED': ('deepskyblue', 'paleturquoise', 'deepskyblue'),
    'UCB1': ('gold', 'ivory', 'gold'),
    'SSMC': ('darkolivegreen', 'lightgreen', 'darkolivegreen'),
    'RB_SDA': ('darkolivegreen', 'lightgreen', 'darkolivegreen'),
    'LB_SDA': ('darkolivegreen', 'lightgreen', 'darkolivegreen'),
    'TS': ('darkorange', 'antiquewhite', 'orange'),
    'TS_binarized': ('darkorange', 'antiquewhite', 'orange'),
    'NPTS': ('grey', 'lightgrey', 'grey'),
    'BDS': ('darkred', 'peachpuff', 'red'),
    'RDS': ('darkred', 'peachpuff', 'red'),
    'QDS': ('darkred', 'peachpuff', 'red'),
    'lower bound': ('rosybrown'),
    'default': ('dodgerblue', 'aliceblue', 'dodgerblue'),
}

linestyleset = {
    'kl_ucb': 'dashdot',
    'IMED': 'solid',
    'empirical_KL_UCB': 'dashdot',
    'empirical_IMED': 'solid',
    'SSMC': 'dashed',
    'LB_SDA': 'dashdot',
    'RB_SDA': 'solid',
    'TS': 'solid',
    'TS_binarized': 'dashdot',
    'BDS': 'dashed',
    'RDS': 'solid',
    'QDS': 'dashdot',
    'lower bound': 'dotted',
}

markerset = {}


def MC_xp(args, plot=False, pickle_path=None, caption='xp'):
    if len(args) == 7:
        bandit, p, T, n_xp, methods, param, store_step = args
        alias = methods
    else:
        bandit, p, T, n_xp, methods, param, store_step, alias = args
    model = mapping[bandit](p)
    all_r = []
    all_traj = {}
    add_lower_bound = model.Cp is not None
    for x, x_alias in zip(methods, alias):
        r, traj = model.MC_regret(x, n_xp, T, param[x_alias], store_step)
        all_r.append(r)
        all_traj[x_alias] = traj
    if add_lower_bound:
        all_r.append(model.Cp * np.log(1+np.arange(T)))
    df_r = pd.DataFrame(all_r).T

    if add_lower_bound:
        df_r.columns = alias + ['lower bound']
        df_r['lower bound'].iloc[0] = 0
    else:
        df_r.columns = alias

    if plot:
        df_r.plot(figsize=(10, 8), logx=True)
    if pickle_path is not None:
        pkl.dump(df_r, open(os.path.join(pickle_path, caption+'.pkl'), 'wb'))
    return df_r, all_traj


def multiprocess_MC(args,
                    plot=False,
                    pickle_path=None,
                    caption='xp',
                    legend=False,
                    plot_arm_distr=False,
                    bounds_distr=None,
                    lower_bound=False,
                    ):
    t0 = time()
    cpu = mp.cpu_count()
    print('Running on %i cpu' % cpu)
    # Ugly matlabesque args management but will do
    if len(args) == 7:
        bandit, p, T, n_xp, methods, param, store_step = args
        alias = methods
    elif len(args) == 8:
        bandit, p, T, n_xp, methods, param, store_step, alias = args

    if plot_arm_distr:
        plot_distr(mapping[bandit](p), bounds=bounds_distr, legend=legend)

    new_args = (bandit, p, T, n_xp//cpu+1, methods, param, store_step, alias)
    res = Parallel(n_jobs=cpu)(delayed(MC_xp)(new_args) for _ in range(cpu))
    df_r = res[0][0]
    for i in range(cpu-1):
        df_r += res[i+1][0]
    df_r = df_r/cpu
    traj = {}
    for x in alias:
        traj[x] = np.concatenate([res[i][1][x] for i in range(cpu)], axis=1)

    # info = {'proba': p, 'N_xp': n_xp, 'T': T, 'methods': alias, 'param': param, 'step_traj': store_step}
    info = {'proba': p, 'N_xp': n_xp, 'T': T, 'methods': alias, 'step_traj': store_step}
    xp_container = {'df_regret': df_r, 'trajectories': traj, 'info': info}

    if plot:
        plot_xp_res(xp_container, markevery=T // 10, legend=legend, lower_bound=lower_bound)
    if pickle_path is not None:
        pkl.dump(
            # {k: xp_container[k] for k in ('df_regret', 'trajectories') if k in xp_container},
            xp_container,
            open(os.path.join(pickle_path, caption+'.pkl'), 'wb')
            )
    print('Execution time: {:.0f} seconds'.format(time()-t0))
    return df_r, traj


def multiprocess_MC2(args,
                    plot=False,
                    pickle_path=None,
                    caption='xp',
                    legend=False,
                    plot_arm_distr=False,
                    bounds_distr=None,
                    lower_bound=False,
                    ):
    t0 = time()
    cpu = mp.cpu_count()
    print('Running on %i cpu' % cpu)
    # Ugly matlabesque args management but will do
    if len(args) == 7:
        bandit, p, T, n_xp, methods, param, store_step = args
        alias = methods
    elif len(args) == 8:
        bandit, p, T, n_xp, methods, param, store_step, alias = args
    new_args = (bandit, p, T, n_xp//cpu+1, methods, param, store_step, alias)
    res = Parallel(n_jobs=cpu)(delayed(MC_xp2)(new_args) for _ in range(cpu))
    df_r = res[0][0]
    for i in range(cpu-1):
        df_r += res[i+1][0]
    df_r = df_r/cpu
    traj = {}
    for x in alias:
        traj[x] = np.concatenate([res[i][1][x] for i in range(cpu)], axis=1)

    # info = {'proba': p, 'N_xp': n_xp, 'T': T, 'methods': alias, 'param': param, 'step_traj': store_step}
    info = {'proba': p, 'N_xp': n_xp, 'T': T, 'methods': alias, 'step_traj': store_step}
    xp_container = {'df_regret': df_r, 'trajectories': traj, 'info': info}

    if plot:
        plot_xp_res(xp_container, markevery=T // 10, legend=legend, lower_bound=lower_bound)
    if pickle_path is not None:
        pkl.dump(
            # {k: xp_container[k] for k in ('df_regret', 'trajectories') if k in xp_container},
            xp_container,
            open(os.path.join(pickle_path, caption+'.pkl'), 'wb')
            )
    print('Execution time: {:.0f} seconds'.format(time()-t0))
    return df_r, traj

def multiprocess_MCRad(args,
                    plot=False,
                    pickle_path=None,
                    caption='xp',
                    legend=False,
                    plot_arm_distr=False,
                    bounds_distr=None,
                    lower_bound=False,
                    ):
    ##### DOING IT UGLY TO SAVE TIME, MAKE IT CLEAN LATER ######
    t0 = time()
    cpu = mp.cpu_count()
    print('Running on %i cpu' % cpu)
    # Ugly matlabesque args management but will do
    if len(args) == 7:
        bandit, p, T, n_xp, methods, param, store_step = args
        alias = methods
    elif len(args) == 8:
        bandit, p, T, n_xp, methods, param, store_step, alias = args
    new_args = (bandit, p, T, n_xp//cpu+1, methods, param, store_step, alias)
    res = Parallel(n_jobs=cpu)(delayed(MC_xpRad)(new_args) for _ in range(cpu))
    df_r = res[0][0]
    for i in range(cpu-1):
        df_r += res[i+1][0]
    df_r = df_r/cpu
    traj = {}
    for x in alias:
        traj[x] = np.concatenate([res[i][1][x] for i in range(cpu)], axis=1)

    # info = {'proba': p, 'N_xp': n_xp, 'T': T, 'methods': alias, 'param': param, 'step_traj': store_step}
    info = {'proba': p, 'N_xp': n_xp, 'T': T, 'methods': alias, 'step_traj': store_step}
    xp_container = {'df_regret': df_r, 'trajectories': traj, 'info': info}

    if plot:
        plot_xp_res(xp_container, markevery=T // 10, legend=legend, lower_bound=lower_bound)
    if pickle_path is not None:
        pkl.dump(
            # {k: xp_container[k] for k in ('df_regret', 'trajectories') if k in xp_container},
            xp_container,
            open(os.path.join(pickle_path, caption+'.pkl'), 'wb')
            )
    print('Execution time: {:.0f} seconds'.format(time()-t0))
    return df_r, traj


def MC_xpRad(args, plot=False, pickle_path=None, caption='xp'):
    if len(args) == 7:
        bandit, p, T, n_xp, methods, param, store_step = args
        alias = methods
    else:
        bandit, p, T, n_xp, methods, param, store_step, alias = args
    model = RadeMAB(p)
    all_r = []
    all_traj = {}
    add_lower_bound = model.Cp is not None
    for x, x_alias in zip(methods, alias):
        r, traj = model.MC_regret(x, n_xp, T, param[x_alias], store_step)
        all_r.append(r)
        all_traj[x_alias] = traj
    if add_lower_bound:
        all_r.append(model.Cp * np.log(1+np.arange(T)))
    df_r = pd.DataFrame(all_r).T

    if add_lower_bound:
        df_r.columns = alias + ['lower bound']
        df_r['lower bound'].iloc[0] = 0
    else:
        df_r.columns = alias

    if plot:
        df_r.plot(figsize=(10, 8), logx=True)
    if pickle_path is not None:
        pkl.dump(df_r, open(os.path.join(pickle_path, caption+'.pkl'), 'wb'))
    return df_r, all_traj


def MC_xp2(args, plot=False, pickle_path=None, caption='xp'):
    if len(args) == 7:
        bandit, p, T, n_xp, methods, param, store_step = args
        alias = methods
    else:
        bandit, p, T, n_xp, methods, param, store_step, alias = args
    model = GenericMAB(bandit, p)
    all_r = []
    all_traj = {}
    add_lower_bound = model.Cp is not None
    for x, x_alias in zip(methods, alias):
        r, traj = model.MC_regret(x, n_xp, T, param[x_alias], store_step)
        all_r.append(r)
        all_traj[x_alias] = traj
    if add_lower_bound:
        all_r.append(model.Cp * np.log(1+np.arange(T)))
    df_r = pd.DataFrame(all_r).T

    if add_lower_bound:
        df_r.columns = alias + ['lower bound']
        df_r['lower bound'].iloc[0] = 0
    else:
        df_r.columns = alias

    if plot:
        df_r.plot(figsize=(10, 8), logx=True)
    if pickle_path is not None:
        pkl.dump(df_r, open(os.path.join(pickle_path, caption+'.pkl'), 'wb'))
    return df_r, all_traj



def plot_distr(model, sample_size=1000, bounds=None, legend=False, axes_style='darkgrid'):
    with sns.axes_style(axes_style):
        fig, ax = plt.subplots(figsize=(10, 7), nrows=1, ncols=1)
        samples = [arm.sample(sample_size) for arm in model.MAB]
        sns.histplot(samples, ax=ax, alpha=0.1, kde=True, legend=legend)
        if legend:
            ax.set_title('Means: ' + ', '.join(['{:.2f}'.format(mean) for mean in model.means]))
        if bounds is not None:
            ax.set_xlim(bounds)
        plt.tight_layout()
        plt.show()


def plot_xp_res(xp,
                colorset=colorset,
                linestyleset=linestyleset,
                markerset=markerset,
                markevery=1,
                legend=False,
                lower_bound=False,
                axes_style='darkgrid',
                labels={},
                figure_name='',
                ):
    with sns.axes_style(axes_style):
        fig, ax = plt.subplots(figsize=(10, 7), nrows=1, ncols=1)
        X_traj = np.arange(xp['info']['T'])[::xp['info']['step_traj']]

        if not lower_bound:
            methods = [method for method in xp['df_regret'].columns if method != 'lower bound']
        else:
            methods = [method for method in xp['df_regret'].columns]

        for method in methods:
            r = xp['df_regret'][method]
            colors = colorset.get(method, colorset.get('default'))
            linestyle = linestyleset.get(method, 'solid')
            marker = markerset.get(method, None)
            label = labels.get(method, method)
            ax.plot(
                r, color=colors[0], linestyle=linestyle, marker=marker, markevery=markevery, label=label
                )
            if method != 'lower bound':
                traj = xp['trajectories'][method]
                q = np.quantile(traj, [0.05, 0.95], axis=1).T
                ax.fill_between(X_traj, q[:, 0], q[:, -1], color=colors[1], alpha=0.6)
            if legend:
                ax.legend(loc='upper left')
            plt.tight_layout()
            if figure_name:
                plt.savefig(figure_name + '.pdf', format='pdf')

        fig, ax = plt.subplots(figsize=(10, 7), nrows=1, ncols=1)
        for method in methods:
            ax.set_xscale('log')
            r = xp['df_regret'][method]
            colors = colorset.get(method, colorset.get('default'))
            linestyle = linestyleset.get(method, 'solid')
            marker = markerset.get(method, None)
            ax.plot(
                r, color=colors[0], linestyle=linestyle, marker=marker, markevery=markevery, label=method
                )
            if method != 'lower bound':
                traj = xp['trajectories'][method]
                q = np.quantile(traj, [0.05, 0.95], axis=1).T
                ax.fill_between(X_traj, q[:, 0], q[:, -1], color=colors[1], alpha=0.6)
            if legend:
                ax.legend(loc='upper left')
            plt.tight_layout()

        for method in xp['info']['methods']:
            traj = xp['trajectories'][method]
            q = np.quantile(traj, [0.05, 0.1, 0.5, 0.9, 0.95], axis=1).T
            print(method, q[-1, :])
            linestyle = linestyleset.get(method, 'solid')
            colors = colorset.get(method, colorset.get('default'))
            marker = markerset.get(method, None)
            r = xp['df_regret'][method]
            fig, ax = plt.subplots(figsize=(10, 7), nrows=1, ncols=1)
            ax.plot(
                r, color=colors[0], linestyle=linestyle, marker=marker, markevery=markevery, label='Average'
                )
            ax.fill_between(X_traj, q[:, 0], q[:, -1], color=colors[1])
            ax.plot(X_traj, q[:, 1], '--', color=colors[2])
            ax.plot(X_traj, q[:, 3], '--', color=colors[2], label='10%-90% quantile')
            ax.plot(X_traj, q[:, 2], '+', color=colors[2], label='Median')
            ax.set_title(method)
            if legend:
                ax.legend(loc='upper left')
            plt.tight_layout()
