######################## 文件简介 ########################
# 将smec_liftsim模拟器包装成训练环境的文件，实现了一个SmecRLEnv类
from smec_liftsim.data_generator.generator_proxy import set_seed
from smec_liftsim.data_generator.random_data_generator import RandomDataGenerator
from smec_liftsim.mansion_configs import MansionConfig
from smec_liftsim.mansion_manager import MansionManager
from smec_liftsim.utils import ElevatorHallCall
from smec_rl_components.smec_graph_build import *
import configparser
import random
import os
import torch
import gym
from gym.spaces import Discrete, Dict, Box
from gym.vector.async_vector_env import AsyncVectorEnv
from smec_rl_components.smec_reward import *
from smec_rl_components.normalization import *
from smec_liftsim.smec_constants import *
from copy import deepcopy

class TimeEstimater:
    def __init__(self):
        self.elev_num = 4
        self.floor_num = 16

    def df2time(self, df, elev=None):
        if df == 0:
            if elev.is_fully_open or elev._is_door_opening:
                open_time = (1 - elev._door_open_rate) / elev._door_open_velocity
                lag_time = elev._keep_door_open_left
                close_time = 1 / elev._door_open_velocity
                return open_time + lag_time + close_time
            elif elev._is_door_closing:
                close_time = (1 - elev._door_open_rate) / elev._door_open_velocity
                return close_time
            return 7.3
        if df == 1:
            return 12
        elif df == 2:
            return 13.5
        else:
            return 13.5 + 1.2 * (df - 2)

    def cal_stop_floor(self, x, v, max_a, floor_height):
        if v > 0:
            run_dir = 1
        elif v < 0:
            run_dir = -1
        else:
            run_dir = 0
        least_run_dis = v * v / (2 * max_a)
        least_stop_pos = x + run_dir * least_run_dis
        stop_flr = max((least_stop_pos + run_dir * (floor_height - 0.001)), 0) // floor_height
        return int(stop_flr)

    # 计算电梯第一次停下来所需时间，因为可能有初速度，所以要特殊处理
    def cal_first_stop_time(self, elev, first_stop_flr, state):
        # self.elevator_mask[elev.elev_idx] = 1
        first_stop_pos = first_stop_flr * elev._floor_height
        # 如果第一落点是纯hallcall且满载了，那么这个电梯不能选
        if (state == 2 or state == 4) and elev._is_overloaded:
            # self.elevator_mask[elev.elev_idx] = 0
            return 10000
        cur_spd = elev._current_velocity
        df = abs(elev._current_position - first_stop_pos) / elev._floor_height
        df = int(round(df))
        consume_time = max(self.df2time(df, elev) - cur_spd, 0)  # 观察经验公式
        return consume_time

    # 根据电梯的初速度，已分配的carcall和hallcall，可以把电梯接下来的运动分为三段：
    # 如电梯现在正在5楼往上行，那么
    # r1: 5楼到16楼之间的up call，以及car call。
    # r2: 16楼到1楼之间的dn call。
    # r3: 1楼到5楼之间的up call。
    # 按常理出牌的话，car call只会存在于r1。
    # 这里只记录电梯要停靠的楼层的位置以及state（便于结合权重）。
    # 在上面那个例子中，加入up call为[1,14], dn call为[16,7], car call为[8]的话，route就为：
    # [8, 14, 16, 7, 1]，这样就可以算出df来计算运行时间了。
    def get_elev_route(self, elev, srv_dir, stp_flr, cur_pos, car_call, hall_up_dn_call):
        route = []
        # 正常来说，carcall只在r1
        f_1 = (srv_dir + 1) * self.floor_num // 2  # f when 1, 0 when -1
        f_m1 = (-srv_dir + 1) * self.floor_num // 2  # f when -1, 0 when 1
        one_1 = (srv_dir + 1) * 1 // 2  # 1 when 1, 0 when -1
        one_m1 = (-srv_dir + 1) * 1 // 2  # 1 when -1, 0 when 1
        # srv_dir=1 010, srv_dir=-1 101
        rparam = [(stp_flr, f_1, srv_dir, one_m1), (f_1, f_m1, -srv_dir, one_1), (f_m1, stp_flr, srv_dir, one_m1)]
        for rnum, rp in enumerate(rparam):
            for i in range(rp[0], rp[1], rp[2]):
                # state: 8 park call; 1 car, 2 up, 3 car and up, 4 dn, 5 car and dn.
                state = 0
                if rnum == 0 and i in car_call:
                    state += 1
                if i in hall_up_dn_call[rp[3]]:
                    state += 2 * (1 + rp[3])
                if state != 0:
                    route.append((i, state))
            if rnum == 1:
                # 电梯运行方向上必须要有一个目的地，如果没有call，就是被重新分配搞的，要手动加一个停靠位置
                if route == [] or (route[0][0] * elev._floor_height - cur_pos) * srv_dir < 0:
                    route.insert(0, (stp_flr, 8))
        return route

    # 计算loss时的权重
    def floor_state2weight(self, floor, state, hallcall_weight):
        # 1 car 2 up 4 dn
        weight = 0
        if state // 2 % 2 == 1:
            # weight += self.cal_accumulate_person(self.weight_t, floor, 2, self.updn_delta_time[floor] / 60)
            weight = hallcall_weight[floor]
        if state // 4 % 2 == 1:
            # weight += self.cal_accumulate_person(self.weight_t, floor, 3, self.updn_delta_time[floor+self.floor_num] / 60)
            weight = hallcall_weight[floor+self.floor_num]
        return weight

    # 不用模拟，用两个delta floor的距离来近似计算，计算公式由实验得出。
    def estimate_elev_route_loss(self, elev, hallcall_weight, hallcall=None):
        copy_elev = deepcopy(elev)
        if hallcall:
            copy_elev.replace_hall_call(hallcall)
        cur_flr = copy_elev._sync_floor
        cur_pos = copy_elev._current_position
        cur_spd = copy_elev._current_velocity
        srv_dir = copy_elev._service_direction
        car_call = copy_elev._car_call
        hall_up_dn_call = [copy_elev._hall_up_call, copy_elev._hall_dn_call]
        stp_flr = self.cal_stop_floor(cur_pos, cur_spd, 0.557, 3.0)

        # 如果电梯之前是空闲的，可能分配了hallcall之后srv_dir也是0没来得及更新，先运行一个dt给他更新一下。
        if srv_dir == 0:
            if hall_up_dn_call[0] + hall_up_dn_call[1] == []:
                return 0
            else:
                copy_elev.run_elevator()
                srv_dir = copy_elev._service_direction
                # print(hall_up_dn_call, cur_flr, cur_spd, srv_dir)

        route = self.get_elev_route(copy_elev, srv_dir, stp_flr, cur_pos, car_call, hall_up_dn_call)

        # 从cur_flr以cur_spd
        # 从cur_pos以cur_spd开始完成电梯的旅程route，特殊处理第一次停靠。
        loss = 0
        accumulate_time = 0
        # 加入floor_weights, TODO: carcall可能应该用前一时间片的权重呢。
        # route肯定不为空
        assert len(route) > 0
        first_stop_flr = route[0][0]

        # 第一次停靠因为可能有初速度，需要特殊处理，还要处理超载的问题
        consume_time = self.cal_first_stop_time(copy_elev, first_stop_flr, route[0][1])
        accumulate_time += consume_time
        loss += accumulate_time * self.floor_state2weight(first_stop_flr, route[0][1], hallcall_weight)

        # 其他段路可以直接用实验公式计算。
        last_flr = first_stop_flr
        for stop_flr in route[1:]:
            df = abs(stop_flr[0] - last_flr)
            consume_time = self.df2time(df, copy_elev)
            accumulate_time += consume_time
            loss += accumulate_time * self.floor_state2weight(stop_flr[0], stop_flr[1], hallcall_weight)
            last_flr = stop_flr[0]
        # maximum_weight = copy_elev._maximum_capacity
        # load_weight = copy_elev._load_weight
        # loss += load_weight / (maximum_weight * 0.8) * 100
        # print(route, loss)
        return loss


