import os
import pickle
import numpy as np
import tensorflow._api.v2.compat.v1 as tf

from utils.mo_base.mo_base_buffer import base_buffer
from utils.bnn import bnn

class MOBASEFakeEnv():
    def __init__(self, path_to_offline_dataset, dic_traffic_env_conf):
        self.dic_traffic_env_conf = dic_traffic_env_conf
        files = os.listdir(path_to_offline_dataset)
        dataset = []
        for file in files:
            with open(os.path.join(path_to_offline_dataset, file), "rb") as f:
                dataset.append(pickle.load(f))
        self.data_buffer = base_buffer(dataset, dic_traffic_env_conf)
        self.state = None
        self.current_time = 0
        
        # construct bnn
        models_name = ["reward_go", "reward_stop", "next_state_go", "next_state_stop"]     # don't change models' name
        models_ensemble = [7, 7, 7, 7]
        models_elites = [5, 5, 5, 5]
        self.input_dim = self.data_buffer.get_data_dim(["state", "action"])
        self.next_state_output_dim = self.data_buffer.get_data_dim(["next_state"])
        self.reward_output_dim = self.data_buffer.get_data_dim(["reward"])
        models_arch = [[self.input_dim, 64, 64, self.reward_output_dim], 
                            [self.input_dim, 64, 64, self.reward_output_dim],
                            [self.input_dim, 64, 64, self.next_state_output_dim],
                            [self.input_dim, 64, 64, self.next_state_output_dim]]
        models_activation = [["swish", "sigmoid", "swish", "ReLU"],
                           ["swish", "sigmoid", "swish", "ReLU"],
                           ["swish", "sigmoid", "swish", "ReLU"],
                           ["swish", "sigmoid", "swish", "ReLU"]]
        weight_decay = [[0.00001, 0.00005, 0.00005, 0.0001],
                        [0.00001, 0.00005, 0.00005, 0.0001],
                        [0.00001, 0.00005, 0.00005, 0.0001],
                        [0.00001, 0.00005, 0.00005, 0.0001],]
        epoches = 100
        validation_split = 0.2
        
        self.models = {}
        for index in range(len(models_name)):
            print("[INFO] Start building model {} | ensembles {} | elites {}".format(models_name[index], 
                                                                                     models_ensemble[index], 
                                                                                     models_elites[index]))
            self._init_transition_model(name=models_name[index], 
                                        ensembles=models_ensemble[index], 
                                        elites=models_elites[index], 
                                        arch=models_arch[index], 
                                        activation=models_activation[index], 
                                        weight_decay=weight_decay[index])
            print("[INFO] model input dim {} | output dim {}".format(self.models[models_name[index]].get_input_dim(), 
                                                                     self.models[models_name[index]].get_output_dim()))
            
            self._train_transition_model(models_name[index], epoches, validation_split)
        
    def _init_transition_model(self, name, ensembles, elites, arch, activation, weight_decay):
        self.models[name] = bnn(name, ensembles, elites)
        self.models[name].build(arch, activation, weight_decay, tf.train.AdamOptimizer, 0.001, loss=None)
        
    def _train_transition_model(self, name, epoch, validation_split):
        x = self.data_buffer.sample_all_chosen_data(["state", "action"])
        output_feature = "reward" if "reward" in name else "next_state"
        y = self.data_buffer.sample_all_chosen_data([output_feature])
        is_stop = True if "stop" in name else False
        x, y = self._filter_by_action(x, y, is_stop)
        self.models[name].train(x, y, epoch, validation_split)
        
    # waiting for update
    def _calc_reward_penalty(self, rewards):
        penalty_coeff = 1.0
        return np.sqrt(np.var(rewards)) * penalty_coeff
    
    # waiting for update
    def _calc_state_penalty(self, states):
        penalty_coeff = np.array([1.0 for _ in range(np.array(states).shape[1])])
        return np.sqrt(np.var(states, axis=0)) * penalty_coeff
    
    def _format_predict_input(self, state, action):
        # state.shape is (num_inter, num_feature)
        # action.shape is (num_inter, 1)
        print(np.array(state).shape)
        print(np.array(action).shape)
        assert np.array(state).shape[0] == np.array(action).shape[0]
        outputs = []
        for one_inter_state, one_inter_action in zip(state, action):
            one_inter_state_in_lane = self.data_buffer.format_one_intersection_state(one_inter_state, self.dic_traffic_env_conf["LIST_STATE_FEATURE"][:])
            one_inter_action_in_lane = self.data_buffer.format_one_intersection_action(one_inter_action)
            output = []
            for one_lane_state, one_lane_action in zip(one_inter_state_in_lane, one_inter_action_in_lane):
                output.append(np.concatenate([one_lane_state, one_lane_action]))
            outputs.append(output)
        return np.array(outputs)
    
    def _format_predict_output(self, state, reward):
        output_state = []
        state = np.reshape(state, (self.dic_traffic_env_conf["NUM_INTERSECTIONS"], self.dic_traffic_env_conf["NUM_LANE"], self.next_state_output_dim))
        for one_intersection_state in state:
            one_intersection_output_state = {}
            for index, feature in enumerate(self.dic_traffic_env_conf["LIST_STATE_FEATURE"]):
                one_intersection_output_state[feature] = one_intersection_state[:, index]
            output_state.append(one_intersection_output_state)
        
        output_reward = [np.sum(one_inter_reward) for one_inter_reward in reward]
        return output_state, output_reward
    
    def _predict(self, input, type):
        assert input[-1] == 0 or input[-1] == 1
        name = "{}_{}".format(type, "go") if input[-1] == 1 else "{}_{}".format(type, "stop")
        ensemble_output = self.models[name].predict(input.reshape(1, self.input_dim))
        penalty = self._calc_state_penalty(ensemble_output) if type == "next_state" else self._calc_reward_penalty(ensemble_output)
        return np.mean(ensemble_output, axis=0) - penalty
    
    def _filter_by_action(self, inputs, outputs, is_stop):
        filtered_inputs = []
        filtered_outputs = []
        for input, output in zip(inputs, outputs):
            if is_stop and input[-1] == 0 or not is_stop and input[-1] == 1:
                filtered_inputs.append(input)
                filtered_outputs.append(output)
                
        return np.array(filtered_inputs), np.array(filtered_outputs)
        
    def step(self, action):
        inputs = self._format_predict_input(self.state, action)
        next_state_output = []
        reward_output = []
        for one_inter_input in inputs:
            one_inter_next_state = []
            one_inter_reward = []
            for one_lane_input in one_inter_input:
                one_inter_next_state.extend(self._predict(one_lane_input, "next_state").reshape(self.next_state_output_dim, ))
                one_inter_reward.extend(self._predict(one_lane_input, "reward").reshape(self.reward_output_dim, ))
            next_state_output.append(one_inter_next_state)
            reward_output.append(one_inter_reward)
        self.state, formated_reward_output = self._format_predict_output(next_state_output, reward_output)
        
        self.current_time += self.dic_traffic_env_conf["MIN_ACTION_TIME"]
        
        return self.state.copy(), np.array(formated_reward_output)
                
    def reset(self):
        self.state = []
        self.current_time = 0
        for _ in range(self.dic_traffic_env_conf["NUM_INTERSECTIONS"]):
            one_inter_state = self.data_buffer.sample_chosen_data(["state"], 12)
            self.state.append(np.array(one_inter_state).reshape(-1))
        
        self.state, _ = self._format_predict_output(self.state, [])
        return self.state.copy()
    
    def set_state(self, state):
        print("[INFO] Set state with shape {}".format(np.array(state).shape))
        self.state = np.array(state)
        
    def get_current_time(self):
        return self.current_time
    
    def test_single_step(self, data):
        self.state = data["state"]
        pred_next_state, pred_reward = self.step(data["action"])
        return pred_next_state, pred_reward
            