""" Packages import """
import numpy as np
import MAB.arms as arms
# from tqdm import tqdm_notebook as tqdm  # Modify when running in console
from tqdm import tqdm
from .utils import rd_argmax, rollavg_bottlneck, get_leader
from .tracker import Tracker2
from scipy.optimize import minimize_scalar, root_scalar
import cvxpy as cp


mapping = {
    'B': arms.ArmBernoulli,
    'Beta': arms.ArmBeta,
    'Dirac': arms.ArmDirac,
    'Emp': arms.ArmEmpirical,
    'Exp': arms.ArmExponential,
    'G': arms.ArmGaussian,
    'GMx': arms.ArmGaussianMixture,
    'LG': arms.ArmLogGaussian,
    'M': arms.ArmMultinomial,
    # 'NegExp': arms.ArmNegativeExponential,
    'P': arms.ArmPoisson,
    'Par': arms.ArmPareto,
    'TG': arms.ArmTG,
    'U': arms.ArmUniform,
    'Rade': arms.ArmRademacher
    }


def default_exp(x):
    #return 0
    return np.sqrt(np.log(1+x))


class GenericMAB:
    """
    Generic class to simulate a Multi-Arm Bandit problem
    """
    def __init__(self, methods, p):
        """
        Initialization of the arms
        :param methods: string, probability distribution of each arm
        :param p: np array or list, parameters of the probability distribution of each arm
        """
        self.MAB = self.generate_arms(methods, p)
        self.nb_arms = len(self.MAB)
        self.means = np.array([el.mean for el in self.MAB])
        self.mu_max = np.max(self.means)
        self.mc_regret = None
        self.Cp = None

    @staticmethod
    def generate_arms(methods, p):
        """
        Method for generating different arms
        :param methods: string, probability distribution of each arm
        :param p: np array or list, parameters of the probability distribution of each arm
        :return: list of class objects, list of arms
        """
        arms_list = list()
        for i, m in enumerate(methods):
            args = [p[i]] + [[np.random.randint(1, 312414)]]
            args = sum(args, []) if type(p[i]) == list else args
            alg = mapping[m]
            arms_list.append(alg(*args))
        return arms_list

    @staticmethod
    def kl(x, y):
        return None

    def MC_regret(self, method, N, T, param_dic, store_step=-1):
        """
        Implementation of Monte Carlo method to approximate the expectation of the regret
        :param method: string, method used (UCB, Thomson Sampling, etc..)
        :param N: int, number of independent Monte Carlo simulation
        :param T: int, time horizon
        :param param_dic: dict, parameters for the different methods, can be the value of rho for UCB model or an int
        corresponding to the number of rounds of exploration for the ExploreCommit method
        """
        mc_regret = np.zeros(T)
        store = store_step > 0
        if store:
            all_regret = np.zeros((np.arange(T)[::store_step].shape[0], N))
        alg = self.__getattribute__(method)
        for i in tqdm(range(N), desc='Computing ' + str(N) + ' simulations'):
            tr = alg(T, **param_dic)
            regret = tr.regret()
            mc_regret += regret
            if store:
                all_regret[:, i] = regret[::store_step]
        if store:
            return mc_regret / N, all_regret
        return mc_regret / N

    def ExploreCommit(self, T, m):
        """
        Implementation of Explore-then-Commit algorithm
        :param T: int, time horizon
        :param m: int, number of rounds before choosing the best action
        :return: np.arrays, reward obtained by the policy and sequence of chosen arms
        """
        tr = Tracker2(self.means, T)
        for t in range(m * self.nb_arms):
            arm = t % self.nb_arms
            tr.update(t, arm, self.MAB[arm].sample()[0])
        arm = rd_argmax(tr.Sa / tr.Na)
        for t in range(m * self.nb_arms, T):
            tr.update(t, arm, self.MAB[arm].sample()[0])
        return tr

    def Index_Policy(self, T, index_func, start_explo=1, store_rewards_arm=False):
        """
        Implementation of generic Index Policy algorithm
        :param T: int, time horizon
        :param start_explo: number of time to explore each arm before comparing index
        :param index_func: function which computes the index with the tracker
        :return: np arrays, reward obtained by the policy and sequence of chosen arms
        """
        tr = Tracker2(self.means, T, store_rewards_arm)
        for t in range(T):
            if t < self.nb_arms*start_explo:
                arm = t % self.nb_arms
            else:
                arm = rd_argmax(index_func(tr))
            reward = self.MAB[arm].sample()[0]
            tr.update(t, arm, reward)
        return tr

    def UCB1(self, T, rho=1.):
        """
        :param T: Time Horizon
        :param rho: coefficient for the upper bound
        :return:
        """
        def index_func(x):
            return x.Sa / x.Na + rho * np.sqrt(2 * np.log(x.t + 1) / x.Na)
        return self.Index_Policy(T, index_func)


    def Generic_MED(self, T, divergence, start_explo=1, store_rewards_arm=True):
        tr = Tracker2(self.means, T, store_rewards_arm)
        for t in range(T):
            if t < self.nb_arms * start_explo:
                arm = t % self.nb_arms
            else:
                divs = np.zeros(self.nb_arms)
                for k in range(self.nb_arms):
                    divs[k] = divergence(k, tr)
                probs = np.exp(-tr.Na * divs)/ np.exp(-tr.Na * divs).sum()
                arm = np.random.choice(np.arange(self.nb_arms), p=probs)
            reward = self.MAB[arm].sample()[0]
            tr.update(t, arm, reward)
        return tr

    def NPTS(self, T, upper_bound=1):
        """
        Implementation of the Non-Parametric Thompson Sampling algorithm
        :param T: Time Horizon
        :param upper_bound: known support upper bound
        :return: Tracker object with the results of the run
        """
        tr = Tracker2(self.means, T)
        if upper_bound is not None:
            X = [[upper_bound] for _ in range(self.nb_arms)]
        tr.Na = tr.Na + 1
        for t in range(T):
            V = np.zeros(self.nb_arms)
            for i in range(self.nb_arms):
                V[i] = np.inner(np.random.dirichlet(np.ones(int(tr.Na[i]))), np.array(X[i]))
            arm = rd_argmax(V)
            tr.update(t, arm, self.MAB[arm].sample()[0])
            X[arm].append(tr.reward[t])
        return tr

    def LB_SDA(self, T, explo_func=default_exp):
        """
        Implementation of the LB-SDA algorithm
        :param T: Time Horizon
        :param explo_func: Forced exploration function
        :return: Tracker object with the results of the run
        """
        tr = Tracker2(self.means, T, store_rewards_arm=True)
        r, t, l = 1, 0, -1
        while t < self.nb_arms:
            arm = t
            tr.update(t, arm, self.MAB[arm].sample()[0])
            t += 1
        while t < T:
            l_prev = l
            l = get_leader(tr.Na, tr.Sa, l_prev)
            _, forced_explo = t, explo_func(r)
            indic = (tr.Na < tr.Na[l]) * (tr.Na < forced_explo) * 1.
            for j in range(self.nb_arms):
                if indic[j] == 0 and j != l and tr.Na[j] < tr.Na[l]:
                    lead_mean = np.mean(tr.rewards_arm[l][-int(tr.Na[j]):])
                    if tr.Sa[j]/tr.Na[j] >= lead_mean and t < T:
                        indic[j] = 1
            if indic.sum() == 0:
                tr.update(t, l, self.MAB[l].sample()[0])
                t += 1
            else:
                to_draw = np.where(indic == 1)[0]
                np.random.shuffle(to_draw)
                for i in to_draw:
                    if t < T:
                        tr.update(t, i, self.MAB[i].sample()[0])
                        t += 1
            r += 1
        return tr

    def PHE(self, T, a, distrib):
        """
        :param T: Time Horizon
        :param a: proportion of perturbed history. a=1 -> same proportion, a=0-> no perturbed history
        :param T: Time Horizon
        :return:
        """
        tr = Tracker2(self.means, T, store_rewards_arm=True)
        for t in range(T):
            if t < self.nb_arms:
                arm = t
            else:
                idx_mean = np.zeros(self.nb_arms)
                for k in range(self.nb_arms):
                    ph = distrib.rvs(size=np.int(a*tr.Na[k])+1)
                    idx_mean[k] = (tr.Sa[k]+ph.sum())/(tr.Na[k]+np.int(a*tr.Na[k])+1)
                arm = rd_argmax(idx_mean)
            reward = self.MAB[arm].sample()[0]
            tr.update(t, arm, reward)
        return tr


    def IMED(self, T):
        def index_func(x):
            mu_max = np.max(x.Sa/x.Na)
            idx = []
            for k in range(self.nb_arms):
                idx.append(x.Na[k]*self.kl(x.Sa[k]/x.Na[k], mu_max)+np.log(x.Na[k]))
            return -np.array(idx)
        return self.Index_Policy(T, index_func)


    def SGB(self, T, eta):
        tr = Tracker2(self.means, T, store_rewards_arm=True)
        def softmax(theta):
            return np.exp(theta)/np.exp(theta).sum()
        theta = np.zeros(self.nb_arms)
        for t in range(T):
            p = softmax(theta)
            arm = np.random.choice(np.arange(self.nb_arms), p=p)
            reward = self.MAB[arm].sample()[0]
            tr.update(t, arm, reward)
            for k in range(self.nb_arms):
                if k == arm:
                    theta[k] += eta * (1-p[k]) * reward
                else:
                    theta[k] -= eta * p[k] * reward
        return tr


    def SGB_decay(self, T, eta):
        tr = Tracker2(self.means, T, store_rewards_arm=True)
        def softmax(theta):
            return np.exp(theta)/np.exp(theta).sum()
        theta = np.zeros(self.nb_arms)
        for t in range(T):
            p = softmax(theta)
            arm = np.random.choice(np.arange(self.nb_arms), p=p)
            reward = self.MAB[arm].sample()[0]
            tr.update(t, arm, reward)
            for k in range(self.nb_arms):
                if k == arm:
                    theta[k] += eta(1+t) * (1-p[k]) * reward
                else:
                    theta[k] -= eta(1+t) * p[k] * reward
        return tr

    def SAMBA(self, T, alpha):
        """
        The guarantees of the algorithm are proved for Bernoulli rewards.
        """
        tr = Tracker2(self.means, T, store_rewards_arm=True)
        prob = np.ones(self.nb_arms)/self.nb_arms
        for t in range(T):
            arm = np.random.choice(np.arange(self.nb_arms), p=prob)
            reward = self.MAB[arm].sample()[0]
            tr.update(t, arm, reward)
            a_star = rd_argmax(prob)
            if a_star == arm:
                prob = prob - alpha * prob ** 2 * reward / prob[a_star]
            else:
                prob[arm] += alpha * prob[arm] * reward
            p_star = prob.sum() - prob[a_star]  
            prob[a_star] = 1 - p_star
        return tr

    def TS(self, T):
        #### Only for Bernoulli rewards ###
        def f(x):
            return np.random.beta(x.Sa+1, x.Na-x.Sa+1)
        return self.Index_Policy(T, f)
