import numpy as np
import matplotlib.pyplot as plt
import math
import random
from tqdm import tqdm_notebook as tqdm

class Environment(object):
    
    def __init__(self, configuration, experiment = None, repetitions = None, path = None):
        self.repetitions = repetitions
        if experiment == "M":
            self.experiment = experiment
            self.cfg = configuration  #: Configuration dictionnary
            self.nb_break_points = self.cfg.get('nb_break_points')   #: How many random events?
            # self.nb_break_points_list = [self.cfg.get('nb_break_points')]  
            self.horizon = self.cfg['horizon']  #: Horizon (number of time steps)routermode
            print("Time horizon:", self.horizon)
            if repetitions == None:
                self.repetitions = self.cfg.get('repetitions')  #: Number of repetitions
            else:
                self.repetitions = repetitions
            print("Number of repetitions:", self.repetitions)
            # list_nb_arms = len(self.cfg['environment']["params"]["listOfMeans"][0])
            # self.nb_arms_list = [len(self.cfg['environment'][0]["params"]["listOfMeans"][0])]
            self.nb_arms = len(self.cfg['environment']["params"]["listOfMeans"][0])
            self.env_type = self.cfg['environment']["env_type"]
            list_of_means = np.array(self.cfg['environment']["params"]["listOfMeans"])
            change_point = self.cfg['environment']["params"]["changePoints"]
                  
               
            self.__Preprocessing()
            self.Env_for_M = {}
            self.mu_for_M = {}
            self.mu_max_for_M = {}
            # print(self.nb_break_points_list)
            for m in self.nb_break_points_list:
                print(list_of_means)
                list_of_means = self.list_of_means_dict[m]
                change_point = self.change_point_dict[m]
                self.__initEnvironments__(list_of_means, change_point)
                self.Env_for_M[m] = self.Env
                self.mu_for_M[m] = self.mu
                self.mu_max_for_M[m] = self.mu_max
            # print(self.Env_for_M)
            if path != None:
                print("Regret saving ...")
                filename = path+'/nb_break_points'+'.csv'
                np.savetxt(filename, self.nb_break_points_list)

                self.repetitions = repetitions
        elif experiment == "T":
            self.experiment = experiment
            self.cfg = configuration  #: Configuration dictionnary
            self.nb_break_points = self.cfg.get('nb_break_points')   #: How many random events?
            # self.nb_break_points_list = [self.cfg.get('nb_break_points')]  
            self.horizon = self.cfg['horizon']  #: Horizon (number of time steps)routermode
            print("Time horizon:", self.horizon)
            if repetitions == None:
                self.repetitions = self.cfg.get('repetitions')  #: Number of repetitions
            else:
                self.repetitions = repetitions
            print("Number of repetitions:", self.repetitions)
            # list_nb_arms = len(self.cfg['environment']["params"]["listOfMeans"][0])
            # self.nb_arms_list = [len(self.cfg['environment'][0]["params"]["listOfMeans"][0])]
            self.nb_arms = len(self.cfg['environment']["params"]["listOfMeans"][0])
            self.env_type = self.cfg['environment']["env_type"]
            list_of_means = np.array(self.cfg['environment']["params"]["listOfMeans"])
            change_point = self.cfg['environment']["params"]["changePoints"]
                  
               
            self.__Preprocessing_for_T()
            self.Env_for_T = {}
            self.mu_for_T = {}
            self.mu_max_for_T = {}
            # print(self.nb_break_points_list)
            for T in range(len(self.horizon_list)):
                self.horizon = self.horizon_list[T]
                print(list_of_means)
                list_of_means = self.list_of_means_dict[T]
                change_point = self.change_point_dict[T]
                self.__initEnvironments__(list_of_means, change_point)
                self.Env_for_T[self.horizon_list[T]] = self.Env
                self.mu_for_T[self.horizon_list[T]] = self.mu
                self.mu_max_for_T[self.horizon_list[T]] = self.mu_max
            # print(self.Env_for_M)
            if path != None:
                print("Regret saving ...")
                filename = path+'/T'+'.csv'
                np.savetxt(filename, self.horizon_list)

        elif experiment == "K":   
            self.experiment = experiment
            self.cfg = configuration  #: Configuration dictionnary
            self.nb_arms_list = [1,2,3,4,5,6,7,8,9,10]
            self.num_of_instance = len(configuration)
            self.Env_for_K = {}
            self.mu_for_K = {}
            self.mu_max_for_K = {}
            # print(self.nb_break_points_list)
            self.__Preprocessing_for_K()
            for i in range(self.num_of_instance):
                self.Env_for_K[i] = []
                self.mu_for_K[i] = []
                self.mu_max_for_K[i] = []
                for k in self.nb_arms_list:
                    self.nb_arms = k
                    list_of_means = self.list_of_means_dict[i][k-1]
                    print(list_of_means)
                    change_point = self.change_point_dict[i][k-1]
                    self.__initEnvironments__(list_of_means, change_point)
                    self.Env_for_K[i].append(self.Env)
                    self.mu_for_K[i].append(self.mu)
                    self.mu_max_for_K[i].append(self.mu_max)
                # print(self.Env_for_M)
            if path != None:
                print("Regret saving ...")
                filename = path+'/nb_arms'+'.csv'
                np.savetxt(filename, self.nb_arms_list)
        else:
            self.cfg = configuration  #: Configuration dictionnary
            self.nb_break_points = self.cfg.get('nb_break_points')   #: How many random events?
            # self.nb_break_points_list = [self.cfg.get('nb_break_points')]  
            self.horizon = self.cfg['horizon']  #: Horizon (number of time steps)routermode
            print("Time horizon:", self.horizon)
            if repetitions == None:
                self.repetitions = self.cfg.get('repetitions')  #: Number of repetitions
            else:
                self.repetitions = repetitions
            print("Number of repetitions:", self.repetitions)
            # list_nb_arms = len(self.cfg['environment']["params"]["listOfMeans"][0])
            # self.nb_arms_list = [len(self.cfg['environment'][0]["params"]["listOfMeans"][0])]
            self.nb_arms = len(self.cfg['environment']["params"]["listOfMeans"][0])
            self.env_type = self.cfg['environment']["env_type"]
            list_of_means = np.array(self.cfg['environment']["params"]["listOfMeans"])
            change_point = self.cfg['environment']["params"]["changePoints"]
            self.experiment = experiment        

            self.__initEnvironments__(list_of_means, change_point)
    


    def returnEnv(self):
        if self.experiment == "M":
            return self.Env_for_M, self.mu_for_M, self.mu_max_for_M
        else:
            return self.Env, self.mu, self.mu_max

    def __initEnvironments__(self, list_of_means, change_point):
        """ Create environments."""
        print(" Create environments ...")
        self.Env = np.zeros((self.repetitions, self.nb_arms, self.horizon), dtype=int)
        self.mu = np.zeros((self.nb_arms, self.horizon), dtype=int)
        self.mu_max = np.zeros(self.horizon ,dtype=float)
        segment = 0
        for t in range(self.horizon):
            for k in range(self.nb_arms):
                if segment < len(change_point) - 1:
                    if t == change_point[segment + 1]:
                        segment += 1
                for i in range(self.repetitions):    
                    self.Env[i][k][t] = self.Bernoulli(list_of_means[segment][k])
                self.mu[k][t] = list_of_means[segment][k]
            self.mu_max[t] = np.max(list_of_means[segment][:])

    def __Preprocessing(self):
        if self.env_type == "3":
            self.nb_break_points_list = [2, 3, 4, 5, 6, 8, 10, 15,16, 20, 25, 30,32, 40, 50]
            # self.nb_break_points_list = [2, 5, 10, 20]
            self.list_of_means_dict = {}
            self.change_point_dict = {}
            remove_list = []
            # print(self.nb_break_points_list)
            for M in self.nb_break_points_list:
                # print(M)
                if self.horizon % M == 0:
                    # print(M)
                    self.list_of_means_dict[M]=[]
                    self.change_point_dict[M]=[]
                    for m in range(M):
                        if (m%3 == 0):
                            self.list_of_means_dict[M].append(self.cfg['environment']["params"]["listOfMeans"][0])
                            self.change_point_dict[M].append(m*(self.horizon/M))
                        elif (m%3 == 1): 
                            self.list_of_means_dict[M].append(self.cfg['environment']["params"]["listOfMeans"][1])
                            self.change_point_dict[M].append(m*(self.horizon/M))
                        elif (m%3 == 2): 
                            self.list_of_means_dict[M].append(self.cfg['environment']["params"]["listOfMeans"][2])
                            self.change_point_dict[M].append(m*(self.horizon/M))
                else:
                    remove_list.append(M)
            for M in remove_list:
                self.nb_break_points_list.remove(M)
            print(self.nb_break_points_list)
            print(self.list_of_means_dict)
            print(self.change_point_dict)
        elif self.env_type == "Type-I" or "Type-II":
            self.nb_break_points_list = [2, 3, 4, 5, 6, 8, 10, 15, 20, 25, 30, 40, 50]
            self.list_of_means_dict = {}
            self.change_point_dict = {}
            remove_list = []
            # print(self.nb_break_points_list)
            for M in self.nb_break_points_list:
                # print(M)
                if self.horizon % M == 0:
                    # print(M)
                    self.list_of_means_dict[M]=[]
                    self.change_point_dict[M]=[]
                    for m in range(M):
                        if m%2 == 0:
                            self.list_of_means_dict[M].append(self.cfg['environment']["params"]["listOfMeans"][0])
                            self.change_point_dict[M].append(m*(self.horizon/M))
                        else: 
                            self.list_of_means_dict[M].append(self.cfg['environment']["params"]["listOfMeans"][1])
                            self.change_point_dict[M].append(m*(self.horizon/M))
                else:
                    remove_list.append(M)
            for M in remove_list:
                self.nb_break_points_list.remove(M)
            print(self.nb_break_points_list)
            print(self.list_of_means_dict)
            print(self.change_point_dict)

        else:        
            self.nb_break_points_list = [1*self.nb_break_points, 2*self.nb_break_points, 3*self.nb_break_points, 4*self.nb_break_points]
            self.list_of_means_dict = {}
            self.change_point_dict = {}
            for M in self.nb_break_points_list:
                if self.horizon % M == 0:
                    self.list_of_means_dict[M]=[]
                    self.change_point_dict[M]=[]
                    for m in range(M):
                        self.list_of_means_dict[M].append(self.cfg['environment'][0]["params"]["listOfMeans"][m % self.nb_break_points])
                        self.change_point_dict[M].append(m*(self.horizon/M))
                else:
                    self.nb_break_points_list.remove(M)

    def __Preprocessing_for_T(self):
        self.horizon_list = [2000, 5000, 10000, 20000, 50000, 75000, 100000]
        self.list_of_means_dict = {}
        self.change_point_dict = {}
        self.nb_break_points = self.cfg.get('nb_break_points')   #: How many random events?
        # self.nb_break_points_list = [self.cfg.get('nb_break_points')]  
        # self.horizon = self.cfg['horizon']  #: Horizon (number of time steps)routermode
        # print("Time horizon:", self.horizon)
        if self.repetitions == None:
            self.repetitions = self.cfg.get('repetitions')  #: Number of repetitions

        print("Number of repetitions:", self.repetitions)
        # list_nb_arms = len(self.cfg['environment']["params"]["listOfMeans"][0])
        # self.nb_arms_list = [len(self.cfg['environment'][0]["params"]["listOfMeans"][0])]
        self.env_type = self.cfg['environment']["env_type"]

        self.list_of_means_dict={}
        self.change_point_dict={}   
        for T in range(len(self.horizon_list)):
            self.list_of_means_dict[T]=self.cfg['environment']["params"]["listOfMeans"]
            self.change_point_dict[T]=[]
            for m in range(self.nb_break_points):
                self.change_point_dict[T].append(m*(self.horizon_list[T]/self.nb_break_points))



    def __Preprocessing_for_K(self): 
        self.list_of_means_dict = {}
        self.change_point_dict = {}
        for i in range(self.num_of_instance):

            self.nb_break_points = self.cfg[i].get('nb_break_points')   #: How many random events?
            # self.nb_break_points_list = [self.cfg.get('nb_break_points')]  
            self.horizon = self.cfg[i]['horizon']  #: Horizon (number of time steps)routermode
            print("Time horizon:", self.horizon)
            if self.repetitions == None:
                self.repetitions = self.cfg[i].get('repetitions')  #: Number of repetitions

            print("Number of repetitions:", self.repetitions)
            # list_nb_arms = len(self.cfg['environment']["params"]["listOfMeans"][0])
            # self.nb_arms_list = [len(self.cfg['environment'][0]["params"]["listOfMeans"][0])]
            self.env_type = self.cfg[i]['environment']["env_type"]

            self.list_of_means_dict[i]=[]
            self.change_point_dict[i]=[]    
            for K in self.nb_arms_list:
                self.list_of_means_dict[i].append([x[0:K] for x in self.cfg[i]['environment']["params"]["listOfMeans"]])
                self.change_point_dict[i].append(self.cfg[i]['environment']["params"]["changePoints"])


    def Bernoulli(self, p):
        if np.random.rand()<p:
            return 1
        else :
            return 0