import numpy as np

class buffer():
    def __init__(self, dataset, dic_traffic_env_conf):
        self.dic_traffic_env_conf = dic_traffic_env_conf
        self.state_feature = tuple(self.dic_traffic_env_conf["LIST_STATE_FEATURE"][:])
        self.state_feature_len = {feature : 0 for feature in self.state_feature}
        self.dataset = []
        for one_data in dataset:
            self.dataset.extend(self.process_one_offline_data(one_data))
        
    def process_one_offline_data(self, one_data):
        state, action, next_state, _, _, _, _ = one_data
        data = []
        
        formated_state = self.format_one_intersection_state(state, self.state_feature)
        formated_action = self.format_one_intersection_action(action)
        formated_next_state = self.format_one_intersection_state(next_state, self.state_feature)
        formated_reward = self.calc_one_intersection_reward(next_state)
        
        for index in range(formated_state.shape[0]):
            data.append([formated_state[index], formated_action[index], formated_next_state[index], formated_reward[index]])
        return data
        
    def format_one_intersection_state(self, state, feature_chosen):
        outputs = [[] for _ in range(self.dic_traffic_env_conf["NUM_LANE"])]
        total_feature_len = 0
        for feature in feature_chosen:
            assert len(state[feature]) % self.dic_traffic_env_conf["NUM_LANE"] == 0
            one_feature_len = len(state[feature]) // self.dic_traffic_env_conf["NUM_LANE"]
            total_feature_len += one_feature_len
            if feature in self.state_feature_len and self.state_feature_len[feature] == 0:
                self.state_feature_len[feature] = one_feature_len
            for i in range(self.dic_traffic_env_conf["NUM_LANE"]):
                outputs[i].extend(np.array(state[feature][i * one_feature_len : (i + 1) * one_feature_len]).reshape(one_feature_len, ))
        return np.array(outputs).reshape(self.dic_traffic_env_conf["NUM_LANE"], total_feature_len)
    
    def format_one_intersection_action(self, action):
        assert action <= 3
        movements = {"right": [0, 3, 6, 9], "left": [1, 4, 7, 10], "through": [2, 5, 8, 11]}
        action_bitmap = self.dic_traffic_env_conf["PHASE"][action + 1].copy()
        for index in movements["right"]:
            action_bitmap.insert(index, 1)
        
        return np.array(action_bitmap).reshape(self.dic_traffic_env_conf["NUM_LANE"], 1)
    
    def format_one_intersection_state_reverse(self, state, feature_chosen):
        assert list(self.state_feature) == list(feature_chosen)
        outputs = {feature : [] for feature in feature_chosen}
        index = 0
        for feature in feature_chosen:
            for one_lane_state in state:
                outputs[feature].extend(np.array(one_lane_state[index : index + self.state_feature_len[feature]]).reshape(self.state_feature_len[feature], ))
            index += self.state_feature_len[feature]
        return outputs
        
    
    def calc_one_intersection_reward(self, next_state):
        reward_feature = []
        for key in self.dic_traffic_env_conf["DIC_REWARD_INFO"].keys():
            if self.dic_traffic_env_conf["DIC_REWARD_INFO"][key] != 0:
                reward_feature.append(key)
                
        assert len(reward_feature) > 0
        
        formated_reward = self.format_one_intersection_state(next_state, reward_feature)
        formated_reward = [np.array(np.sum(x)) for x in formated_reward]
        return np.array(formated_reward).reshape(self.dic_traffic_env_conf["NUM_LANE"], 1)
    
    def sample(self, size):
        if size > len(self.dataset):
            print("[WARNING] Required size is {}, which is larger than dataset size {}".format(size, len(self.dataset)))
            print("[WARNING] Reset to {}".format(len(self.dataset)))
            size = len(self.dataset)
            
        indices = np.random.randint(0, len(self.dataset), size=size)
        output = []
        for indice in indices:
            output.append(self.dataset[indice].copy())
        return output
        
    def sample_all(self):
        return self.dataset.copy()
    
    def sample_chosen_data(self, catagory, size):
        if size > len(self.dataset):
            print("[WARNING] Required size is {}, which is larger than dataset size {}".format(size, len(self.dataset)))
            print("[WARNING] Reset to {}".format(len(self.dataset)))
            size = len(self.dataset)
            
        catagory_index_map = {"state" : 0,
                              "action" : 1,
                              "next_state" : 2,
                              "reward" : 3}
        indices = np.random.randint(0, len(self.dataset), size=size)
        output = []
        for indice in indices:
            one_output = []
            for one_catagory in catagory:
                one_output.extend(self.dataset[indice][catagory_index_map[one_catagory]].copy())
            output.append(one_output)
        return np.array(output)
    
    def sample_all_chosen_data(self, catagory):
        catagory_index_map = {"state" : 0,
                              "action" : 1,
                              "next_state" : 2,
                              "reward" : 3}
        output = []
        for data in self.dataset:
            one_output = []
            for one_catagory in catagory:
                one_output.extend(data[catagory_index_map[one_catagory]].copy())
            output.append(one_output)
        return np.array(output)
    
    def sample_chosen_feature(self, features, size):
        if size > len(self.dataset):
            print("[WARNING] Required size is {}, which is larger than dataset size {}".format(size, len(self.dataset)))
            print("[WARNING] Reset to {}".format(len(self.dataset)))
            size = len(self.dataset)
        output = []
        if self.check_features(features):
            indices = np.random.randint(0, len(self.dataset), size=size)
            for indice in indices:
                one_output = []
                for feature in features:
                    # 1st is index of one data, 2nd is index of state, 3rd is index of chosen feature
                    one_output.append(self.dataset[indice][0][self.dic_traffic_env_conf["LIST_STATE_FEATURE"].index(feature)])
                output.append(one_output)

        return np.array(output)
    
    def sample_all_chosen_features(self, features):
        output = []
        if self.check_features(features):
            for data in self.dataset:
                one_output = []
                for feature in features:
                    # 1st is index of state, 2nd is index of chosen feature
                    one_output.append(data[0][self.dic_traffic_env_conf["LIST_STATE_FEATURE"].index(feature)])
                output.append(one_output)

        return np.array(output)
    
    def check_features(self, features):
        for feature in features:
            if feature not in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]:
                return False
        return True
    
    def get_data_dim(self, catagory):
        catagory_index_map = {"state" : 0,
                              "action" : 1,
                              "next_state" : 2,
                              "reward" : 3}
        
        output_dim = 0
        for one_catagory in catagory:
            output_dim += self.dataset[0][catagory_index_map[one_catagory]].shape[0]
        return output_dim
    
