import numpy as np
from joblib import Parallel, delayed
from math import sqrt,log,exp
import random

##############################################################################################################################
                                             #Line of code to run bandit setting
##############################################################################################################################

def randmax(A):     # take one random element from argmax
    maxValue=max(A)
    index = [i for i in range(len(A)) if A[i]==maxValue]
    return np.random.choice(index)

def randmin(A):     # take one random element from argmin
    minValue=min(A)
    index = [i for i in range(len(A)) if A[i]==minValue]
    return np.random.choice(index)

class MAB:
    def __init__(self,arms):
        """given a list of arms, create the MAB environnement"""
        self.arms = arms
        self.nbArms = len(arms)
        self.means = [arm.mean for arm in arms]
        self.bestarm = np.argmax(self.means)
        
        
    
    def generateReward(self,arm,rng):
        return self.arms[arm].sample(rng)
       
            
def OneBanditOneLearnerOneRun(bandit, strategy, timeHorizon, prior_distrib,N_sample,dist,sigma_prior, rng):
    """
    Run a bandit strategy (strategy) on a MAB instance (bandit) for (timeHorizon) time steps
    output : sequence of arms chosen, sequence of rewards obtained
    """
    selections = []
    rewards = []
    strategy.clear() # reset previous history
    prior = [ [prior_distrib.generateReward(k,rng) for _ in range(N_sample[k])] for k in range(len(N_sample)) ]
    strategy.get(prior,dist,sigma_prior)
    t = 0
    while t < timeHorizon:
        # choose the next arm to play with the bandit algorithm
        arm = strategy.chooseArmToPlay()
        # get the reward of the chosen arm
        for a in arm:
            reward = bandit.generateReward(a, rng) 
            #reward = bandit.generateReward(a)
            strategy.receiveReward(a, reward)
            t = t + 1
            selections.append(a)
            rewards.append(reward)
    return selections, rewards

    
def CumulativeRegret(bandit,selections):
    """Compute the pseudo-regret associated to a sequence of arm selections"""
    T = len(selections)
    meansB = np.array(bandit.means)
    mustar = max(meansB)
    return np.cumsum(mustar*np.ones(T)-meansB[selections])
            
def OneBanditOneLearnerMultipleRuns(bandit, strategy, timeHorizon, N_exp, prior_distrib,N_sample,dist,sigma_prior,tsave=[], seed = 42):
    """
    Perform N_exp runs of a bandit strategy (strategy) on a MAB instance (bandit) for (timeHorizon) time steps 
    and compute the pseudo-regret of each run 
    optional : tsave is a vector of time steps in which the results will be stored (set to 1:timeHorizon by default)
    output : a table of size N_exp x |tsave| in which each row is the pseudo-regret at the sub-sampled times 
    """
    
    if (len(tsave) == 0):
        tsave = [i+1 for i in range(timeHorizon)]
    
    def do_one_xp(seed):
        rng = np.random.RandomState(seed)
        selections, rewards = OneBanditOneLearnerOneRun(bandit, strategy, timeHorizon, prior_distrib,N_sample,dist,sigma_prior,rng)
        regret_one_run = CumulativeRegret(bandit, selections)
        return np.array(regret_one_run)[tsave-1] 
    
    Regrets = Parallel(n_jobs=-1)(delayed(do_one_xp)(seed) for seed in range(N_exp))


    return np.array(Regrets)


##############################################################################################################################
                                             # Gaussian arm
##############################################################################################################################

class Gaussian:

    def __init__(self,mu,var=1):
        # create a Gaussian arm with specified mean and variance
        self.mean = mu
        self.variance = var

    def sample(self, rng):
        # generate a reward from a Gaussian arm 
        return self.mean + sqrt(self.variance)*rng.normal()
    

##############################################################################################################################
                                             # KL_UCB_Transfer
##############################################################################################################################

