import numpy as np
import matplotlib.pyplot as plt
import math
import random
from tqdm import tqdm
from skip_mech import skip_mech
from plotsettings import BBOX_INCHES, signature, maximizeWindow, palette, makemarkers, add_percent_formatter, legend, show_and_save, nrows_ncols, violin_or_box_plot, adjust_xticks_subplots, table_to_latex


class Meta(object):
    def __init__(
        self,
        repetitions,
        nb_arms,
        nb_break_points,
        horizon,
        Env,      
        path = str  
    ):
        self.repetitions = repetitions
        self.nb_arms = nb_arms
        self.nb_break_points = nb_break_points
        self.horizon = horizon
        self.Env = Env.Env
        
        self.path = path
        
    def __call__(self):
        X_mean, X_std, A_mean, A_std = self.mainAlg()      
        return X_mean, X_std, A_mean, A_std 
    
    def mainAlg(self, path = None):
        # table contains a table of rewards for the K arms up to horizon T
        print("Running ... Meta")
        self.A_mem = np.zeros((self.repetitions, self.nb_arms, self.horizon), dtype = int)
        self.X_mem = np.zeros((self.repetitions, self.horizon), dtype = float)
        for self.r in tqdm(range(self.repetitions)):

            # index of the episode
            episode = 1
            self.chosen_arms = np.zeros(self.horizon, dtype=int)
            self.received_rewards = np.zeros(self.horizon)
            change_points = []
            self.reward_map = np.zeros((self.nb_arms, self.horizon))

            t = 0  # current number of samples
            self.new_episode = False  # should we start a new episode?
            self.tl = t
            self.A_master = [k for k in range(self.nb_arms)]
            self.m_list = [2**i for i in range(1,int(np.ceil(np.log(self.horizon)))+1)]
            # B = np.zeros((self.horizon - self.tl -1, len(m_list)), dtype=float)
            self.B = {}
            for s in range(self.tl+1, self.horizon):
                self.B[s] = {}
                for m in self.m_list:
                    self.B[s][m] = self.Bernoulli(1/np.sqrt(m*(s-self.tl)))
            while t < self.horizon:
                if self.new_episode == True:
                    print("next episode")
                    self.new_episode = False
                    self.tl = t
                    self.A_master = [k for k in range(self.nb_arms)]
                    self.m_list = [2**i for i in range(1,int(np.ceil(np.log(self.horizon)))+1)]
                    # B = np.zeros((self.horizon - self.tl -1, len(m_list)), dtype=float)
                    self.B = {}
                    for s in range(self.tl+1, self.horizon):
                        self.B[s] = {}
                        for m in self.m_list:
                            self.B[s][m] = self.Bernoulli(1/np.sqrt(m*(s-self.tl)))
                # start the episode
                # while not new_episode and t < self.horizon:
                t = self.BaseAlg(self.tl, self.horizon-self.tl)
                    
        X_mem_cum = np.cumsum(self.X_mem,axis=1)
        # print("Roll data saving ...")
        # if path != None:
        #     filename = path+'/'+'roll_data'+self.__str__()+str(random.randint(0,1000000))+'.csv'
        #     np.savetxt(filename, X_mem_cum ,delimiter=",")
        X_mean = np.mean(self.X_mem, axis=0)
        X_std = np.std(X_mem_cum, axis=0)
        A_mem_cum = np.cumsum(self.A_mem,axis=2)
        A_std = np.std(A_mem_cum, axis=0)
        A_mean = np.mean(A_mem_cum, axis=0)
        self.plot_arm(A_mean, A_std)
        return X_mean, X_std
    
    def BaseAlg(self, t_start, m0):
        t = t_start
        self.At = [k for k in range(self.nb_arms)]
        while t<self.horizon:
            print(t)
            I = np.random.choice(self.At)
            rew = self.Env[self.r][I][t]
            self.chosen_arms[t] = I
            self.received_rewards[t] = rew
            self.reward_map[I][t] = rew
            self.X_mem[self.r][t] = rew
            self.A_mem[self.r][I][t] = 1
            self.A_current = self.At.copy()
            t = t+1
            if len([k for k,v in self.B[t].items() if v>0 ]) != 0:
                m = max([k for k,v in self.B[t].items() if v>0 ])
                t = self.BaseAlg(t, m)
            if t > t_start+m0:
                return t
            evict_arm = self.BadArm(self.tl, t)
            if evict_arm in self.A_master:
                self.A_master.remove(evict_arm)

            evict_arm = self.BadArm(t_start, t)
            self.At = self.A_current.copy()
            if evict_arm in self.At:
                self.At.remove(evict_arm)
            if len(self.A_master) == 0:
                self.new_episode = True
                return t
        return t
            
    def BadArm(self, t1, t2):
        K_set = [k for k in range(self.nb_arms)]   
        for s1 in range(t1,t2):
            for s2 in range(s1,t2):
                delta = np.zeros((len(K_set), len(K_set)), dtype=float)
                for a_prime in K_set:
                    for a in K_set:
                        delta[a_prime][a] = np.sum(self.reward_map[a, s1:s2]) - np.sum(self.reward_map[a_prime, s1:s2])
                deltaIndexPair = np.argwhere(delta == np.max(delta))
                if delta[deltaIndexPair[0][0]][deltaIndexPair[0][1]]>np.log(self.horizon)*np.sqrt(np.max([self.nb_arms*(s2-s1), self.nb_arms*self.nb_arms])):
                    return deltaIndexPair[0][1]
        return -1
    
    def __str__(self):
        return r"META"