class buffer_old():
    def __init__(self, dataset, dic_traffic_env_conf):
        self.dic_traffic_env_conf = dic_traffic_env_conf
        self.dataset = []
        for one_data in dataset:
            self.dataset.extend(self.process_one_offline_data(one_data))
        
    def process_one_offline_data(self, one_data):
        state_feature = self.dic_traffic_env_conf["LIST_STATE_FEATURE"][:]
        
        state, action, next_state, _, _, _, _ = one_data
        data = []
        
        formated_state = self.format_one_intersection_state(state, state_feature)
        formated_action = self.format_one_intersection_action(action)
        formated_next_state = self.format_one_intersection_state(next_state, state_feature)
        formated_reward = self.calc_one_intersection_reward(next_state)
        
        for index in range(formated_state.shape[0]):
            data.append([formated_state[index], formated_action[index], formated_next_state[index], formated_reward[index]])
        return data
        
    def format_one_intersection_state(self, state, feature_chosen):
        outputs = [[] for _ in range(self.dic_traffic_env_conf["NUM_LANE"])]
        for feature in feature_chosen:
            assert len(state[feature]) == self.dic_traffic_env_conf["NUM_LANE"]
            for i in range(self.dic_traffic_env_conf["NUM_LANE"]):
                outputs[i].append(state[feature][i])
        return np.array(outputs).reshape(self.dic_traffic_env_conf["NUM_LANE"], len(feature_chosen))
    
    def format_one_intersection_action(self, action):
        assert action <= 3
        movements = {"right": [0, 3, 6, 9], "left": [1, 4, 7, 10], "through": [2, 5, 8, 11]}
        action_bitmap = self.dic_traffic_env_conf["PHASE"][action + 1].copy()
        for index in movements["right"]:
            action_bitmap.insert(index, 1)
        
        return np.array(action_bitmap).reshape(self.dic_traffic_env_conf["NUM_LANE"], 1)
    
    def calc_one_intersection_reward(self, next_state):
        reward_feature = []
        for key in self.dic_traffic_env_conf["DIC_REWARD_INFO"].keys():
            if self.dic_traffic_env_conf["DIC_REWARD_INFO"][key] != 0:
                reward_feature.append(key)
                
        assert len(reward_feature) > 0
        
        formated_reward = self.format_one_intersection_state(next_state, reward_feature)
        formated_reward = [np.array(np.sum(x)) for x in formated_reward]
        return np.array(formated_reward).reshape(self.dic_traffic_env_conf["NUM_LANE"], 1)
    
    def sample(self, size):
        if size > len(self.dataset):
            print("[WARNING] Required size is {}, which is larger than dataset size {}".format(size, len(self.dataset)))
            print("[WARNING] Reset to {}".format(len(self.dataset)))
            size = len(self.dataset)
            
        indices = np.random.randint(0, len(self.dataset), size=size)
        output = []
        for indice in indices:
            output.append(self.dataset[indice].copy())
        return output
        
    def sample_all(self):
        return self.dataset.copy()
    
    def sample_chosen_data(self, catagory, size):
        if size > len(self.dataset):
            print("[WARNING] Required size is {}, which is larger than dataset size {}".format(size, len(self.dataset)))
            print("[WARNING] Reset to {}".format(len(self.dataset)))
            size = len(self.dataset)
            
        catagory_index_map = {"state" : 0,
                              "action" : 1,
                              "next_state" : 2,
                              "reward" : 3}
        indices = np.random.randint(0, len(self.dataset), size=size)
        output = []
        for indice in indices:
            one_output = []
            for one_catagory in catagory:
                one_output.extend(self.dataset[indice][catagory_index_map[one_catagory]].copy())
            output.append(one_output)
        return np.array(output)
    
    def sample_all_chosen_data(self, catagory):
        catagory_index_map = {"state" : 0,
                              "action" : 1,
                              "next_state" : 2,
                              "reward" : 3}
        output = []
        for data in self.dataset:
            one_output = []
            for one_catagory in catagory:
                one_output.extend(data[catagory_index_map[one_catagory]].copy())
            output.append(one_output)
        return np.array(output)
    
    def sample_chosen_feature(self, features, size):
        if size > len(self.dataset):
            print("[WARNING] Required size is {}, which is larger than dataset size {}".format(size, len(self.dataset)))
            print("[WARNING] Reset to {}".format(len(self.dataset)))
            size = len(self.dataset)
        output = []
        if self.check_features(features):
            indices = np.random.randint(0, len(self.dataset), size=size)
            for indice in indices:
                one_output = []
                for feature in features:
                    # 1st is index of one data, 2nd is index of state, 3rd is index of chosen feature
                    one_output.append(self.dataset[indice][0][self.dic_traffic_env_conf["LIST_STATE_FEATURE"].index(feature)])
                output.append(one_output)

        return np.array(output)
    
    def sample_all_chosen_features(self, features):
        output = []
        if self.check_features(features):
            for data in self.dataset:
                one_output = []
                for feature in features:
                    # 1st is index of state, 2nd is index of chosen feature
                    one_output.append(data[0][self.dic_traffic_env_conf["LIST_STATE_FEATURE"].index(feature)])
                output.append(one_output)

        return np.array(output)
    
    def check_features(self, features):
        for feature in features:
            if feature not in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]:
                return False
        return True
    
    def get_data_dim(self, catagory):
        catagory_index_map = {"state" : 0,
                              "action" : 1,
                              "next_state" : 2,
                              "reward" : 3}
        
        output_dim = 0
        for one_catagory in catagory:
            output_dim += self.dataset[0][catagory_index_map[one_catagory]].shape[0]
        return output_dim
