import numpy as np
import matplotlib.pyplot as plt
import math
import random
from tqdm import tqdm
from skip_mech import skip_mech
from plotsettings import BBOX_INCHES, signature, maximizeWindow, palette, makemarkers, add_percent_formatter, legend, show_and_save, nrows_ncols, violin_or_box_plot, adjust_xticks_subplots, table_to_latex
import sys


class AdSwitch(object):
    def __init__(
        self, 
        repetitions,
        nb_arms,
        nb_break_points,
        horizon,
        Env,
        path = str,
        sub_sample=50, 
        sub_sample2=20,
        C = 1
    ):
        self.repetitions = repetitions
        self.nb_arms = nb_arms
        self.nb_break_points = nb_break_points
        self.horizon = horizon
        self.Env = Env.Env

        self.path = path        
        
        self.sub_sample = sub_sample
        self.sub_sample2 = sub_sample2
        self.C = C
        
    def __call__(self):
        X_mean, X_std, A_mean, A_std = self.mainAlg()      
        return X_mean, X_std, A_mean, A_std   
    
    def __str__(self):
        return r"AdSwitch"
    
    def progress_bar(self, iteration, total, length=40):
        progress = iteration / total
        arrow = '=' * int(round(progress * length) - 1)
        spaces = ' ' * (length - len(arrow))
        sys.stdout.write(f'\r[{arrow}{spaces}] {iteration}/{total}')
        sys.stdout.flush()
    
    def mainAlg(self):
        # table contains a table of rewards for the K arms up to horizon T
        print("Running ... AdSwitch")
        A_mem = np.zeros((self.repetitions, self.nb_arms, self.horizon), dtype = int)
        X_mem = np.zeros((self.repetitions, self.horizon), dtype = float)
        for r in tqdm(range(self.repetitions)):

            # index of the episode
            episode = 1
            chosen_arms = np.zeros(self.horizon, dtype=int)
            received_rewards = np.zeros(self.horizon)
            change_points = []
            t = 0  # current number of samples

            while t < self.horizon:
                self.progress_bar(t + 1, self.horizon)
                SUMS   = np.zeros((self.nb_arms, self.horizon - t , self.horizon - t ))  # array containing the S[a, s, t]
                DRAWS  = np.zeros((self.nb_arms, self.horizon - t , self.horizon - t ))  # array containing the N[a, s, t]
                bad_gaps = np.zeros(self.nb_arms)  # gaps of the bad arms
                bad_means = np.zeros(self.nb_arms)  # means of the bad arms
                t_loc = 0  # number of samples in the current episode
                total_number = np.zeros(self.nb_arms, dtype=int)  # total number of selections in the episode

                # initialize the set of good and bad arms
                good = list(range(self.nb_arms))
                bad = []
                # initialize sampling obligations
                obligations = np.zeros(self.nb_arms, dtype=int)
                new_episode = False  # should we start a new episode?

                # start the episode
                while not new_episode and t < self.horizon:
                    # print(t)
                    # form the set of candidate arms to choose from (good arm + sampling obligations)
                    candidate = good.copy()     #good arm set exploitation
                    candidate.extend([i for i in range(self.nb_arms) if obligations[i] > 0])    #bad arm set exploration

                    # draw an arm (least sampled among candidates)
                    I = np.argmax(np.random.rand(len(candidate)) * -total_number[candidate])
                    rew = self.Env[r][I][t]
                    total_number[I] += 1
                    chosen_arms[t] = I 
                    received_rewards[t] = rew
                    X_mem[r][t] = rew
                    A_mem[r][I][t] = 1
                    # update everything
                    if obligations[I] > 0:
                        # if the arm was sampled due to obligations, decrease the remaining time to sample
                        obligations[I] -= 1

                    t += 1
                    self.progress_bar(t + 1, self.horizon)
                    t_loc += 1

                    for i in range(self.nb_arms ):
                        if i != I and t_loc > 1:
                            DRAWS[i , 0 : t_loc - 1, t_loc - 1] = DRAWS[i, 0 : t_loc - 1, t_loc - 2]
                            SUMS[i , 0 : t_loc - 1, t_loc - 1] = SUMS[i, 0 : t_loc - 1, t_loc - 2]
                        else:
                            DRAWS[i , t_loc - 1, t_loc - 1] = 1
                            SUMS[i , t_loc - 1, t_loc - 1] = rew
                            if t_loc > 1:
                                DRAWS[i , 0 : t_loc - 1, t_loc-1] = DRAWS[i , 0 : t_loc - 1, t_loc - 2] + 1
                                SUMS[i ,  0 : t_loc - 1, t_loc-1] = SUMS[i ,  0 : t_loc - 1, t_loc - 2] + rew

                    # updating the set of good arms
                    for s in range(t_loc):
                        cand = [a for a in good if DRAWS[a , s , t_loc - 1] > 1]
                        cand_means = [SUMS[a, s , t_loc-1] / DRAWS[a , s , t_loc-1] for a in cand]

                        if len(cand_means) > 0:
                            mu_max = max(cand_means)

                            for i in range(len(cand)):
                                arm = cand[i]

                                if mu_max - cand_means[i] > np.sqrt(self.C * np.log(self.horizon) / (DRAWS[arm , s , t_loc - 1] - 1)):

                                    # remove arm from Good
                                    good.remove(arm)
                                    # add arm to bad and store its gap and mean
                                    bad.append(arm)
                                    bad_gaps[arm ] = mu_max - cand_means[i]
                                    bad_means[arm] = cand_means[i]

                    # perform tests
                    if t_loc % self.sub_sample == 1:
                        # check whether a bad arm has changed
                        check = 0
                        id_change = 0
                        s = 1

                        while s < t_loc and check == 0:
                            # print("while1")
                            # print(t)
                            for bad_arm in bad:
                                draws = DRAWS[bad_arm , s , t_loc - 1]
                                if draws > 1 and abs(SUMS[bad_arm , s , t_loc - 1] / draws - bad_means[bad_arm]) >bad_gaps[bad_arm] / 4 + np.sqrt(2 * np.log(self.horizon) / draws):
                                    new_episode = True
                                    check += 1
                                    id_change = bad_arm

                            s = s + 1

                        if check == 0:
                            # check whether a good arm has changed
                            s = 1

                            while s < t_loc and check == 0:
                                # print("while2")
                                # print(t)
                                for s1 in [j for j in range(t_loc) if j % self.sub_sample2 == 1]:
                                    for s2 in [j for j in range(s1 - 1, t_loc) if j % self.sub_sample2 == 1]:
                                        for good_arm in good:
                                            draws1 = DRAWS[good_arm, s1 - 1, s2 - 1]
                                            draws2 = DRAWS[good_arm, s - 1, t_loc - 1]

                                            if draws1 > 1 and draws2 > 1 and abs(
                                                    SUMS[good_arm , s1 - 1, s2 - 1] / draws1 -
                                                    SUMS[good_arm , s - 1, t_loc-1] / draws2) > \
                                                    np.sqrt(2 * np.log(self.horizon) / draws1) + np.sqrt(2 * np.log(self.horizon) / draws2):
                                                new_episode = True
                                                check += 1
                                                id_change = good_arm

                                s = s + 1

                        if check > 0:
                            episode += 1
                            change_points.append(t)

                    # possibly add some new sampling obligation
                    for bad_arm in bad:
                        i = 1
                        while 1 / (2 ** i) >= bad_gaps[bad_arm]/ 16:
                            # print("while3")
                            # print(t)
                            if np.random.rand() < (1 / (2 ** i)) * np.sqrt(episode / (self.nb_arms * self.horizon * np.log(self.horizon))):
                                n = int(np.floor(2 ** (2 * i + 1) * np.log(self.horizon)))
                                # update the sampling obligation
                                obligations[bad_arm] = max(obligations[bad_arm], n)
                            i += 1

            # End of the outer while loop
        X_mem_cum = np.cumsum(X_mem,axis=1)
        # print("Roll data saving ...")
        # if path != None:
        #     filename = path+'/'+'roll_data'+self.__str__()+str(random.randint(100000,999999))+'.csv'
        #     np.savetxt(filename, X_mem_cum ,delimiter=",")
        X_mean = np.mean(X_mem, axis=0)
        X_std = np.std(X_mem_cum, axis=0)
        A_mem_cum = np.cumsum(A_mem,axis=2)
        A_std = np.std(A_mem_cum, axis=0)
        A_mean = np.mean(A_mem_cum, axis=0)
        # self.plot_arm(A_mean, A_std)
        return X_mean, X_std, A_mean, A_std