import numpy as np

class mopo_buffer():
    def __init__(self, dataset, dic_traffic_env_conf):
        self.dic_traffic_env_conf = dic_traffic_env_conf
        self.state_feature = tuple(self.dic_traffic_env_conf["LIST_STATE_FEATURE"][:])
        self.state_feature_len = {feature : 0 for feature in self.state_feature}
        self.dataset = []
        for one_data in dataset:
            self.dataset.append(self.process_one_offline_data(one_data))
        
    def process_one_offline_data(self, one_data):
        
        state, action, next_state, _, _, _, _ = one_data
        
        formated_state = self.format_one_intersection_state(state, self.state_feature)
        formated_action = self.format_one_intersection_action(action)
        formated_next_state = self.format_one_intersection_state(next_state, self.state_feature)
        formated_reward = self.calc_one_intersection_reward(next_state)
        
        data = [formated_state, formated_action, formated_next_state, formated_reward]
        return data
        
    def format_one_intersection_state(self, state, feature_chosen):
        output = []
        total_feature_len = 0
        for feature in feature_chosen:
            assert len(state[feature]) % self.dic_traffic_env_conf["NUM_LANE"] == 0
            one_feature_len = len(state[feature])
            total_feature_len += one_feature_len
            if feature in self.state_feature_len and self.state_feature_len[feature] == 0:
                self.state_feature_len[feature] = one_feature_len
            output.extend(state[feature])
        return np.array(output).reshape(total_feature_len, )
    
    def format_one_intersection_action(self, action):
        assert action <= 3
        action_bitmap = self.dic_traffic_env_conf["PHASE"][action + 1].copy()
        
        return np.array(action_bitmap).reshape(self.dic_traffic_env_conf["NUM_LANE"] * 2 // 3, )
    
    def calc_one_intersection_reward(self, next_state):
        reward_feature = []
        for key in self.dic_traffic_env_conf["DIC_REWARD_INFO"].keys():
            if self.dic_traffic_env_conf["DIC_REWARD_INFO"][key] != 0:
                reward_feature.append(key)
                
        assert len(reward_feature) > 0
        
        formated_reward = self.format_one_intersection_state(next_state, reward_feature)
        formated_reward.reshape(len(reward_feature), self.dic_traffic_env_conf["NUM_LANE"])
        formated_reward = [np.mean(formated_reward[index]) * self.dic_traffic_env_conf["DIC_REWARD_INFO"][feature] for index, feature in enumerate(reward_feature)]
        return np.array(np.mean(formated_reward)).reshape(1, )
    
    def sample(self, size):
        if size > len(self.dataset):
            print("[WARNING] Required size is {}, which is larger than dataset size {}".format(size, len(self.dataset)))
            print("[WARNING] Reset to {}".format(len(self.dataset)))
            size = len(self.dataset)
            
        indices = np.random.randint(0, len(self.dataset), size=size)
        output = []
        for indice in indices:
            output.append(self.dataset[indice].copy())
        return output
        
    def sample_all(self):
        return self.dataset.copy()
    
    def sample_chosen_data(self, catagory, size):
        if size > len(self.dataset):
            print("[WARNING] Required size is {}, which is larger than dataset size {}".format(size, len(self.dataset)))
            print("[WARNING] Reset to {}".format(len(self.dataset)))
            size = len(self.dataset)
            
        catagory_index_map = {"state" : 0,
                              "action" : 1,
                              "next_state" : 2,
                              "reward" : 3}
        indices = np.random.randint(0, len(self.dataset), size=size)
        output = []
        for indice in indices:
            one_output = []
            for one_catagory in catagory:
                one_output.extend(self.dataset[indice][catagory_index_map[one_catagory]].copy())
            output.append(one_output)
        return np.array(output)
    
    def sample_all_chosen_data(self, catagory):
        catagory_index_map = {"state" : 0,
                              "action" : 1,
                              "next_state" : 2,
                              "reward" : 3}
        output = []
        for data in self.dataset:
            one_output = []
            for one_catagory in catagory:
                one_output.extend(data[catagory_index_map[one_catagory]].copy())
            output.append(one_output)
        return np.array(output)
    
    def sample_chosen_feature(self, features, size):
        if size > len(self.dataset):
            print("[WARNING] Required size is {}, which is larger than dataset size {}".format(size, len(self.dataset)))
            print("[WARNING] Reset to {}".format(len(self.dataset)))
            size = len(self.dataset)
        output = []
        if self.check_features(features):
            indices = np.random.randint(0, len(self.dataset), size=size)
            for indice in indices:
                one_output = []
                for feature in features:
                    # 1st is index of one data, 2nd is index of state, 3rd is index of chosen feature
                    feature_index = self.dic_traffic_env_conf["LIST_STATE_FEATURE"].index(feature)
                    one_output.extend(self.dataset[indice][0][feature_index * self.dic_traffic_env_conf["NUM_LANE"] : (feature_index + 1) * self.dic_traffic_env_conf["NUM_LANE"]])
                output.append(one_output)

        return np.array(output)
    
    def sample_all_chosen_features(self, features):
        output = []
        if self.check_features(features):
            for data in self.dataset:
                one_output = []
                for feature in features:
                    # 1st is index of state, 2nd is index of chosen feature
                    feature_index = self.dic_traffic_env_conf["LIST_STATE_FEATURE"].index(feature)
                    one_output.extend(data[0][feature_index * self.dic_traffic_env_conf["NUM_LANE"] : (feature_index + 1) * self.dic_traffic_env_conf["NUM_LANE"]])
                output.append(one_output)

        return np.array(output)
    
    def check_features(self, features):
        for feature in features:
            if feature not in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]:
                return False
        return True
    
    def get_data_dim(self, catagory):
        catagory_index_map = {"state" : 0,
                              "action" : 1,
                              "next_state" : 2,
                              "reward" : 3}
        
        output_dim = 0
        for one_catagory in catagory:
            output_dim += self.dataset[0][catagory_index_map[one_catagory]].shape[0]
        return output_dim
    
    def format_one_intersection_state_reverse(self, state, feature_chosen):
        assert list(self.state_feature) == list(feature_chosen)
        outputs = {}
        index = 0
        for feature in feature_chosen:
            outputs[feature] = np.array(state[index : index + self.state_feature_len[feature]]).reshape(self.state_feature_len[feature], )
            index += self.state_feature_len[feature]
        return outputs
    