import numpy as np
import matplotlib.pyplot as plt
import math
import random
from tqdm import tqdm
from skip_mech import skip_mech
from plotsettings import BBOX_INCHES, signature, maximizeWindow, palette, makemarkers, add_percent_formatter, legend, show_and_save, nrows_ncols, violin_or_box_plot, adjust_xticks_subplots, table_to_latex


class DTS(object):
    def __init__(
        self,
        repetitions,
        nb_arms,
        nb_break_points,
        horizon,
        Env,   
        path = str,   
        klUCB = False,
        gamma = None
    ):
        self.repetitions = repetitions
        self.nb_arms = nb_arms
        self.nb_break_points = nb_break_points
        self.horizon = horizon
        self.Env = Env.Env
        self.path = path
        self.klUCB = klUCB
        
        if gamma == None:
            gamma = 1-np.sqrt(self.nb_break_points/self.horizon)
        self.gamma = gamma
        
    def __call__(self):
        X_mean, X_std, A_mean, A_std = self.mainAlg()      
        return X_mean, X_std, A_mean, A_std   
    
    def __str__(self):
        return r"Discounted-TS"
    
    def mainAlg(self,a=1,b=1):
        print("Running ... Discounted TS")
        A_mem = np.zeros((self.repetitions, self.nb_arms, self.horizon), dtype = int)
        X_mem = np.zeros((self.repetitions, self.horizon), dtype = float)
        DiscNumber = np.zeros(self.nb_arms, float)
        DiscSum = np.zeros(self.nb_arms, float)
        Time =0
        for i in tqdm(range(self.repetitions)):
            # UCB = np.zeros(self.nb_arms,dtype = float)
            # UCB[:] = 10000
            ChosenArms      = np.zeros(self.horizon, dtype = int)
            ReceivedRewards = np.zeros( self.horizon, dtype = float)
            for t in range(self.horizon):
                indices = np.ones(self.nb_arms, dtype = float)
                for k in range(self.nb_arms):
                    if DiscNumber[k]>0:
                        indices[k] = np.random.beta(a + DiscSum[k], b + DiscNumber[k] - DiscSum[k])
                I = np.argmax(indices)
                # get the reward  
                rew = self.Env[i][I][t]
                ChosenArms[t]=I 
                ReceivedRewards[t]=rew
                # update everything
                DiscNumber = self.gamma*DiscNumber
                DiscNumber[I]+=1
                DiscSum = self.gamma*DiscSum
                DiscSum[I]+=rew
                X_mem[i][t] = rew
                A_mem[i][I][t] = 1
                Time = 1+self.gamma*Time
        X_mem_cum = np.cumsum(X_mem,axis=1)
        # print("Roll data saving ...")
        # if path != None:
        #     filename = path+'/'+'roll_data'+self.__str__()+str(random.randint(100000,999999))+'.csv'
        #     np.savetxt(filename, X_mem_cum ,delimiter=",")
        X_mean = np.mean(X_mem, axis=0)
        X_std = np.std(X_mem_cum, axis=0)
        A_mem_cum = np.cumsum(A_mem,axis=2)
        A_std = np.std(A_mem_cum, axis=0)
        A_mean = np.mean(A_mem_cum, axis=0)
        # self.plot_arm(A_mean, A_std)
        return X_mean, X_std, A_mean, A_std  