import os
import pickle
import json
import numpy as np
import tensorflow._api.v2.compat.v1 as tf

from utils.mopo.mopo_buffer import mopo_buffer
from utils.bnn import bnn

class MOPOFakeEnv():
    def __init__(self, dic_path, dic_traffic_env_conf):
        self.dic_traffic_env_conf = dic_traffic_env_conf
        self.dic_path = dic_path
        path_to_offline_dataset = os.path.join(self.dic_path["PATH_TO_OFFLINE_DATA"], "data")
        files = os.listdir(path_to_offline_dataset)
        dataset = []
        for index, file in enumerate(files):
            with open(os.path.join(path_to_offline_dataset, file), "rb") as f:
                dataset.append(pickle.load(f))
            # if index > 5:
            #     break
        
        self.state = None
        self.current_time = 0
        self.current_phase = [0 for _ in range(self.dic_traffic_env_conf["NUM_INTERSECTIONS"])]
        traffic_light_node_dict = self._adjacency_extraction()
        self.adjacency_matrix = [traffic_light_node_dict["intersection_{0}_{1}".format(i+1, j+1)]["adjacency_row"]
                                    for i in range(self.dic_traffic_env_conf["NUM_COL"])
                                    for j in range(self.dic_traffic_env_conf["NUM_ROW"])]
        
        self.use_current_phase = False
        self.use_adjacency_matrix = False
        if "cur_phase" in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]:
            self.dic_traffic_env_conf["LIST_STATE_FEATURE"].remove("cur_phase")
            self.use_current_phase = True
        elif "adjacency_matrix" in self.dic_traffic_env_conf["LIST_STATE_FEATURE"]:
            self.dic_traffic_env_conf["LIST_STATE_FEATURE"].remove("adjacency_matrix")
            self.use_adjacency_matrix = True
        
        self.data_buffer = mopo_buffer(dataset, dic_traffic_env_conf)
        
        # construct bnn
        models_name = ["reward", "next_state"]     # don't change models' name
        models_ensemble = [7, 7, 7, 7]
        models_elites = [5, 5, 5, 5]
        self.input_dim = self.data_buffer.get_data_dim(["state", "action"])
        self.next_state_output_dim = self.data_buffer.get_data_dim(["next_state"])
        self.reward_output_dim = self.data_buffer.get_data_dim(["reward"])
        models_arch = [[self.input_dim, 64, 64, self.reward_output_dim], 
                            [self.input_dim, 64, 64, self.next_state_output_dim]]
        models_activation = [["swish", "sigmoid", "swish", "ReLU"],
                           ["swish", "sigmoid", "swish", "ReLU"]]
        weight_decay = [[0.00001, 0.00005, 0.00005, 0.0001],
                        [0.00001, 0.00005, 0.00005, 0.0001]]
        epoches = 100
        validation_split = 0.2
        
        self.models = {}
        for index in range(len(models_name)):
            print("[INFO] Start building model {} | ensembles {} | elites {}".format(models_name[index], 
                                                                                     models_ensemble[index], 
                                                                                     models_elites[index]))
            self._init_transition_model(name=models_name[index], 
                                        ensembles=models_ensemble[index], 
                                        elites=models_elites[index], 
                                        arch=models_arch[index], 
                                        activation=models_activation[index], 
                                        weight_decay=weight_decay[index])
            print("[INFO] model input dim {} | output dim {}".format(self.models[models_name[index]].get_input_dim(), 
                                                                     self.models[models_name[index]].get_output_dim()))
            
            self._train_transition_model(models_name[index], epoches, validation_split)
        
    def _init_transition_model(self, name, ensembles, elites, arch, activation, weight_decay):
        self.models[name] = bnn(name, ensembles, elites)
        self.models[name].build(arch, activation, weight_decay, tf.train.AdamOptimizer, 0.001, loss=None)
        
    def _train_transition_model(self, name, epoch, validation_split):
        x = self.data_buffer.sample_all_chosen_data(["state", "action"])
        output_feature = "reward" if "reward" in name else "next_state"
        y = self.data_buffer.sample_all_chosen_data([output_feature])
        self.models[name].train(x, y, epoch, validation_split)
        
    # waiting for update
    def _calc_reward_penalty(self, rewards):
        penalty_coeff = -1.0
        return np.sqrt(np.var(rewards)) * penalty_coeff
    
    # waiting for update
    def _calc_state_penalty(self, states):
        penalty_coeff = np.array([1.0 for _ in range(np.array(states).shape[1])])
        return np.sqrt(np.var(states, axis=0)) * penalty_coeff
    
    def _format_predict_input(self, state, action):
        # state.shape is (num_inter, num_feature)
        # action.shape is (num_inter, 1)
        print(np.array(state).shape)
        print(np.array(action).shape)
        assert np.array(state).shape[0] == np.array(action).shape[0]
        outputs = []
        for one_inter_state, one_inter_action in zip(state, action):
            one_inter_state = self.data_buffer.format_one_intersection_state(one_inter_state, self.dic_traffic_env_conf["LIST_STATE_FEATURE"][:])
            one_inter_action = self.data_buffer.format_one_intersection_action(one_inter_action)
            output = np.concatenate([one_inter_state, one_inter_action])
            outputs.append(output)
        return np.array(outputs)
    
    def _format_predict_output(self, state, reward):
        output_state = []
        state = np.reshape(state, (self.dic_traffic_env_conf["NUM_INTERSECTIONS"], self.next_state_output_dim))
        for index, one_intersection_state in enumerate(state):
            one_intersection_output_state = self.data_buffer.format_one_intersection_state_reverse(one_intersection_state, self.dic_traffic_env_conf["LIST_STATE_FEATURE"][:])
            
            if self.use_current_phase:
                one_intersection_output_state["cur_phase"] = [self.current_phase[index] + 1]
            elif self.use_adjacency_matrix:
                one_intersection_output_state["adjacency_matrix"] = self.adjacency_matrix[index]
                
            output_state.append(one_intersection_output_state)
        
        output_reward = [np.mean(one_inter_reward) for one_inter_reward in reward]
        return output_state, output_reward
    
    def _predict(self, input, type):
        assert input[-1] == 0 or input[-1] == 1
        ensemble_output = self.models[type].predict(input.reshape(1, self.input_dim))
        penalty = self._calc_state_penalty(ensemble_output) if type == "next_state" else self._calc_reward_penalty(ensemble_output)
        return np.mean(ensemble_output, axis=0) - penalty
        
    def step(self, action):
        inputs = self._format_predict_input(self.state, action)
        next_state_output = []
        reward_output = []
        for one_inter_input in inputs:
            next_state_output.append(self._predict(one_inter_input, "next_state").reshape(self.next_state_output_dim, ))
            reward_output.append(self._predict(one_inter_input, "reward").reshape(self.reward_output_dim, ))
        self.state, formated_reward_output = self._format_predict_output(next_state_output, reward_output)
        
        self.current_time += self.dic_traffic_env_conf["MIN_ACTION_TIME"]
        self.current_phase = action
        
        return self.state.copy(), np.array(formated_reward_output)
                
    def reset(self):
        self.state = []
        self.current_time = 0
        for _ in range(self.dic_traffic_env_conf["NUM_INTERSECTIONS"]):
            one_inter_state = self.data_buffer.sample_chosen_data(["state"], 1)
            self.state.append(np.array(one_inter_state).reshape(-1))
        
        self.state, _ = self._format_predict_output(self.state, [])
        return self.state.copy()
    
    def set_state(self, state):
        print("[INFO] Set state with shape {}".format(np.array(state).shape))
        self.state = np.array(state)
        
    def get_current_time(self):
        return self.current_time
    
    def test_single_step(self, data):
        self.state = data["state"]
        pred_next_state, pred_reward = self.step(data["action"])
        return pred_next_state, pred_reward
            
    def _adjacency_extraction(self):
        traffic_light_node_dict = {}
        file = os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"], self.dic_traffic_env_conf["ROADNET_FILE"])
        with open("{0}".format(file)) as json_data:
            net = json.load(json_data)
            for inter in net["intersections"]:
                if not inter["virtual"]:
                    traffic_light_node_dict[inter["id"]] = {"location": {"x": float(inter["point"]["x"]),
                                                                         "y": float(inter["point"]["y"])},
                                                            "total_inter_num": None, "adjacency_row": None,
                                                            "inter_id_to_index": None,
                                                            "neighbor_ENWS": None}

            top_k = self.dic_traffic_env_conf["TOP_K_ADJACENCY"]
            total_inter_num = len(traffic_light_node_dict.keys())
            inter_id_to_index = {}

            edge_id_dict = {}
            for road in net["roads"]:
                if road["id"] not in edge_id_dict.keys():
                    edge_id_dict[road["id"]] = {}
                edge_id_dict[road["id"]]["from"] = road["startIntersection"]
                edge_id_dict[road["id"]]["to"] = road["endIntersection"]

            index = 0
            for i in traffic_light_node_dict.keys():
                inter_id_to_index[i] = index
                index += 1

            for i in traffic_light_node_dict.keys():
                location_1 = traffic_light_node_dict[i]["location"]

                row = np.array([0]*total_inter_num)
                # row = np.zeros((self.dic_traffic_env_conf["NUM_ROW"],self.dic_traffic_env_conf["NUM_col"]))
                for j in traffic_light_node_dict.keys():
                    location_2 = traffic_light_node_dict[j]["location"]
                    dist = self._cal_distance(location_1, location_2)
                    row[inter_id_to_index[j]] = dist
                if len(row) == top_k:
                    adjacency_row_unsorted = np.argpartition(row, -1)[:top_k].tolist()
                elif len(row) > top_k:
                    adjacency_row_unsorted = np.argpartition(row, top_k)[:top_k].tolist()
                else:
                    adjacency_row_unsorted = [k for k in range(total_inter_num)]
                adjacency_row_unsorted.remove(inter_id_to_index[i])
                traffic_light_node_dict[i]["adjacency_row"] = [inter_id_to_index[i]]+adjacency_row_unsorted
                traffic_light_node_dict[i]["total_inter_num"] = total_inter_num

            for i in traffic_light_node_dict.keys():
                traffic_light_node_dict[i]["total_inter_num"] = inter_id_to_index
                traffic_light_node_dict[i]["neighbor_ENWS"] = []
                for j in range(4):
                    road_id = i.replace("intersection", "road")+"_"+str(j)
                    if edge_id_dict[road_id]["to"] not in traffic_light_node_dict.keys():
                        traffic_light_node_dict[i]["neighbor_ENWS"].append(None)
                    else:
                        traffic_light_node_dict[i]["neighbor_ENWS"].append(edge_id_dict[road_id]["to"])

        return traffic_light_node_dict
    
    def _cal_distance(self, loc_dict1, loc_dict2):
        a = np.array((loc_dict1["x"], loc_dict1["y"]))
        b = np.array((loc_dict2["x"], loc_dict2["y"]))
        return np.sqrt(np.sum((a-b)**2))
    