import numpy as np
import pickle
import os
import traceback
import random

random.seed(1937)

def get_reward_from_features(rs):
    if "reward" in rs.keys():
        reward = {"lane_num_waiting_vehicle_in": np.array(rs["reward"]),
            "traffic_movement_pressure_queue_efficient": np.array(rs["reward"])}
    elif "lane_num_waiting_vehicle_in" in rs["state"].keys():
        reward = {"lane_num_waiting_vehicle_in": np.sum(rs["state"]["lane_num_waiting_vehicle_in"]),
                "traffic_movement_pressure_queue_efficient": np.absolute(np.sum(rs["state"]["traffic_movement_pressure_queue_efficient"]))}
    else:
        raise AssertionError
    return reward
    


def cal_reward(rs, rewards_components):
    r = 0
    for component, weight in rewards_components.items():
        if weight == 0:
            continue
        if component not in rs.keys():
            continue
        if rs[component] is None:
            continue
        r += rs[component] * weight
    return r


class ConstructSample:

    def __init__(self, path_to_samples, cnt_round, dic_traffic_env_conf, dic_path):
        self.parent_dir = path_to_samples
        self.path_to_samples = path_to_samples + "/round_" + str(cnt_round)
        self.cnt_round = cnt_round
        self.dic_traffic_env_conf = dic_traffic_env_conf
        self.dic_path = dic_path
        print("Sample Config: {}".format(dic_traffic_env_conf))

        self.logging_data_list_per_gen = None
        self.hidden_states_list = None
        self.samples = []
        self.samples_all_intersection = [None]*self.dic_traffic_env_conf['NUM_INTERSECTIONS']

        self.interval = self.dic_traffic_env_conf["MIN_ACTION_TIME"]
        self.measure_time = self.dic_traffic_env_conf["MEASURE_TIME"]

    def load_data(self, folder, i):
        '''
            Load the data stored in './inter_{}.pkl' by Generator i
        '''
        try:
            f_logging_data = open(os.path.join(self.path_to_samples, folder, "inter_{0}.pkl".format(i)), "rb")
            logging_data = pickle.load(f_logging_data)
            f_logging_data.close()
            return 1, logging_data

        except Exception:
            print("Error occurs when making samples for inter {0}".format(i))
            print('traceback.format_exc():\n%s' % traceback.format_exc())
            return 0, None

    def load_data_for_system(self, folder):
        self.logging_data_list_per_gen = []
        print("Load data for system in ", folder)
        for i in range(self.dic_traffic_env_conf['NUM_INTERSECTIONS']):
            pass_code, logging_data = self.load_data(folder, i)     # load inter_{}.pkl
            if pass_code == 0:
                return 0
            self.logging_data_list_per_gen.append(logging_data)
        return 1

    def construct_state(self, features, time, i):
        '''
            Apply binary phase expansion to state
        '''
        state = self.logging_data_list_per_gen[i][time]
        # print("state: {}".format(state))
        # time is the iterator in make_reward, state['time'] is the index in data
        # check if two variables are consistent, if not, the data is not fit for this action interval 
        assert time == state["time"] 
        if self.dic_traffic_env_conf["BINARY_PHASE_EXPANSION"]:
            state_after_selection = {}
            for key, value in state["state"].items():
                if key in features:
                    if "cur_phase" in key:
                        state_after_selection[key] = self.dic_traffic_env_conf['PHASE'][value[0]]
                    else:
                        state_after_selection[key] = value
        else:
            state_after_selection = {key: value for key, value in state["state"].items() if key in features}
        # print("default setting: {}, sample: {}".format(
        #     self.dic_traffic_env_conf["BINARY_PHASE_EXPANSION"], state_after_selection))
        return state_after_selection

    def construct_reward(self, rewards_components, time, i):
        '''
            Use next_state to compute the reward
        '''
        # print(time, time + self.measure_time - 1)
        rs = self.logging_data_list_per_gen[i][time + self.measure_time - 1]
        
        assert time + self.measure_time - 1 == rs["time"]
        rs = get_reward_from_features(rs)
        r_instant = cal_reward(rs, rewards_components)
        # average
        list_r = []
        for t in range(time, time + self.measure_time):
            rs = self.logging_data_list_per_gen[i][t]
            assert t == rs["time"]
            rs = get_reward_from_features(rs)
            r = cal_reward(rs, rewards_components)
            list_r.append(r)
        r_average = np.average(list_r)

        return r_instant, r_average

    def judge_action(self, time, i):
        if self.logging_data_list_per_gen[i][time]['action'] == -1:
            raise ValueError
        else:
            return self.logging_data_list_per_gen[i][time]['action']

    def make_reward(self, folder, i, offline_data_ratio):
        if self.samples_all_intersection[i] is None:
            self.samples_all_intersection[i] = []
        if i % 100 == 0:
            print("make reward for inter {0} in folder {1}".format(i, folder))
        list_samples = []
        try:
            if len(self.logging_data_list_per_gen[i]) == 0:
                total_time = 0
            else:
                total_time = int(self.logging_data_list_per_gen[i][-1]['time'] + 1)
            for time in range(0, total_time - self.measure_time + 1, self.interval):
                state = self.construct_state(self.dic_traffic_env_conf["LIST_STATE_FEATURE"], time, i)
                reward_instant, reward_average = self.construct_reward(self.dic_traffic_env_conf["DIC_REWARD_INFO"],
                                                                       time, i)
                action = self.judge_action(time, i)

                if time + self.interval == total_time:
                    next_state = self.construct_state(self.dic_traffic_env_conf["LIST_STATE_FEATURE"],
                                                      time + self.interval - 1, i)
                else:
                    next_state = self.construct_state(self.dic_traffic_env_conf["LIST_STATE_FEATURE"],
                                                      time + self.interval, i)
                sample = [state, action, next_state, reward_average, reward_instant, time,
                          folder+"-"+"round_{0}".format(self.cnt_round)]
                list_samples.append(sample)

            self.samples_all_intersection[i].extend(list_samples)
            # print(len(self.samples_all_intersection[i]))
            
            # load offline_data
            if "PATH_TO_OFFLINE_DATA" in self.dic_path.keys():
                offline_samples = []
                
                path_to_offline_data_i = os.path.join(\
                    self.dic_path["PATH_TO_OFFLINE_DATA"], "inter_data", "inter_{}".format(i))
                path_dir = os.listdir(path_to_offline_data_i)
                num_samples = np.min([int(len(path_dir) * offline_data_ratio)+1, len(path_dir)])     # follow the setting in MOPO 
                samples_list = random.sample(path_dir, num_samples)
                
                for sample_path in samples_list:
                    with open(os.path.join(path_to_offline_data_i, sample_path), "rb") as sample_file:
                        sample = pickle.load(sample_file)
                    offline_samples.append(sample)
                
                # delete some generated samples
                num_remain = int(len(self.samples_all_intersection[i]) * 0.95)
                remain_list = random.sample(list(range(len(self.samples_all_intersection[i]))), num_remain)
                new_samples = []
                for remain in remain_list:
                    new_samples.append(self.samples_all_intersection[i][remain])
                self.samples_all_intersection[i] = new_samples
                
                # add offline data samples
                self.samples_all_intersection[i].extend(offline_samples)
            
            return 1
        except:
            print("Error occurs when making rewards in generator {0} for intersection {1}".format(folder, i))
            print('traceback.format_exc():\n%s' % traceback.format_exc())
            return 0

    def make_reward_for_system(self):
        for folder in os.listdir(self.path_to_samples):
            print(folder)
            if "generator" not in folder:
                continue
            if not self.load_data_for_system(folder):
                continue
            for i in range(self.dic_traffic_env_conf['NUM_INTERSECTIONS']):
                offline_data_ratio = 0.05
                if "OFFLINE_METHOD" in self.dic_traffic_env_conf and self.dic_traffic_env_conf["OFFLINE_METHOD"] == "SIMPLE":
                    offline_data_ratio = 1.00
                pass_code = self.make_reward(folder, i, offline_data_ratio)
                if pass_code == 0:
                    continue

        for i in range(self.dic_traffic_env_conf['NUM_INTERSECTIONS']):
            self.dump_sample(self.samples_all_intersection[i], "inter_{0}".format(i))

    def dump_sample(self, samples, folder):
        if folder == "":
            with open(os.path.join(self.parent_dir, "total_samples.pkl"), "ab+") as f:
                pickle.dump(samples, f, -1)
        elif "inter" in folder:
            with open(os.path.join(self.parent_dir, "total_samples_{0}.pkl".format(folder)), "ab+") as f:
                pickle.dump(samples, f, -1)
        else:
            with open(os.path.join(self.path_to_samples, folder, "samples_{0}.pkl".format(folder)), 'wb') as f:
                pickle.dump(samples, f, -1)
                
            