class SmecRLEnv(gym.Env):
    """
    RL environment for SMEC elevators.
    """
    def __init__(self, data_file='./smec_rl/simple_dataset_v2.csv', config_file=None, render=True, forbid_unrequired=True, seed=None, forbid_uncalled=False,
    # def __init__(self, data_file='train_data/new/lunchpeak/LunchPeak1_elvx.csv', config_file=None, render=True, forbid_unrequired=True, seed=None, forbid_uncalled=False,
                 use_graph=True, gamma=0.99, real_data=True, use_advice=False, special_reward=False, data_dir=None, file_begin_idx=None, dos=''):

        self.id = "Liftsim"

        if not config_file:
            config_file = os.path.join(os.path.dirname(__file__) + '/rl_config2.ini')
        file_name = config_file
        self.forbid_uncalled = forbid_uncalled
        config = configparser.ConfigParser()
        config.read(file_name)

        time_step = float(config['Configuration']['RunningTimeStep'])
        assert time_step <= 1, 'RunningTimeStep in config.ini must be less than 1 in order to ensure accuracy.'
        # dos = '06:00-12:00'
        # dos = '00:00-06:00'
        # dos = '50:00-60:00'
        # dos = ''
        if dos == '':
            st = 0
        else:
            ts = dos.split('-')[0].split(':')
            st = int(ts[0]) * 60 + int(ts[1])

        # person_generator = FixedDataGenerator(data_file=data_file, data_dir=data_dir, file_begin_idx=file_begin_idx, data_of_section=dos)
        person_generator = RandomDataGenerator(data_dir=data_dir, data_of_section=dos)
        # person_generator = RandomDataGenerator(data_dir=data_dir, data_of_section=dos, random_or_load_or_save=3)
        # person_generator = RandomDataGenerator(data_dir=data_dir, data_of_section=dos, random_or_load_or_save=1, save_idx=60)

        self._config = MansionConfig(
            dt=time_step,
            number_of_floors=int(config['MansionInfo']['NumberOfFloors']),
            floor_height=float(config['MansionInfo']['FloorHeight']),
            maximum_acceleration=float(config['MansionInfo']['Acceleration']),
            maximum_speed=float(config['MansionInfo']['RateSpeed']),
            person_entering_time=float(config['MansionInfo']['PersonEnterTime']),
            door_opening_time=float(config['MansionInfo']['DoorOpeningTime']),
            door_closing_time=float(config['MansionInfo']['DoorClosingTime']),
            keep_door_open_lag=float(config['MansionInfo']['DoorKeepOpenLagTime']),
            door_well_time2=float(config['MansionInfo']['DwellTime2']),
            maximum_parallel_entering_exiting_number=int(config['MansionInfo']['ParallelEnterNum']),
            rated_load=int(config['MansionInfo']['RateLoad']),
            start_time=st
        )

        self.mansion = MansionManager(int(config['MansionInfo']['ElevatorNumber']), person_generator, self._config,
                                      config['MansionInfo']['Name'])
        self.use_graph = use_graph
        self.viewer = None
        self.open_render = render
        if render:
        # if True:
            from smec_liftsim.rendering import Render
            self.viewer = Render(self.mansion)
        self.elevator_num = self.mansion.attribute.ElevatorNumber
        self.floor_num = int(config['MansionInfo']['NumberOfFloors'])
        self.waiting_times = []
        self.forbid_unrequired = forbid_unrequired

        if seed is not None:
            self.seed(seed)
        self.seed_c = seed

        # gym specific settings
        self.metadata = {'render.modes': []}
        self.gamma = gamma
        self.reward_range = (-float('inf'), float('inf'))
        self.reward_threshold = 0.0
        self.trials = 100
        self.spec = None
        self._max_episode_steps = 10000

        # Set these in ALL subclasses
        # self.action_space = Box(low=0, high=self.floor_num * 2 - 1, shape=(self.elevator_num,), dtype=np.int64)
        self.action_space = Discrete(self.elevator_num)

        ele_f = (self.elevator_num, self.floor_num)
        self.graph_node_num = (self.elevator_num + self.floor_num) * 2
        self.gb = GraphBuilder(self.elevator_num, self.floor_num, 'cpu')
        self.empty_adj_matrix = self.gb.get_zero_adj_matrix()
        self.cur_adj_matrix = self.empty_adj_matrix.clone()
        self.empty_node_feature = self.gb.get_zero_node_feature()
        self.cur_node_feature = self.empty_node_feature.clone()
        assert self.use_graph

        self.use_advice = use_advice
        self.special_reward = special_reward
        candidate_num = self.elevator_num
        if use_advice:
            candidate_num += 1

        self.observation_space = Dict(
            {'adj_m': Box(low=-float('inf'), high=float('inf'), shape=(self.graph_node_num, self.graph_node_num)),
             'node_feature_m': Box(low=-float('inf'), high=float('inf'), shape=(self.graph_node_num, 3)),
             'legal_masks': Box(low=-float('inf'), high=float('inf'), shape=(self.floor_num * 2, candidate_num,)),
             'elevator_mask': Box(low=-float('inf'), high=float('inf'), shape=(self.elevator_num, self.floor_num * 2,)),
             'floor_mask': Box(low=-float('inf'), high=float('inf'), shape=(self.floor_num * 2,)),
             'distances': Box(low=-float('inf'), high=float('inf'), shape=(self.floor_num * 2, self.elevator_num,)),
             'valid_action_mask': Box(low=0, high=1, shape=(self.floor_num * 2,))
             })

        self.reward_filter = Identity()
        self.reward_filter = RewardFilter(self.reward_filter, shape=(), gamma=gamma, clip=None)

        # state normalization
        self.state_filter = Identity()
        self.state_filter = ZFilter(self.state_filter, shape=[self.graph_node_num, 3], clip=None)

        # self.real_dataset = generate_dataset()
        # self.data_idx = 0
        # self.next_generate_person = self.real_dataset[self.data_idx]
        self.evaluate_info = {'valid_up_action': 0,
                              'advice_up_action': 0,
                              'valid_dn_action': 0,
                              'advice_dn_action': 0}
        self.estimater = TimeEstimater()

    @staticmethod
    def get_filter_by_list(list_len, query):
        cur_elv_mask = torch.tensor([0.0 for _ in range(list_len)])
        for elev in query:
            cur_elv_mask[elev] = 1.0
        return cur_elv_mask

    def get_action_mask(self, device):
        # M JY: add advice choice
        candidate_num = self.elevator_num + 1 if self.use_advice else self.elevator_num

        # get a list of action candidates by rules given pre-defined floors.
        unallocated_up, unallocated_dn = self.mansion.get_unallocated_floors()
        floor2elv_masks = []
        # handle up floors
        for idx in range(self.floor_num):
            if idx not in unallocated_up:
                cur_elv_mask = torch.tensor([1.0 for _ in range(candidate_num)])
            else:
                conv_elevators = self.mansion.get_convenience_elevators(up_or_down=True, floor_id=idx)
                if len(conv_elevators) > 0:  # convenient elevators exist
                    cur_elv_mask = self.get_filter_by_list(candidate_num, conv_elevators)
                    if self.use_advice:
                        cur_elv_mask[-1] = 1.0
                else:
                    cur_elv_mask = torch.tensor([1.0 for _ in range(candidate_num)])
            floor2elv_masks.append(cur_elv_mask)

        # handle down floors
        for idx in range(self.floor_num):
            if idx not in unallocated_dn:
                cur_elv_mask = torch.tensor([1.0 for _ in range(candidate_num)])
            else:
                conv_elevators = self.mansion.get_convenience_elevators(up_or_down=False, floor_id=idx)
                if len(conv_elevators) > 0:  # convenient elevators exist
                    cur_elv_mask = self.get_filter_by_list(candidate_num, conv_elevators)
                    if self.use_advice:
                        cur_elv_mask[-1] = 1.0
                # elif len(uncalled_elevators) > 0:  # non-called elevators exist
                #     cur_elv_mask = self.get_filter_by_list(self.elevator_num, uncalled_elevators)
                else:
                    cur_elv_mask = torch.tensor([1.0 for _ in range(candidate_num)])
            floor2elv_masks.append(cur_elv_mask)

        elevator_mask = torch.stack(floor2elv_masks).to(device)
        return elevator_mask

    def get_action_mask_plus(self, device):
        # get a list of action candidates by rules given pre-defined floors.
        unallocated_up, unallocated_dn = self.mansion.get_unallocated_floors()

        data = self.mansion._person_generator.prob[min(int(self._config.raw_time)//60, 59)]  # 16*16
        floor_mask = np.zeros(self.floor_num*2)
        for src in range(self.floor_num):
            dn = data[src][:src]
            dn_sum = np.sum(dn)
            up = data[src][src:]
            up_sum = np.sum(up)

            floor_mask[src] = up_sum
            floor_mask[src+self.floor_num] = dn_sum
        floor_mask = torch.from_numpy(floor_mask)

        # 合并floor_mask 2f x 1
        for up in unallocated_up:
            floor_mask[up] += len(self.mansion._wait_upward_persons_queue[up])
        for dn in unallocated_dn:
            floor_mask[dn+self.floor_num] += len(self.mansion._wait_downward_persons_queue[dn])

        # 不管这个生成概率，只用当前的方便电梯
        convenience_mask = self.mansion.get_convenience_mask()  # e x 2f
        elevator_mask = torch.from_numpy(convenience_mask).to(device)  # e x 2f
        return elevator_mask, floor_mask

    def get_time(self):
        raw_time = self._config.raw_time
        cur_day = raw_time // (24 * 3600)
        cur_time = raw_time % (24 * 3600)
        return [cur_day, int(cur_time // 3600 + 7), int(cur_time % 3600 // 60), int(cur_time % 60)]

    def step(self, actions):
        return self.step_rl_dp(actions)

    def step_rl_dp(self, actions):
        # print(actions.shape)
        if type(actions) == np.ndarray:
            floor2elevators = torch.from_numpy(actions)
            advantage_floor = None
        else:
            floor2elevators, advantage_floor = actions.split(32, 0)
            # print(actions.shape, floor2elevators.shape, advantage_floor.shape)
            assert type(floor2elevators) == torch.Tensor, "only support tensor action"  # unwrapped raw action.

            # M JY: add advice choice
            if self.use_advice:
                advice_floor2elevators = self.get_floor2elevator_dis(floor2elevators.device).cpu().numpy()
                advice_floor2elevators = np.argmin(advice_floor2elevators, axis=1)

            floor2elevators = floor2elevators.squeeze()
            advantage_floor = advantage_floor.squeeze()

        unallocated_up, unallocated_dn = self.mansion.get_unallocated_floors()
        all_elv_up_fs, all_elv_down_fs = [[] for _ in range(self.elevator_num)], [[] for _ in range(self.elevator_num)]

        for up_floor in unallocated_up:
            self.evaluate_info['valid_up_action'] += 1
            cur_elev = floor2elevators[up_floor].item()
            # if self.use_advice and cur_elev == self.elevator_num:
            #     # use advice
            #     self.evaluate_info['advice_up_action'] += 1
            #     cur_elev = advice_floor2elevators[up_floor]
            all_elv_up_fs[cur_elev].append(up_floor)
        for dn_floor in unallocated_dn:
            self.evaluate_info['valid_dn_action'] += 1
            cur_elev = floor2elevators[dn_floor + self.floor_num].item()
            # if self.use_advice and cur_elev == self.elevator_num:
            #     # use advice
            #     self.evaluate_info['advice_dn_action'] += 1
            #     cur_elev = advice_floor2elevators[dn_floor + self.floor_num]
            all_elv_down_fs[cur_elev].append(dn_floor)
        action_to_execute = []
        for idx in range(self.elevator_num):
            action_to_execute.append(ElevatorHallCall(all_elv_up_fs[idx], all_elv_down_fs[idx]))

        # step until next person come
        next_call_come = False
        cur_time = self._config.raw_time
        reward = np.zeros((self.floor_num*2, ))
        arrive_wts = [[] for i in range(self.elevator_num)]
        total_energy = 0
        reward_list_for_eval = []
        dt_per_dp = 0
        while not next_call_come and not self.mansion.is_done:
            dt_per_dp += 1
            calling_wt, arrive_wt, loaded_num, enter_num, no_io_masks, awt, hall_waiting_rewards, car_waiting_rewards, energy,  deliver_person_num\
                = self.mansion.run_mansion(action_to_execute, special_reward=True, advantage_floor=advantage_floor)
            for i in range(self.elevator_num):
                arrive_wts[i] += arrive_wt[i]
            self.mansion.generate_person()
            if self.open_render:
                self.render()
            # time.sleep(0.05)
            unallocated_up, unallocated_dn = self.mansion.get_unallocated_floors()
            # print(unallocated_up, unallocated_dn)
            action_to_execute = [ElevatorHallCall([], []) for _ in range(self.elevator_num)]
            next_call_come = unallocated_up != [] or unallocated_dn != []
            # if DEBUG:
            #     print(action_to_execute, next_call_come, self.mansion.is_done)
            #     print(self.mansion._wait_upward_persons_queue)
            #     print(self.mansion._wait_downward_persons_queue)
            #     print(self.mansion.finish_person_num, self.mansion._person_generator.total_person_num)
            #     for idx, elev in enumerate(self.mansion._elevators):
            #         print(idx, elev._run_state, elev.state)


            # cal reward
            factor = 0.2
            # r1 = 0.01 * (-np.array(hall_waiting_rewards) - factor * np.array(car_waiting_rewards) - 5e-4 * energy)
            # r2 = 0.1 * np.array([deliver_person_num for _ in range(self.floor_num*2)])
            hwr = sum(hall_waiting_rewards)
            cwr = sum(car_waiting_rewards)
            # r3 = -0.01 * np.array([hwr + factor * cwr + 5e-4 * energy for _ in range(self.floor_num*2)])
            # 除以过去3分钟的平均人流量
            # r4 = -0.2 * np.array([hwr + factor * cwr + 5e-4 * energy for _ in range(self.floor_num*2)]) \
            #      / (self.mansion._person_generator.reward_factor[min(int(self._config.raw_time)//60, 59)] + 1)
            # r4 = -0.02 * np.array([hwr + factor * cwr + 5e-4 * energy for _ in range(self.floor_num * 2)]) \
            #      / (self.mansion._person_generator.reward_factor[min(int(self._config.raw_time) // 60, 59)] + 1)
            # r5 = -0.02 * np.array([hwr + factor * cwr + 0 * energy for _ in range(self.floor_num * 2)]) \
            #      / (self.mansion._person_generator.reward_factor[min(int(self._config.raw_time) // 60, 59)] + 1)
            total_loss = 0
            for idx in range(len(self.mansion._elevators)):
                # print(f'Elev {idx}, {self.mansion._elevators[idx]._hall_up_call}, {self.mansion._elevators[idx]._hall_dn_call}, {self.mansion._elevators[idx]._current_position}', end=', ')
                loss = self.estimater.estimate_elev_route_loss(self.mansion._elevators[idx], hall_waiting_rewards)
                # print(f'loss: {loss}')
                total_loss += loss
            estimate_awt = total_loss / (hwr + 1)
            r6 = -0.01 * np.array([estimate_awt for _ in range(self.floor_num * 2)])
            # print(f"Current Time: {self._config._current_time}  Loss: {estimate_awt} {total_loss}, {hwr}")
            reward += r6
            reward_list_for_eval.append(sum(r6))
            total_energy += energy
        # TODO: calculate reward, during the time interval between two person, finish how many person?
        finish_time = self._config.raw_time
        delta_t = finish_time - cur_time
        timestep = delta_t / self._config._delta_t
        reward = reward / timestep
        info = {'waiting_time': concate_list(arrive_wts), 'sum_wait_rew': 0, 'sum_io_rew': 0,
                'sum_enter_rew': 0, 'awt': awt, 'total_energy': total_energy, 'reward_list_for_eval': reward_list_for_eval}
        new_obs = self.get_smec_state()
        self.mansion.generate_person()
        done = self.mansion.is_done

        return new_obs, reward, done, info
        # return new_obs, new_reward, done, info

    # 不计算奖励的版本，加快速度
    def step_eval(self, actions):
        if type(actions) == np.ndarray:
            floor2elevators = torch.from_numpy(actions)
            advantage_floor = None
        else:
            floor2elevators, advantage_floor = actions.split(32, 0)
            assert type(floor2elevators) == torch.Tensor, "only support tensor action"  # unwrapped raw action.

            floor2elevators = floor2elevators.squeeze()
            advantage_floor = advantage_floor.squeeze()

        unallocated_up, unallocated_dn = self.mansion.get_unallocated_floors()
        all_elv_up_fs, all_elv_down_fs = [[] for _ in range(self.elevator_num)], [[] for _ in range(self.elevator_num)]

        for up_floor in unallocated_up:
            self.evaluate_info['valid_up_action'] += 1
            cur_elev = floor2elevators[up_floor].item()
            all_elv_up_fs[cur_elev].append(up_floor)
        for dn_floor in unallocated_dn:
            self.evaluate_info['valid_dn_action'] += 1
            cur_elev = floor2elevators[dn_floor + self.floor_num].item()
            all_elv_down_fs[cur_elev].append(dn_floor)
        action_to_execute = []
        for idx in range(self.elevator_num):
            action_to_execute.append(ElevatorHallCall(all_elv_up_fs[idx], all_elv_down_fs[idx]))

        # step until next person come
        next_call_come = False
        cur_time = self._config.raw_time
        reward = np.zeros((self.floor_num * 2,))
        arrive_wts = [[] for i in range(self.elevator_num)]
        total_energy = 0
        reward_list_for_eval = []
        dt_per_dp = 0
        while not next_call_come and not self.mansion.is_done:
            dt_per_dp += 1
            calling_wt, arrive_wt, loaded_num, enter_num, no_io_masks, awt, hall_waiting_rewards, car_waiting_rewards, energy, deliver_person_num \
                = self.mansion.run_mansion(action_to_execute, special_reward=True, advantage_floor=advantage_floor)
            for i in range(self.elevator_num):
                arrive_wts[i] += arrive_wt[i]
            self.mansion.generate_person()
            if self.open_render:
                self.render()
            # time.sleep(0.05)
            unallocated_up, unallocated_dn = self.mansion.get_unallocated_floors()
            # print(unallocated_up, unallocated_dn)
            action_to_execute = [ElevatorHallCall([], []) for _ in range(self.elevator_num)]
            next_call_come = unallocated_up != [] or unallocated_dn != []

            total_energy += energy
        info = {'waiting_time': concate_list(arrive_wts), 'sum_wait_rew': 0, 'sum_io_rew': 0,
                'sum_enter_rew': 0, 'awt': awt, 'total_energy': total_energy,
                'reward_list_for_eval': reward_list_for_eval}
        new_obs = self.get_smec_state()
        self.mansion.generate_person()
        done = self.mansion.is_done

        return new_obs, reward, done, info

    def get_floor2elevator_dis(self, device):
        floor2elevator_dis = []
        for call_floor in range(self.floor_num):  # up calls
            cur_distance = []
            for elev in self.mansion._elevators:
                elevator_floor = elev._sync_floor
                # try by JY
                if call_floor == elevator_floor and \
                        (elev._run_state == ELEVATOR_STOP_DOOR_CLOSING or elev._run_state == ELEVATOR_RUN):
                    if elev._service_direction == 1:
                        elevator_floor += 0.01
                    elif elev._service_direction == -1:
                        elevator_floor -= 0.01
                going_up = elev._service_direction == 1  # going up
                if going_up and call_floor >= elevator_floor:
                    distance = call_floor - elevator_floor  # directly move up
                elif going_up and call_floor < elevator_floor:
                    distance = (self.floor_num - elevator_floor) + self.floor_num + call_floor  # move up + move to bottom + move to call
                else:
                    distance = elevator_floor + call_floor  # down to bottom and move up
                cur_distance.append(distance / self.floor_num)  # normalize
            floor2elevator_dis.append(cur_distance)

        for call_floor in range(self.floor_num):  # down calls
            cur_distance = []
            for elev in self.mansion._elevators:
                elevator_floor = elev._sync_floor
                # try by JY
                if call_floor == elevator_floor and \
                        (elev._run_state == ELEVATOR_STOP_DOOR_CLOSING or elev._run_state == ELEVATOR_RUN):
                    if elev._service_direction == 1:
                        elevator_floor += 0.01
                    elif elev._service_direction == -1:
                        elevator_floor -= 0.01
                going_down = elev._service_direction != 1  # going down
                if going_down and call_floor <= elevator_floor:
                    distance = elevator_floor - call_floor  # directly move down
                elif going_down and call_floor > elevator_floor:
                    distance = elevator_floor + self.floor_num + (
                                self.floor_num - call_floor)  # move down + move to top + move to call
                else:
                    distance = (self.floor_num - elevator_floor) + (self.floor_num - call_floor)  # to top and move down
                cur_distance.append(distance / self.floor_num)  # normalize
            floor2elevator_dis.append(cur_distance)
        floor2elevator_dis = torch.tensor(floor2elevator_dis).to(device)
        return floor2elevator_dis

    # no attention mask, pure convenience mask
    def get_smec_state(self):
        up_wait, down_wait, loading, location, up_call, down_call, load_up, load_down = self.mansion.get_rl_state(
            encode=True)
        up_wait, down_wait, loading, location = torch.tensor(up_wait), torch.tensor(down_wait), torch.tensor(
            loading), torch.tensor(location)
        self.cur_adj_matrix = self.gb.update_adj_matrix(self.cur_adj_matrix, up_call, down_call)
        self.cur_node_feature = self.gb.update_node_feature(self.cur_node_feature, up_wait, down_wait, load_up,
                                                            load_down, location)
        distances = self.get_floor2elevator_dis(up_wait.device)
        valid_action_mask = self.mansion.get_unallocated_floors_mask()
        valid_action_mask = torch.tensor(valid_action_mask).to(up_wait.device)

        legal_masks = self.get_action_mask(up_wait.device)
        elevator_mask, floor_mask = self.get_action_mask_plus(up_wait.device)
        ms = {'adj_m': self.cur_adj_matrix, 'node_feature_m': self.cur_node_feature,
              'legal_masks': legal_masks,
              'elevator_mask': elevator_mask, 'floor_mask': floor_mask,
              'distances': distances, 'valid_action_mask': valid_action_mask}
        return ms

    def seed(self, seed=None):
        set_seed(seed)

    def reset(self):
        self.mansion.reset_env()
        self.cur_node_feature = self.empty_node_feature.clone()
        self.cur_adj_matrix = self.empty_adj_matrix.clone()
        state = self.get_smec_state()
        if self.seed_c:
            self.seed_c += 100
            self.seed(self.seed_c)
        self.reward_filter.reset()
        self.state_filter.reset()

        # self.data_idx = 0
        # self.next_generate_person = self.real_dataset[self.data_idx]
        # print(state)
        return state

    def render(self, **kwargs):
        self.viewer.view()

    def close(self):
        pass

    @property
    def attribute(self):
        return self.mansion.attribute

    @property
    def state(self):
        return self.mansion.state

    @property
    def statistics(self):
        return self.mansion.get_statistics()

    @property
    def log_debug(self):
        return self._config.log_notice

    @property
    def log_notice(self):
        return self._config.log_notice

    @property
    def log_warning(self):
        return self._config.log_warning

    @property
    def log_fatal(self):
        return self._config.log_fatal


iteration = 1000


def uniform_dispatch(up_floors, down_floors, elev_num, random_one=True):
    all_ups = [('up', ele) for ele in up_floors]
    all_downs = [('down', ele) for ele in down_floors]
    all_call_candidate = all_ups + all_downs
    candidate_num = len(all_call_candidate)
    split_interval = candidate_num // elev_num
    final_assignment = []
    for ele_id in range(elev_num):
        lb = ele_id * split_interval
        if ele_id == elev_num - 1:
            cur_assign = all_call_candidate[lb:]
        else:
            cur_assign = all_call_candidate[lb: lb + split_interval]
        up_floors, down_floors = [], []
        if random_one and cur_assign:
            selected_ele = random.choice(cur_assign)
            if selected_ele[0] == 'up':
                up_floors.append(selected_ele[1])
            else:
                down_floors.append(selected_ele[1])
        else:
            for ele in cur_assign:
                if ele[0] == 'up':
                    up_floors.append(ele[1])
                else:
                    down_floors.append(ele[1])
        final_assignment.append(ElevatorHallCall(up_floors, down_floors))
    return final_assignment


def identity_dispatch(up_floors, down_floors, elev_num):
    hallcall = ElevatorHallCall(up_floors, down_floors)
    return_ele = [hallcall for i in range(elev_num)]
    return return_ele


def make_env(seed=0, render=False, forbid_uncalled=False, use_graph=True, gamma=0.99, real_data=True,
             use_advice=False, special_reward=False, data_dir=None, file_begin_idx=None, dos=''):
    def _thunk():
        return SmecRLEnv(render=render, seed=seed, forbid_uncalled=forbid_uncalled, use_graph=use_graph, gamma=gamma,
                         real_data=real_data, use_advice=use_advice, special_reward=special_reward, data_dir=data_dir, file_begin_idx=file_begin_idx, dos=dos)

    return _thunk


def test_multi_env(num_processes):
    envs = [make_env(seed=i) for i in range(num_processes)]
    envs = AsyncVectorEnv(env_fns=envs)
    bo = envs.reset()
    batch_action = [torch.tensor([1 for i in range(6)]) for j in range(num_processes)]
    obs, rew, done, info = envs.step(batch_action)
    return


if __name__ == '__main__':
    # test_multi_env(8)
    eval_env = make_env(seed=0, render=False,
                        real_data=True, data_dir='../train_data/new/lunchpeak')()

    for t in range(360):

        a = eval_env.mansion._person_generator.generate_person()
        print(t, a)
        eval_env.step()
