import numpy as np


class Environment():
    def __init__(self, cfg, env_mean): 
        self.cfg = cfg
        self.repetitions = self.cfg.get('repetitions') 
        self.nb_break_points = self.cfg.get('nb_break_points')
        self.horizon = self.cfg.get('horizon')
        self.nb_arms = self.cfg.get('nb_arms')
        list_of_means = np.array(env_mean["params"]["listOfMeans"])
        change_point = env_mean["params"]["changePoints"]
        
        self.__initEnvironments__(list_of_means, change_point)
        print(self.Env)


    def __initEnvironments__(self, list_of_means, change_point):
        """ Create environments."""
        print(" Create environments ...")
        self.Env = np.zeros((self.repetitions, self.nb_arms, self.horizon), dtype=int)
        self.mu = np.zeros((self.nb_arms, self.horizon), dtype=int)
        self.mu_max = np.zeros(self.horizon ,dtype=float)
        segment = 0
        for t in range(self.horizon):
            for k in range(self.nb_arms):
                if segment < len(change_point) - 1:
                    if t == change_point[segment + 1]:
                        segment += 1
                for i in range(self.repetitions):    
                    self.Env[i][k][t] = self.Bernoulli(list_of_means[segment][k])
                self.mu[k][t] = list_of_means[segment][k]
            self.mu_max[t] = np.max(list_of_means[segment][:])        
            
    def Bernoulli(self, p):
        if np.random.rand()<p:
            return 1
        else :
            return 0
        
class EnvInM():
    def __init__(self, cfg): 
        self.cfg = cfg
        
    def __call__(self, env_mean):
        nb_break_points_list, list_of_means_dict, change_point_dict = self.Preprocessing(env_mean)
        env_mean_dict = {}
        env_samples_dict = {}
        for m in nb_break_points_list:
            env_mean_dict[m] = {}
            env_mean_dict[m]["params"] = {}
            env_mean_dict[m]["params"]["listOfMeans"] = list_of_means_dict[m]
            env_mean_dict[m]["params"]["changePoints"] = change_point_dict[m]
            env_samples = Environment(self.cfg, env_mean_dict[m])
            env_samples_dict[m] = env_samples
        return nb_break_points_list, env_samples_dict
    
    def Preprocessing(self, env_mean):
        horizon = self.cfg.get('horizon')
        nb_arms = self.cfg.get('nb_arms')
        list_of_means_dict = {}
        change_point_dict = {}
        env_type = env_mean.get('env_type')
        
        if (env_type == "3") or (env_type == "Type-I") or (env_type == "Type-II"):
            nb_break_points_list = [2, 3, 4, 5, 6, 8, 10, 15, 16, 20, 25, 30, 32, 40, 50]
            remove_list = []
            for M in nb_break_points_list:
                # print(M)
                if horizon % M == 0:
                    # print(M)
                    list_of_means_dict[M]=[]
                    change_point_dict[M]=[]
                    for m in range(M):
                        if env_type == "3":
                            if (m % 3 == 0):
                                list_of_means_dict[M].append(env_mean["params"]["listOfMeans"][0])
                                change_point_dict[M].append(m*(horizon/M))
                            elif (m % 3 == 1): 
                                list_of_means_dict[M].append(env_mean["params"]["listOfMeans"][1])
                                change_point_dict[M].append(m*(horizon/M))
                            elif (m % 3 == 2): 
                                list_of_means_dict[M].append(env_mean["params"]["listOfMeans"][2])
                                change_point_dict[M].append(m*(horizon/M))
                        else:
                            if m % 2 == 0:
                                list_of_means_dict[M].append(env_mean["params"]["listOfMeans"][0])
                                change_point_dict[M].append(m*(horizon/M))
                            else: 
                                list_of_means_dict[M].append(env_mean["params"]["listOfMeans"][1])
                                change_point_dict[M].append(m*(horizon/M))
                else:
                    remove_list.append(M)
            for M in remove_list:
                nb_break_points_list.remove(M)
            print(nb_break_points_list)
            print(list_of_means_dict)
            print(change_point_dict)
        elif env_type == "rand":
            nb_break_points = self.cfg.get('nb_break_points')
            nb_break_points_list = [2, 3, 4, 5, 6, 8, 10, 15, 16, 20, 25, 30, 32, 40, 50]
            remove_list = []
            for M in nb_break_points_list:
                if horizon % M == 0:
                    list_of_means_dict[M]=[]
                    change_point_dict[M]=[] 
                    for m in range(M):
                        list_of_means_dict[M].append([round(np.random.rand(),2) for i in range(nb_arms)])
                        change_point_dict[M].append(m*(horizon/M))
                else:
                    remove_list.append(M)
            for M in remove_list:
                nb_break_points_list.remove(M)
            print(nb_break_points_list)
            print(list_of_means_dict)
            print(change_point_dict)
        else:
            nb_break_points = self.cfg.get('nb_break_points')
            nb_break_points_list = [1*nb_break_points, 2*nb_break_points, 3*nb_break_points, 4*nb_break_points]
            remove_list = []
            for M in nb_break_points_list:
                if horizon % M == 0:
                    list_of_means_dict[M]=[]
                    change_point_dict[M]=[] 
                    for m in range(M):
                        list_of_means_dict[M].append(env_mean["params"]["listOfMeans"][m % nb_break_points])
                        change_point_dict[M].append(m*(horizon/M))
                else:
                    remove_list.append(M)
            for M in remove_list:
                nb_break_points_list.remove(M)
            print(nb_break_points_list)
            print(list_of_means_dict)
            print(change_point_dict)
        return nb_break_points_list, list_of_means_dict, change_point_dict