class KL_UCB_Transfer:
    def __init__(self, nbArms, sigma,epsilon = 1/20):
        self.nbArms = nbArms
        self.sigma = sigma
        self.epsilon = epsilon
        self.clear()

    def clear(self):
        self.Draws = np.zeros(self.nbArms, dtype=int)
        self.sum_rewards = np.zeros(self.nbArms, dtype=float)
        self.t = 0
        self.L = None
        self.N_prime = None
        self.mu_p = None
        self.sigmap = None


    def get(self, prior_data, L,sigma_prior):
        self.L = np.array(L, dtype=float)
        self.N_p = np.array([len(prior_data[k]) for k in range(self.nbArms)], dtype=float)
        self.mu_p = [np.mean(prior_data[k]) for k in range(self.nbArms)]
        self.sigmap = sigma_prior

    def _compute_index(self, mu_a,mup_a,L_a,N_a,Np_a,sigma,sigmap,delta):
        if Np_a == 0:
            return mu_a + np.sqrt(delta * 2 * (sigma**2)/N_a)
        alpha = N_a/(2* (sigma**2))
        beta = Np_a/(2* (sigmap**2))
        if mu_a >= mup_a + L_a and np.sqrt(delta / beta) <= mu_a - (mup_a + L_a):
            return mup_a + L_a + np.sqrt(delta / beta)
        elif mup_a + L_a >= mu_a and np.sqrt(delta / alpha) <= (mup_a + L_a) - mu_a:
            return mu_a + np.sqrt(delta / alpha)
        else:
            term1 = alpha*mu_a + beta*(mup_a + L_a)
            term2 = (alpha+beta)*delta - alpha*beta*((mu_a - (mup_a + L_a))**2)
            return (term1 + np.sqrt(term2))/(alpha + beta)
                                             

    def chooseArmToPlay(self):
        zeros = [k for k in range(self.nbArms) if self.Draws[k] == 0]
        if zeros:
            return zeros
        mean = [self.sum_rewards[k]/self.Draws[k] for k in range(self.nbArms)]
        indices = []
        sigma = self.sigma
        sigmap = self.sigmap
        delta = (1 + self.epsilon)*np.log(self.t) 
        for a in range(self.nbArms):
            mu_a = mean[a]
            mup_a = self.mu_p[a]
            L_a = self.L[a]
            N_a = self.Draws[a]
            Np_a = self.N_p[a]
            indices.append(self._compute_index(mu_a,mup_a,L_a,N_a,Np_a,sigma,sigmap,delta))
        return [randmax(indices)]

    def receiveReward(self, arm, reward):
        self.Draws[arm] += 1
        self.sum_rewards[arm] += reward
        self.t += 1

    def name(self):
        return "KL_UCB_Transfer"
    
##############################################################################################################################
                                             # AST_UCB
##############################################################################################################################

class AST_UCB:
    def __init__(self, nbArms, sigma,alpha):
        self.nbArms = nbArms
        self.sigma = sigma
        self.alpha = alpha
        self.clear()

    def clear(self):
        self.Draws = np.zeros(self.nbArms, dtype=int)
        self.sum_rewards = np.zeros(self.nbArms, dtype=float)
        self.t = 0
        self.L = None
        self.N_prime = None
        self.mu_p = None
        self.sigmap = None


    def get(self, prior_data, L,sigma_prior):
        self.L = np.array(L, dtype=float)
        self.N_p = np.array([len(prior_data[k]) for k in range(self.nbArms)], dtype=float)
        self.sum_p = [sum(prior_data[k]) for k in range(self.nbArms)]
        self.sigmap = sigma_prior

    def chooseArmToPlay(self):
        zeros = [k for k in range(self.nbArms) if self.Draws[k] == 0]
        if zeros:
            return zeros
        Sk = [self.Draws[k]+self.N_p[k] for k in range(self.nbArms)]
        
        mu1 = [self.sum_rewards[k]/self.Draws[k] for k in range(self.nbArms)]
        mu2 = [(self.sum_rewards[k] + self.sum_p[k])/Sk[k] for k in range(self.nbArms)]
        
        p1 = [np.sqrt(2*self.alpha*np.log(self.t)*(self.sigma**2)/self.Draws[k]) for k in range(self.nbArms)]
        p2 = [np.sqrt(2*self.alpha*np.log(self.t)*(self.sigma**2)/Sk[k]) + (self.N_p[k] /Sk[k] ) for k in range(self.nbArms)]
        
        indices = [min(mu1[k] + p1[k],mu2[k] + p2[k]) for k in range(self.nbArms)]
        
        return [randmax(indices)]

    def receiveReward(self, arm, reward):
        self.Draws[arm] += 1
        self.sum_rewards[arm] += reward
        self.t += 1

    def name(self):
        return "AST_UCB"