class OfflineConstructSample:

    def __init__(self, path_to_samples, path_to_data, cnt_round, dic_traffic_env_conf):
        self.parent_dir = path_to_data
        self.path_to_samples = path_to_samples
        self.cnt_round = cnt_round
        self.dic_traffic_env_conf = dic_traffic_env_conf
        print("Sample Config: {}".format(dic_traffic_env_conf))

        self.logging_data_list_per_gen = None
        self.hidden_states_list = None
        self.samples = []
        self.samples_all_intersection = [None]*self.dic_traffic_env_conf['NUM_INTERSECTIONS']

        self.interval = self.dic_traffic_env_conf["MEASURE_TIME"]
        self.measure_time = self.dic_traffic_env_conf["MEASURE_TIME"]

    def load_data(self, i):
        '''
            Load the data stored in './inter_{}.pkl' by Generator i
        '''
        try:
            f_logging_data = open(os.path.join(self.path_to_samples, "inter_{0}.pkl".format(i)), "rb")
            logging_data = pickle.load(f_logging_data)
            f_logging_data.close()
            return 1, logging_data

        except Exception:
            print("Error occurs when making samples for inter {0}".format(i))
            print('traceback.format_exc():\n%s' % traceback.format_exc())
            return 0, None

    def load_data_for_system(self):
        self.logging_data_list_per_gen = []
        for i in range(self.dic_traffic_env_conf['NUM_INTERSECTIONS']):
            pass_code, logging_data = self.load_data(i)     # load inter_{}.pkl
            if pass_code == 0:
                return 0
            self.logging_data_list_per_gen.append(logging_data)
        return 1

    def construct_state(self, features, time, i):
        '''
            Apply binary phase expansion to state
        '''
        state = self.logging_data_list_per_gen[i][time]
        
        # time is the iterator in make_reward, state['time'] is the index in data
        # check if two variables are consistent, if not, the data is not fit for this action interval 
        assert time % self.dic_traffic_env_conf["CITYFLOW_MAX_TIME"] == state["time"] 
        if self.dic_traffic_env_conf["BINARY_PHASE_EXPANSION"]:
            state_after_selection = {}
            for key, value in state["state"].items():
                if key in features:
                    if "cur_phase" in key:
                        state_after_selection[key] = self.dic_traffic_env_conf['PHASE'][value[0]]
                    else:
                        state_after_selection[key] = value
        else:
            state_after_selection = {key: value for key, value in state["state"].items() if key in features}
        # print("default setting: {}, sample: {}".format(
        #     self.dic_traffic_env_conf["BINARY_PHASE_EXPANSION"], state_after_selection))
        return state_after_selection

    def construct_reward(self, rewards_components, time, i):
        '''
            Use next_state to compute the reward
        '''
        # print(time, time + self.measure_time - 1)
        
        rs = self.logging_data_list_per_gen[i][time + self.measure_time - 1]
        assert time % self.dic_traffic_env_conf["CITYFLOW_MAX_TIME"] + self.measure_time - 1 == rs["time"]
        rs = get_reward_from_features(rs)
        r_instant = cal_reward(rs, rewards_components)
        # average
        list_r = []
        for t in range(time, time + self.measure_time):
            rs = self.logging_data_list_per_gen[i][t]
            assert t % self.dic_traffic_env_conf["CITYFLOW_MAX_TIME"] == rs["time"]
            rs = get_reward_from_features(rs)
            r = cal_reward(rs, rewards_components)
            list_r.append(r)
        r_average = np.average(list_r)

        return r_instant, r_average

    def judge_action(self, time, i):
        if self.logging_data_list_per_gen[i][time]['action'] == -1:
            raise ValueError
        else:
            return self.logging_data_list_per_gen[i][time]['action']

    def make_reward(self, i):
        if self.samples_all_intersection[i] is None:
            self.samples_all_intersection[i] = []
        if i % 100 == 0:
            print("make reward for inter {0}".format(i))
        list_samples = []
        try:
            total_time = len(self.logging_data_list_per_gen[i])
            for time in range(0, total_time - self.measure_time + 1, self.interval):
                state = self.construct_state(self.dic_traffic_env_conf["LIST_STATE_FEATURE"], time, i)
                reward_instant, reward_average = self.construct_reward(self.dic_traffic_env_conf["DIC_REWARD_INFO"],
                                                                       time, i)
                action = self.judge_action(time, i)

                if time + self.interval == total_time:
                    next_state = self.construct_state(self.dic_traffic_env_conf["LIST_STATE_FEATURE"],
                                                      time + self.interval - 1, i)
                else:
                    next_state = self.construct_state(self.dic_traffic_env_conf["LIST_STATE_FEATURE"],
                                                      time + self.interval, i)
                sample = [state, action, next_state, reward_average, reward_instant, time, 'offline_data']
                list_samples.append(sample)

            self.samples_all_intersection[i].extend(list_samples)
            # print(len(self.samples_all_intersection[i]))
            return 1
        except:
            print("Error occurs when making offline rewards for intersection {0}".format(i))
            print('traceback.format_exc():\n%s' % traceback.format_exc())
            return 0

    # specifically for offline sampling
    def make_reward_for_system(self):
        if self.load_data_for_system(): 
            for i in range(self.dic_traffic_env_conf['NUM_INTERSECTIONS']):
                pass_code = self.make_reward(i)
                if pass_code == 0:
                    continue
                
        self.dump_sample()

    def dump_sample(self):
        inter_data_path = os.path.join(os.path.dirname(self.parent_dir), "inter_data")
        for i in range(self.dic_traffic_env_conf['NUM_INTERSECTIONS']):
            one_inter_path = os.path.join(inter_data_path, "inter_{}".format(i))
            if not os.path.exists(one_inter_path):
                os.makedirs(one_inter_path)
            for j in range(len(self.samples_all_intersection[i])):
                with open(os.path.join(self.parent_dir, "sample_{0}_{1}.pkl".format(i,j)), "ab+") as f:
                    pickle.dump(self.samples_all_intersection[i][j], f, -1)
                with open(os.path.join(one_inter_path, "sample_{0}_{1}.pkl".format(i,j)), "wb") as f:
                    pickle.dump(self.samples_all_intersection[i][j], f, -1)
        
            
