import numpy as np
import matplotlib.pyplot as plt
import math
import random
from tqdm import tqdm
from skip_mech import skip_mech
from plotsettings import BBOX_INCHES, signature, maximizeWindow, palette, makemarkers, add_percent_formatter, legend, show_and_save, nrows_ncols, violin_or_box_plot, adjust_xticks_subplots, table_to_latex


class DUCB(object):
    def __init__(
        self,
        repetitions,
        nb_arms,
        nb_break_points,
        horizon,
        Env,   
        path = str,   
        klUCB = False,
        gamma = None
    ):
        self.repetitions = repetitions
        self.nb_arms = nb_arms
        self.nb_break_points = nb_break_points
        self.horizon = horizon
        self.Env = Env.Env
        self.path = path
        self.klUCB = klUCB
        
        if gamma == None:
            gamma = 1-np.sqrt(self.nb_break_points/self.horizon)
        self.gamma = 0.99
        print(self.gamma)
        
    def __call__(self):
        if self.klUCB:
            X_mean, X_std, A_mean, A_std = self.mainklAlg()      
        else:
            X_mean, X_std, A_mean, A_std = self.mainAlg()     
        return X_mean, X_std, A_mean, A_std   
    
    def __str__(self):
        kl = r"{}".format("kl" if self.klUCB else "")
        return r"Discounted-{}UCB".format(kl)
    
    def mainAlg(self):
        print("Running ... Discounted UCB")
        A_mem = np.zeros((self.repetitions, self.nb_arms, self.horizon), dtype = int)
        X_mem = np.zeros((self.repetitions, self.horizon), dtype = float)
        DiscNumber = np.zeros(self.nb_arms, float)
        DiscSum = np.zeros(self.nb_arms, float)
        Time =0
        for i in tqdm(range(self.repetitions)):
            # UCB = np.zeros(self.nb_arms,dtype = float)
            # UCB[:] = 10000
            ChosenArms      = np.zeros(self.horizon, dtype = int)
            ReceivedRewards = np.zeros( self.horizon, dtype = float)
            UCB = np.zeros(self.nb_arms,dtype = float)
            UCB[:] = 10000
            Z = np.zeros((self.nb_arms, self.horizon), dtype= int )
            X = np.zeros(self.horizon, dtype = float)     
            n = np.zeros(self.nb_arms, dtype = int)
            for t in range(self.horizon):
                for k in range(self.nb_arms):
                    if DiscNumber[k]>0:
                        UCB[k] = DiscSum[k]/DiscNumber[k]+np.sqrt(2*np.log(n[k])/DiscNumber[k])
                I = np.argmax(UCB)
                # get the reward  
                n[I] = n[I] + 1
                rew = self.Env[i][I][t]
                ChosenArms[t]=I 
                ReceivedRewards[t]=rew
                # update everything
                DiscNumber = self.gamma*DiscNumber
                DiscNumber[I]+=1
                DiscSum = self.gamma*DiscSum
                DiscSum[I]+=rew
                X_mem[i][t] = rew
                A_mem[i][I][t] = 1
                Time = 1+self.gamma*Time
        X_mem_cum = np.cumsum(X_mem,axis=1)
        # print("Roll data saving ...")
        # if path != None:
        #     filename = path+'/'+'roll_data'+self.__str__()+str(random.randint(100000,999999))+'.csv'
        #     np.savetxt(filename, X_mem_cum ,delimiter=",")
        X_mean = np.mean(X_mem, axis=0)
        X_std = np.std(X_mem_cum, axis=0)
        A_mem_cum = np.cumsum(A_mem,axis=2)
        A_std = np.std(A_mem_cum, axis=0)
        A_mean = np.mean(A_mem_cum, axis=0)
        # self.plot_arm(A_mean, A_std)
        return X_mean, X_std, A_mean, A_std    
    

    def mainklAlg(self,gamma=0.75, path = None):
        print("Running ... Discounted klUCB")
        A_mem = np.zeros((self.repetitions, self.nb_arms, self.horizon), dtype = int)
        X_mem = np.zeros((self.repetitions, self.horizon), dtype = float)
        DiscNumber = np.zeros(self.nb_arms, float)
        DiscSum = np.zeros(self.nb_arms, float)
        Time =0
        for i in tqdm(range(self.repetitions)):
            # UCB = np.zeros(self.nb_arms,dtype = float)
            # UCB[:] = 10000
            ChosenArms      = np.zeros(self.horizon, dtype = int)
            ReceivedRewards = np.zeros( self.horizon, dtype = float)
            for t in range(self.horizon):
                indices = np.ones(self.nb_arms, dtype = float)
                for k in range(self.nb_arms):
                    if DiscNumber[k]>0:
                        indices[k] = self.klUp(DiscSum[k]/DiscNumber[k],np.log(Time)/DiscNumber[k])
                I = np.argmax(indices)
                # get the reward  
                rew = self.Env[i][I][t]
                ChosenArms[t]=I 
                ReceivedRewards[t]=rew
                # update everything
                DiscNumber = gamma*DiscNumber
                DiscNumber[I]+=1
                DiscSum = gamma*DiscSum
                DiscSum[I]+=rew
                X_mem[i][t] = rew
                A_mem[i][I][t] = 1
                Time = 1+gamma*Time
        X_mem_cum = np.cumsum(X_mem,axis=1)
        # print("Roll data saving ...")
        # if path != None:
        #     filename = path+'/'+'roll_data'+self.__str__()+str(random.randint(100000,999999))+'.csv'
        #     np.savetxt(filename, X_mem_cum ,delimiter=",")
        X_mean = np.mean(X_mem, axis=0)
        X_std = np.std(X_mem_cum, axis=0)
        A_mem_cum = np.cumsum(A_mem,axis=2)
        A_std = np.std(A_mem_cum, axis=0)
        A_mean = np.mean(A_mem_cum, axis=0)
        # self.plot_arm(A_mean, A_std)
        return X_mean, X_std, A_mean, A_std    
    
    def klChange(self, p, q):
        # binary relative entropy
        res = 0
        if p != q:
            if p <= 0:
                p = np.finfo(float).eps
            if p >= 1:
                p = 1 - np.finfo(float).eps
            res = p * np.log(p / q) + (1 - p) * np.log((1 - p) / (1 - q))
        return res
    
    def klUp(self,p,level):
    # KL upper confidence bound:
    # return uM>p such that d(p,uM)=level 
        lM = p 
        uM = np.min([np.min([1, p + np.sqrt(level/2)]),1]) 
        for j in range(16):
            qM = (uM+lM)/2
            if self.klChange(p,qM) > level:
                uM= qM
            else:
                lM=qM
        return uM