import numpy as np
import matplotlib.pyplot as plt
import math
import random
from tqdm import tqdm
from skip_mech import skip_mech
from plotsettings import BBOX_INCHES, signature, maximizeWindow, palette, makemarkers, add_percent_formatter, legend, show_and_save, nrows_ncols, violin_or_box_plot, adjust_xticks_subplots, table_to_latex
import sys

class ArmSwitch(object):
    def __init__(
        self, 
        repetitions,
        nb_arms,
        nb_break_points,
        horizon,
        Env,
        path = str,
    ):
        self.repetitions = repetitions
        self.nb_arms = nb_arms
        self.nb_break_points = nb_break_points
        self.horizon = horizon
        self.Env = Env.Env

        self.path = path     
        
    def __call__(self):
        X_mean, X_std, A_mean, A_std = self.mainAlg()      
        return X_mean, X_std, A_mean, A_std

    def __str__(self):
        return r"ArmSwitch"
    
    def progress_bar(self, iteration, total, length=40):
        progress = iteration / total
        arrow = '=' * int(round(progress * length) - 1)
        spaces = ' ' * (length - len(arrow))
        sys.stdout.write(f'\r[{arrow}{spaces}] {iteration}/{total}')
        sys.stdout.flush()
    
    def mainAlg(self, delta = 0.1):
        # table contains a table of rewards for the K arms up to horizon T
        print("Running ... ArmSwitch")
        A_mem = np.zeros((self.repetitions, self.nb_arms, self.horizon), dtype = int)
        X_mem = np.zeros((self.repetitions, self.horizon), dtype = float)
        for r in tqdm(range(self.repetitions)):

            # index of the episode
            episode = 1
            A = np.zeros(self.horizon, dtype=int)
            X = np.zeros(self.horizon)
            reward_map = np.zeros((self.nb_arms, self.horizon))
            change_points = []
            P = np.zeros((self.horizon, self.nb_arms))
            s = 0
            n = 0  # current number of samples

            while n < self.horizon:
                self.progress_bar(n + 1, self.horizon)
                s = s + 1
                t_s = n + 1
                # initialize the set of good and bad arms
                GOOD = list(range(self.nb_arms))
                BAD = []
                B = np.zeros(self.nb_arms)
                active = np.ones(self.nb_arms)*(n+1)
                new_episode = False
                n = n + 1
                self.progress_bar(n + 1, self.horizon)
                while not new_episode and n < self.horizon:
                    # print("GOOD set {}".format(GOOD))
                    # print("BAD set {}".format(BAD))
                    
                    # print(n)
                    B_set = []
                    B_set.extend([2**(-i) for i in range(1, int(math.ceil(math.log(self.horizon, 2))))])
                    for a in BAD:
                        for epslion in B_set:
                            # With prob epslion*np.sqrt(s/(self.nb_arms*self.horizon)
                            if np.random.rand()<epslion*np.sqrt(s/(self.nb_arms*self.horizon)):
                                if B[a] <= 0:
                                    active[a] = n
                                B[a] = np.max([B[a], 1/(epslion)**2]) 
                    
                    Active_set = GOOD.copy()
                    Active_set.extend([a for a in BAD if B[a]>=1/self.nb_arms])
                    # print("Active_set set {}".format(Active_set))
                    
                    for a in range(self.nb_arms):
                        if a not in Active_set:
                            B[a] = 0
                       
                    m = 0
                    for a in BAD:
                        if B[a]>=1/self.nb_arms:
                            m = m+1
                    # print("m = {}".format(m))
                    
                    # P.append(len(Active_set))
                    for a in Active_set:
                        if a in BAD:
                            P[n][a] = 1/self.nb_arms
                            # print(m/self.nb_arms)
                            # print((1-(m/self.nb_arms)))
                            # print(GOOD)
                            # print(len(GOOD))
                        else:
                            P[n][a] = (1-(m/self.nb_arms))/len(GOOD)

                    # print(P[n])
                    I = np.random.choice(list(range(self.nb_arms)), p=P[n])
                    rew = self.Env[r][I][n]
                    A[n] = I 
                    X[n] = rew
                    reward_map[I][n] = rew
                    X_mem[r][n] = rew
                    A_mem[r][I][n] = 1

                    for a in Active_set:
                        if a in BAD:
                            B[a] = B[a] - 1/self.nb_arms
                  
                    for a in Active_set:
                        for a_prime in Active_set:
                            # print(a)
                            # print(a_prime)
                            ELIM = False
                            # print(type(active[a]))
                            # print(type(active[a_prime]))
                            for n_prime in range(np.max([int(active[a]), int(active[a_prime])]), n):
                                # print(reward_map[a, n_prime:n])
                                # print(reward_map[a_prime, n_prime:n])
                                # print(np.sum(reward_map[a, n_prime:n]))
                                # print(np.sum(reward_map[a_prime, n_prime:n]))
                                Delta_tilde_hat = np.sum(reward_map[a, n_prime:n]) - np.sum(reward_map[a_prime, n_prime:n])
                                Delta_hat = np.sum(reward_map[a, n_prime:n]) - np.sum(reward_map[a_prime, n_prime:n])
                                C = np.sqrt(np.log((2*self.nb_arms*(self.horizon**2)*(np.log(n - n_prime + 1)+2))/delta))
                                # print(C)
                                # print(Delta_hat)
                                # print(np.shape(Delta_tilde_hat))
                                # print(np.shape(Delta_hat))
                                # print(np.shape(np.max([np.sqrt((n-n_prime+1)/self.nb_arms), C])))
                                # print(np.shape(12*C*(np.max([np.sqrt(P[n_prime:n,a]), C]))))
                                # print((np.abs(Delta_tilde_hat) > 12*C*(np.max([np.sqrt((n-n_prime+1)/self.nb_arms), C]))))
                                # print((np.abs(Delta_hat) > (12*C*(np.max([np.sqrt(np.max(P[n_prime:n,a])), C])))))
                                # print((a, a_prime in GOOD))
                                if (Delta_tilde_hat > C*(np.max([np.sqrt((n-n_prime+1)/self.nb_arms), C]))) or ((Delta_hat > (C*(np.max([np.sqrt(np.max(P[n_prime:n,a])), C])))) and (a, a_prime in GOOD)):
                                    ELIM = True
                                    # print("ELIM True")
                            
                            if ELIM:
                                if a_prime in GOOD:
                                    print("arm {} from GOOD to BAD".format(a_prime))
                                    GOOD.remove(a_prime)
                                    BAD.append(a_prime)
                                if a_prime in BAD:
                                    B[a] = 0
                            if len(GOOD) == 0:
                                new_episode = True
                                print("New episode")
                                n = n - 1
                    n = n + 1
                    self.progress_bar(n + 1, self.horizon)
                   
            # End of the outer while loop
        X_mem_cum = np.cumsum(X_mem,axis=1)
        # print("Roll data saving ...")
        # if path != None:
        #     filename = path+'/'+'roll_data'+self.__str__()+str(random.randint(100000,999999))+'.csv'
        #     np.savetxt(filename, X_mem_cum ,delimiter=",")
        X_mean = np.mean(X_mem, axis=0)
        X_std = np.std(X_mem_cum, axis=0)
        A_mem_cum = np.cumsum(A_mem,axis=2)
        A_std = np.std(A_mem_cum, axis=0)
        A_mean = np.mean(A_mem_cum, axis=0)
        # self.plot_arm(A_mean, A_std)
        return X_mean, X_std, A_mean, A_std
    
