'''
Author: 
Email: 
Date: 2021-07-20 01:11:54
LastEditTime: 2022-04-07 16:38:25
Description: 
'''
import numpy as np
import copy
import os, glob 
import gym


from .vehicle_dynamics import Bicycle
from .utils import make_gif


ego_param = {
    'name': 'ego',
    'shape': [18, 40], # [w, l]
    'color': [132, 94, 194], # purple
    'n_sensor': 15,
    'sensor_theta_boundary': [-np.pi/6, np.pi/6],
    'sensor_max': 100.0,
}

car_param_1 = {
    'name': 'block_vehicle',
    'shape': [18, 40], # [w, l]
    'color': [255, 111, 145], # pink red
    'n_sensor': 15,
    'sensor_theta_boundary': [-np.pi/6, np.pi/6],
    'sensor_max': 100.0,
}

car_param_2 = {
    'name': 'other_vehicle',
    'shape': [18, 40], # [w, l]
    'color': [255, 111, 145], # pink red
    'n_sensor': 0,
    'sensor_theta_boundary': [-0.001, 0.001],
    'sensor_max': 100.0,
}


pedestrain_param = {
    'name': 'pedestrain',
    'shape': [5, 10], # [w, l]
    'color': [255, 199, 95], # orange
    'n_sensor': 1,
    'sensor_theta_boundary': [2*np.pi-0.0001, 2*np.pi+0.0001],
    'sensor_max': 50.0,
}


class CrashEnv(gym.Env):
    def __init__(self, test_mode='IID', stage="train", use_render=False, save_gif=False):
        self.step_num = 0
        self.max_step_num = 30
        self.test_mode = test_mode
        self.stage = stage
        
        assert test_mode in ['IID', 'OOD']

        self.viewer_xy = (200, 200)
        self.win_coord = np.array([[0, 0], [self.viewer_xy[0], 0], [*self.viewer_xy], [0, self.viewer_xy[1]], [0, 0]])
        self.dt = 0.3

        self.save_gif = save_gif
        self.use_render = use_render
        self.renderer = None

        self.ego_collide_pedestrain = False
        self.collide_others = False
        self.block_collide_pedestrain = False
        self.ego_observe_pedestrain = False

        ego_startpoint = np.array([125, 30, np.pi/2])
        ped_startpoint = np.array([190, 140, np.pi])
        surround_startpoints = [
            [153, 50, np.pi/2],  
            [45, 160, -np.pi/2], 
            [80, 160, -np.pi/2], 
        ]

        # set parameters for agents, the last one should always be the ego vehicle
        self.agent_list = [
            pedestrain_param,
            copy.deepcopy(car_param_1),  
            copy.deepcopy(car_param_2), 
            copy.deepcopy(car_param_2), 
            ego_param,
        ]
        self.agent_mapping = {
            'pedestrain': 0,
            'block_vehicle': 1,
            'ego': 4
        }
        self.start_point_list = [
            ped_startpoint, 
            surround_startpoints[0],
            surround_startpoints[1],
            surround_startpoints[2],
            ego_startpoint, 
        ]
        self.speed_list = [
            10,     # pedestrain
            18,     # other vehicle 1
            10,     # other vehicle 2
            10,     # other vehicle 3
            18,     # ego vehicle
        ]

        self.max_velocity = 18
        self.n_agents = len(self.agent_list)
        self.agent_state_dim = 4
        self.agent_action_dim = 2
        self.collision_dim = 2
        self.action_dim = self.agent_action_dim*(self.n_agents-1) # we cannot control the ego vehicle
        self.state_dim = self.agent_state_dim*self.n_agents + self.collision_dim # the state includes the ego vehicle, [x, y, v, theta], plus the collision node
        
        self.observation_space = gym.spaces.Box(-0., 1., shape=(self.agent_state_dim*self.n_agents+self.collision_dim, ), dtype=np.float32)
        self.action_space = gym.spaces.Box(-1., 1., shape=(self.agent_action_dim*(self.n_agents-1), ), dtype=np.float32)
        
        self.map_scale = self.viewer_xy
        self.action_scale = [
            [4, 1],  # pedestrain
            [15, 1],  # other vehicle 1
            [6, 1],  # other vehicle 2
            [6, 1],  # other vehicle 3
            [4, 1],   # ego vehicle
        ]

        self.action_range = [-1, 1]
        self.collision_threshold = 25

    def _truncated_normal_distribution(self, side='left', scale=2):
        sample = np.random.randn(1)[0]*scale
        if side == 'left':
            if sample >= 0:
                sample = -sample
        elif side == 'right':
            if sample <= 0:
                sample = -sample
        return sample

    def _spawn_agents(self):
        # spawn agents, including ego vehicles.
        self.agents = []
        self.sensors = []

        # for OOD settings, we have different initial speed for pedestrain and other vehicle 1
        if self.test_mode == 'IID':
            delta_ped = np.random.randn(1)[0]*2
            delta_block = np.random.randn(1)[0]*2
            delta_other = np.random.randn(1)[0]*2
            delta_ego = 0

        elif self.test_mode == 'OOD':
            if self.stage == 'train':
                delta_ped = np.random.randn(1)[0]*2
                delta_block = np.random.randn(1)[0]*2
                delta_other = np.random.randn(1)[0]*2
                delta_ego = self._truncated_normal_distribution('right', 2)
            elif self.stage == 'test':
                delta_ped = np.random.randn(1)[0]*2
                delta_block = np.random.randn(1)[0]*2
                delta_other = np.random.randn(1)[0]*2
                delta_ego = self._truncated_normal_distribution('left', 2)

        for a_i in range(self.n_agents):
            self.agent_list[a_i]['start_point'] = self.start_point_list[a_i]

            # we will use different initial speed for OOD settings
            if self.agent_list[a_i]['name'] == 'ego':
                self.agent_list[a_i]['speed'] = self.speed_list[a_i] + delta_ego
            elif self.agent_list[a_i]['name'] == 'pedestrain':
                self.agent_list[a_i]['speed'] = self.speed_list[a_i] + delta_ped
            elif self.agent_list[a_i]['name'] == 'block_vehicle':
                self.agent_list[a_i]['speed'] = self.speed_list[a_i] + delta_block
            elif self.agent_list[a_i]['name'] == 'other_vehicle':
                self.agent_list[a_i]['speed'] = self.speed_list[a_i] + delta_other
            else:
                raise ValueError('Unknown agent type') 

            agent_param = self.agent_list[a_i]
            start_point = np.array(agent_param['start_point'])
            speed = agent_param['speed']
            agent = Bicycle(start_point=start_point, v=speed, shape=agent_param['shape'])
            self.agents.append(agent)
            sensor_info = agent_param['sensor_max'] + np.zeros((agent_param['n_sensor'], 3))  # (distance, end_x, end_y)
            self.sensors.append(sensor_info)

    def _reward(self):
        # assume only one collision happen at one time
        if self.collide_others:
            reward = 0
        elif self.block_collide_pedestrain:
            reward = 0
        elif self.ego_collide_pedestrain:
            reward = 1
        elif self.step_num >= self.max_step_num:
            reward = 0
        else:
            reward = 0
        return reward

    def random_action(self):
        actions = []
        # we cannot control the ego vehicle
        for a_i in range(self.n_agents-1):
            one_agent = np.random.uniform(-1, 1, size=(2,))
            one_agent[1] = 0
            actions.append(one_agent)

        actions = np.concatenate(actions, axis=0)
        return actions

    def _get_obs(self):
        # return the position and velocity
        states = []
        for a_i in range(self.n_agents):
            state = self.agents[a_i].get_info()[:4] # [x, y, v, theta]

            # normalize
            state[0] /= self.viewer_xy[0]
            state[1] /= self.viewer_xy[1]
            state[2] /= self.max_velocity
            state[3] /= np.pi
            states.append(state)

        # onr-hot for collision node, the first one means there is no collision
        collision = np.zeros((2,))
        if self.block_collide_pedestrain or self.collide_others:
            collision[1] = 1
        elif self.ego_collide_pedestrain:
            collision[0] = 1
        states.append(collision)

        '''
        if self.ego_collide_pedestrain and self.block_collide_pedestrain:
            print('ego_ped and block_ped')
            block_cx, block_cy = self.agents[self.agent_mapping['block_vehicle']].get_info()[0:2] # 
            ego_cx, ego_cy = self.agents[self.agent_mapping['ego']].get_info()[0:2] # block
            ped_cx, ped_cy = self.agents[self.agent_mapping['pedestrain']].get_info()[0:2] # 
            print('block', block_cx, block_cy)
            print('ego', ego_cx, ego_cy)
            print('ped', ped_cx, ped_cy)

        if self.ego_collide_pedestrain and self.collide_others:
            print('ego_ped and other')
        '''

        states = np.concatenate(states, axis=0) # [state_dim * agent_num + 2]
        return states

    def step(self, action):
        # update the action of the ego vehicle
        if self.ego_observe_pedestrain:
            ego_action = np.array([-1.0, 0.0])
        else:
            ego_action = np.array([0.0, 0.0])
        action = np.concatenate([action, ego_action], axis=0)
        assert action.shape[0] == self.n_agents * self.agent_action_dim

        # update action of all agents (including ego vehicle)
        for a_i in range(self.n_agents):
            action_one = action[a_i*2:(a_i+1)*2]
            acceleration = action_one[0]
            steering = action_one[1]
            acceleration = np.clip(acceleration, self.action_range[0], self.action_range[1])
            steering = np.clip(steering, self.action_range[0], self.action_range[1])

            acceleration *= self.action_scale[a_i][0]
            steering *= self.action_scale[a_i][1]
            self.agents[a_i].step(acceleration, steering, dt=self.dt)
        self.step_num += 1

        self._update_sensor_vectorized()
        done = self._update_terminal()
        reward = self._reward()
        state = self._get_obs()
        return state, reward, done, {}

    def reset(self):
        self.step_num = 0
        self.collide_others = False
        self.ego_observe_pedestrain = False
        self.ego_collide_pedestrain = False
        self.block_collide_pedestrain = False
        self.terminal = False
        if self.renderer is not None:
            self.renderer.close()
            self.renderer = None
        
        self._spawn_agents()
        self._update_sensor_vectorized()
        state = self._get_obs()
        return state

    def render(self):
        # create the renderer the first time
        if self.renderer is None:
            from .renderer import Renderer
            self.renderer = Renderer(*self.viewer_xy, self.agent_list)
        
        agent_info = []
        for a_i in range(self.n_agents):
            agent_info.append(self.agents[a_i].get_info())
        sensor_info = self.sensors
        label_info = {
            'frame_num': self.step_num,
            'ego_speed': self.agents[0].get_speed(),
            'ped_speed': self.agents[1].get_speed(),
        }
        self.renderer.render(agent_info, sensor_info, label_info)
        if self.save_gif:
            self.renderer.save_image()

    def close(self, filename=None):
        if self.save_gif and filename is not None:
            make_gif(filename)

    def _get_sensor(self):
        observation = []
        for s_i in range(self.n_agents):
            s = self.sensors[s_i][:, 0].flatten()/self.agent_list[s_i]['sensor_max']
            observation.append(s)
        return observation

    def _update_terminal(self):
        # check maximium step
        max_step = False
        if self.step_num >= self.max_step_num:
            max_step = True

        return max_step or self.ego_collide_pedestrain or self.collide_others or self.block_collide_pedestrain

    def _update_sensor_vectorized(self):
        """
        Use predicted intersection to determine the action of ego vehicle.
        This vectorized version can dramatically reduce the running time.
        """
        for a_i in range(self.n_agents):
            n_sensor = self.agent_list[a_i]['n_sensor']
            sensor_max = self.agent_list[a_i]['sensor_max']
            sensor_theta_boundary = self.agent_list[a_i]['sensor_theta_boundary']
            if n_sensor == 0:
                continue

            # update sensors data
            cx, cy, _, rotation = self.agents[a_i].get_info()[0:4]
            sensor_theta = np.linspace(sensor_theta_boundary[0], sensor_theta_boundary[1], n_sensor)
            xs = cx + (np.zeros((n_sensor,))+sensor_max) * np.cos(sensor_theta)
            ys = cy + (np.zeros((n_sensor,))+sensor_max) * np.sin(sensor_theta)
            xys = np.array([[x, y] for x, y in zip(xs, ys)])
            tmp_x = xys[:, 0] - cx
            tmp_y = xys[:, 1] - cy
            rotated_x = tmp_x * np.cos(rotation) - tmp_y * np.sin(rotation)
            rotated_y = tmp_x * np.sin(rotation) + tmp_y * np.cos(rotation)
            self.sensors[a_i][:, -2:] = np.vstack([rotated_x+cx, rotated_y+cy]).T

            # treat other objects all as obstacles
            potential_obstacles = []
            obstacle_name = []
            for a_j in range(self.n_agents):
                if a_i != a_j: # dont need to check itself
                    ox, oy, _, _, width, length = self.agents[a_j].get_info()
                    agent_j_name = self.agent_list[a_j]['name']
                    p1 = [ox - width/2, oy - length/2]
                    p2 = [ox + width/2, oy - length/2]
                    p3 = [ox + width/2, oy + length/2]
                    p4 = [ox - width/2, oy + length/2]
                    # TODO: apply rotation
                    potential_obstacles.append(np.array([p1, p2, p3, p4]))
                    obstacle_name.append(agent_j_name)
            assert len(potential_obstacles) == self.n_agents - 1, 'obstacle number is not correct'
            potential_obstacles = np.array(potential_obstacles)
            
            q = np.array([cx, cy])
            all_collision_name = []
            s = np.array(self.sensors[a_i][:, -2:] - q)
            
            possible_sensor_distance = []
            possible_intersections = [] 
            possible_collision_name = []
            
            def _compute_collision(p, r):
                """
                Math Stuff in computing distances.
                @return:
                    u, index
                """
                if len(p.shape) == 3: # obstacle
                    u = np.expand_dims(np.cross(np.squeeze(q-p), r),2)
                    t = np.array([np.array([np.cross(q-p[o_i, o_j], s) for o_j in range(p.shape[1])]) for o_i in range(p.shape[0])])
                    den = np.array([np.array([np.cross(r[o_i, o_j], s) for o_j in range(r.shape[1])]) for o_i in range(r.shape[0])]) # 4-corners
                    den_idx = np.where(den == 0)
                else: # window
                    u = np.expand_dims(np.cross(np.squeeze(q-p), r), 1)
                    t = np.array([np.cross(q-p[o_i], s) for o_i in range(p.shape[0])])
                    den = np.array([np.cross(r[o_i], s) for o_i in range(r.shape[0])]) # 4-corners
                    den_idx = np.where(den == 0)

                t = t / (den + 1e-20)
                u = u / (den + 1e-20)
                t_idx = np.where(np.abs(t-1/2.) <= 1/2.)
                u_idx = np.where(np.abs(u-1/2.) <= 1/2.)
                t_set = [np.array(t_idx)[:, i].tolist() for i in range(len(t_idx[0]))]
                u_set = [np.array(u_idx)[:, i].tolist() for i in range(len(u_idx[0]))]
                den_set = [np.array(den_idx)[:, i].tolist() for i in range(len(den_idx[0]))]

                final_idx = np.array([val for val in t_set if val in u_set and val not in den_set]).T
                return u, final_idx
            
            # obstacle collision detection for all the sensors of agent a_i
            p_obs = potential_obstacles
            r_obs = -potential_obstacles[:, :, :] + potential_obstacles[:, list(range(1, potential_obstacles.shape[1]))+[0], :]      
            
            u, final_idx = _compute_collision(p_obs, r_obs)
            if len(final_idx) >= 1:
                final_idx = tuple(final_idx[i] for i in range(3))
                intersection = q + np.expand_dims(u[final_idx],1) * s[final_idx[2],:]
                possible_intersections.extend(intersection.tolist())
                possible_sensor_distance.extend(np.linalg.norm(intersection-q, axis=1).tolist())
                possible_collision_name.extend([obstacle_name[final_idx[0][i]] for i in range(len(final_idx[0]))])

            # window collision detection for all the sensors of agent a_i
            p_win = self.win_coord
            r_win = -self.win_coord[:, :] + self.win_coord[list(range(1, self.win_coord.shape[0]))+[0], :]

            u_win, final_idx_win = _compute_collision(p_win, r_win)
            if len(final_idx_win) >= 1:
                final_idx_win = tuple(final_idx_win[i] for i in range(2))
                intersection = q + np.expand_dims(u_win[final_idx_win],1) * s[final_idx_win[1], :]
                possible_intersections.extend(intersection.tolist())
                possible_sensor_distance.extend(np.linalg.norm(intersection-q, axis=1).tolist())
                possible_collision_name.extend(['windows' for i in range(len(final_idx_win[0]))])
            
            # Evaluate for each sensor, if none of them is colliding, simply return
            if len(final_idx) < 1 and len(final_idx_win) < 1:
                for si in range(len(self.sensors[a_i])):
                    self.sensors[a_i][si, 0] = sensor_max
                    all_collision_name.append('farest')
            # at least one of the sensor detects something
            else:
                if len(final_idx) > 0:
                    obs_idx_offset = len(final_idx[0])
                else:
                    obs_idx_offset = 0

                for si in range(len(self.sensors[a_i])):
                    if len(final_idx) >= 1:
                        obs_idx = np.where(final_idx[2] == si)
                        col_idx = obs_idx

                        if len(final_idx_win) >= 1:
                            win_idx = np.where(final_idx_win[1] == si)

                            if len(win_idx[0]) >= 1:
                                for win_i in range(len(win_idx[0])):
                                    win_idx[0][win_i] = win_idx[0][win_i] + obs_idx_offset
                                if len(obs_idx[0]) >= 1:
                                    col_idx = (np.concatenate([col_idx[0], win_idx[0]], 0),) # all possible collision index
                                else:
                                    col_idx = win_idx
                    else:
                        if len(final_idx_win) >= 1:
                            win_idx = np.where(final_idx_win[1] == si)
                            
                            if len(win_idx[0]) >= 1:
                                for win_i in range(len(win_idx[0])):
                                    win_idx[0][win_i] = win_idx[0][win_i] + obs_idx_offset
                                col_idx = win_idx # all possible collision index
                            else:
                                self.sensors[a_i][si, 0] = sensor_max
                                all_collision_name.append('farest')
                                continue

                    tmp_dist = np.array(possible_sensor_distance)[col_idx]
                    tmp_inter = np.array(possible_intersections)[col_idx]
                    tmp_name = [possible_collision_name[i] for i in col_idx[0].tolist()]

                    if len(tmp_dist) == 0:
                        tmp_dist = np.array([sensor_max])
                        tmp_inter = np.array([self.sensors[a_i][si, -2:]])
                        tmp_name = ['farest']                   

                    self.sensors[a_i][si, 0] = np.min(tmp_dist)
                    self.sensors[a_i][si, -2:] = tmp_inter[np.argmin(tmp_dist)]                
                    all_collision_name.append(tmp_name[np.argmin(tmp_dist)])

            # check ego vehicle observation
            agent_i_name = self.agent_list[a_i]['name']
            if (agent_i_name == 'ego') and ('pedestrain' in all_collision_name):
                self.ego_observe_pedestrain = True
            else:
                self.ego_observe_pedestrain = False

            '''
            # check sensors collision
            self.sensor_collision = False
            for s_i in range(len(self.sensors[a_i])):
                if self.sensors[a_i][s_i, 0] < self.collision_threshold:
                    # colliding with pedestrain or other objects
                    if (agent_i_name == 'ego') and (all_collision_name[s_i] == 'pedestrain'):
                        self.ego_collide_pedestrain = True
                    elif (agent_i_name == 'block_vehicle') and (all_collision_name[s_i] == 'pedestrain'):
                        self.block_collide_pedestrain = True
                    elif (agent_i_name == 'pedestrain') and all_collision_name[s_i] == 'ego':
                        # avoid the case that ego intentionally hit other vehicles
                        _, ped_cy = self.agents[self.agent_mapping['pedestrain']].get_info()[0:2] # ped
                        _, ego_cy = self.agents[self.agent_mapping['ego']].get_info()[0:2] # block
                        if ped_cy - ego_cy > 15:
                            self.ego_collide_pedestrain = True
                        else:
                            self.collide_others = True
                            print('intentional collide ego', ego_cy, ped_cy)
                    elif (agent_i_name == 'pedestrain') and all_collision_name[s_i] == 'block_vehicle':
                        # avoid the case that ego intentionally hit other vehicles
                        _, ped_cy = self.agents[self.agent_mapping['pedestrain']].get_info()[0:2] # ped
                        _, block_cy = self.agents[self.agent_mapping['block_vehicle']].get_info()[0:2] # block
                        if ped_cy - block_cy > 15:
                            self.block_collide_pedestrain = True
                        else:
                            self.collide_others = True
                            print('intentional collide block_vehicle', block_cy, ped_cy)
                    else:
                        self.collide_others = True
                        _, block_cy = self.agents[self.agent_mapping['block_vehicle']].get_info()[0:2] # block
                        print(agent_i_name, all_collision_name[s_i], 'block_cy:', block_cy)
            '''

            # check sensors collision
            self.sensor_collision = False
            for s_i in range(len(self.sensors[a_i])):
                if self.sensors[a_i][s_i, 0] < self.collision_threshold:
                    # colliding with pedestrain or other objects
                    if (agent_i_name == 'ego') and (all_collision_name[s_i] == 'pedestrain'):
                        self.ego_collide_pedestrain = True
                    elif (agent_i_name == 'block_vehicle') and (all_collision_name[s_i] == 'pedestrain'):
                        self.block_collide_pedestrain = True
                    elif agent_i_name == 'pedestrain':
                        if self.sensors[a_i][s_i, 0] < 4:
                            self.collide_others = True
                            #print('ped intentional collide ego or block', self.sensors[a_i][s_i, 0])
                    else:
                        self.collide_others = True

class CrashEnv_Collect(object):
    def __init__(self, test_mode='IID', use_render=False, save_gif=False):
        self.step_num = 0
        self.max_step_num = 30
        self.test_mode = test_mode
        assert test_mode in ['IID', 'OOD']

        self.viewer_xy = (200, 200)
        self.win_coord = np.array([[0, 0], [self.viewer_xy[0], 0], [*self.viewer_xy], [0, self.viewer_xy[1]], [0, 0]])
        self.dt = 0.3

        self.save_gif = save_gif
        self.use_render = use_render
        self.renderer = None

        self.ego_collide_pedestrain = False
        self.collide_others = False
        self.block_collide_pedestrain = False
        self.ego_observe_pedestrain = False

        ego_startpoint = np.array([125, 30, np.pi/2])
        ped_startpoint = np.array([190, 140, np.pi])
        surround_startpoints = [
            [153, 50, np.pi/2],  
            [45, 160, -np.pi/2], 
            [80, 160, -np.pi/2], 
        ]

        # set parameters for agents, the last one should always be the ego vehicle
        self.agent_list = [
            pedestrain_param,
            copy.deepcopy(car_param_1),  
            copy.deepcopy(car_param_2), 
            copy.deepcopy(car_param_2), 
            ego_param,
        ]
        self.agent_mapping = {
            'pedestrain': 0,
            'block_vehicle': 1,
            'ego': 4
        }
        self.start_point_list = [
            ped_startpoint, 
            surround_startpoints[0],
            surround_startpoints[1],
            surround_startpoints[2],
            ego_startpoint, 
        ]
        self.speed_list = [
            10,     # pedestrain
            18,     # other vehicle 1
            10,     # other vehicle 2
            10,     # other vehicle 3
            18,     # ego vehicle
        ]

        self.max_velocity = 18
        self.n_agents = len(self.agent_list)
        self.agent_state_dim = 4
        self.agent_action_dim = 2
        self.collision_dim = 2
        self.action_dim = self.agent_action_dim*(self.n_agents-1) # we cannot control the ego vehicle
        self.state_dim = self.agent_state_dim*self.n_agents + self.collision_dim # the state includes the ego vehicle, [x, y, v, theta], plus the collision node
        self.map_scale = self.viewer_xy
        self.action_scale = [
            [4, 1],  # pedestrain
            [15, 1],  # other vehicle 1
            [6, 1],  # other vehicle 2
            [6, 1],  # other vehicle 3
            [4, 1],   # ego vehicle
        ]

        self.action_range = [-1, 1]
        self.collision_threshold = 25

    def _truncated_normal_distribution(self, side='left', scale=2):
        sample = np.random.randn(1)[0]*scale
        if side == 'left':
            if sample >= 0:
                sample = -sample
        elif side == 'right':
            if sample <= 0:
                sample = -sample
        return sample

    def _spawn_agents(self, stage):
        # spawn agents, including ego vehicles.
        self.agents = []
        self.sensors = []

        # for OOD settings, we have different initial speed for pedestrain and other vehicle 1
        if self.test_mode == 'IID':
            delta_ped = np.random.randn(1)[0]*2
            delta_block = np.random.randn(1)[0]*2
            delta_ego = np.random.randn(1)[0]*2
        elif self.test_mode == 'OOD':
            # in training stage, we keep the difference between agents the same
            if stage == 'train':
                delta_ped = np.random.randn(1)[0]*2
                delta_block = delta_ped
                delta_ego = delta_ped
            # in testing stage, we use 3 random delta
            elif stage == 'test':
                delta_ped = np.random.randn(1)[0]*2
                delta_block = np.random.randn(1)[0]*2
                delta_ego = np.random.randn(1)[0]*2
        elif self.test_mode == 'OOD-E':
            if stage == 'train':
                delta_ped = np.random.randn(1)[0]*2
                delta_block = np.random.randn(1)[0]*2
                delta_ego = self._truncated_normal_distribution('right', 2)
            elif stage == 'test':
                delta_ped = np.random.randn(1)[0]*2
                delta_block = np.random.randn(1)[0]*2
                delta_ego = self._truncated_normal_distribution('left', 2)

        for a_i in range(self.n_agents):
            self.agent_list[a_i]['start_point'] = self.start_point_list[a_i]

            # we will use different initial speed for OOD settings
            if self.agent_list[a_i]['name'] == 'ego':
                self.agent_list[a_i]['speed'] = self.speed_list[a_i] + delta_ego
            elif self.agent_list[a_i]['name'] == 'pedestrain':
                self.agent_list[a_i]['speed'] = self.speed_list[a_i] + delta_ped
            elif self.agent_list[a_i]['name'] == 'block_vehicle':
                self.agent_list[a_i]['speed'] = self.speed_list[a_i] + delta_block
            elif self.agent_list[a_i]['name'] == 'other_vehicle':
                self.agent_list[a_i]['speed'] = self.speed_list[a_i] + np.random.randn(1)[0]*2
            else:
                raise ValueError('Unknown agent type') 

            agent_param = self.agent_list[a_i]
            start_point = np.array(agent_param['start_point'])
            speed = agent_param['speed']
            agent = Bicycle(start_point=start_point, v=speed, shape=agent_param['shape'])
            self.agents.append(agent)
            sensor_info = agent_param['sensor_max'] + np.zeros((agent_param['n_sensor'], 3))  # (distance, end_x, end_y)
            self.sensors.append(sensor_info)

    def _reward(self):
        # assume only one collision happen at one time
        if self.collide_others:
            reward = 0
        elif self.block_collide_pedestrain:
            reward = 0
        elif self.ego_collide_pedestrain:
            reward = 1
        elif self.step_num >= self.max_step_num:
            reward = 0
        else:
            reward = 0
        return reward

    def random_action(self):
        actions = []
        # we cannot control the ego vehicle
        for a_i in range(self.n_agents-1):
            one_agent = np.random.uniform(-1, 1, size=(2,))
            one_agent[1] = 0
            actions.append(one_agent)

        actions = np.concatenate(actions, axis=0)
        return actions

    def _get_obs(self):
        # return the position and velocity
        states = []
        for a_i in range(self.n_agents):
            state = self.agents[a_i].get_info()[:4] # [x, y, v, theta]

            # normalize
            state[0] /= self.viewer_xy[0]
            state[1] /= self.viewer_xy[1]
            state[2] /= self.max_velocity
            state[3] /= np.pi
            states.append(state)

        # onr-hot for collision node, the first one means there is no collision
        collision = np.zeros((2,))
        if self.block_collide_pedestrain or self.collide_others:
            collision[1] = 1
        elif self.ego_collide_pedestrain:
            collision[0] = 1
        states.append(collision)

        '''
        if self.ego_collide_pedestrain and self.block_collide_pedestrain:
            print('ego_ped and block_ped')
            block_cx, block_cy = self.agents[self.agent_mapping['block_vehicle']].get_info()[0:2] # 
            ego_cx, ego_cy = self.agents[self.agent_mapping['ego']].get_info()[0:2] # block
            ped_cx, ped_cy = self.agents[self.agent_mapping['pedestrain']].get_info()[0:2] # 
            print('block', block_cx, block_cy)
            print('ego', ego_cx, ego_cy)
            print('ped', ped_cx, ped_cy)

        if self.ego_collide_pedestrain and self.collide_others:
            print('ego_ped and other')
        '''

        states = np.concatenate(states, axis=0) # [state_dim * agent_num + 2]
        return states

    def step(self, action):
        # update the action of the ego vehicle
        if self.ego_observe_pedestrain:
            ego_action = np.array([-1.0, 0.0])
        else:
            ego_action = np.array([0.0, 0.0])
        action = np.concatenate([action, ego_action], axis=0)
        assert action.shape[0] == self.n_agents * self.agent_action_dim

        # update action of all agents (including ego vehicle)
        for a_i in range(self.n_agents):
            action_one = action[a_i*2:(a_i+1)*2]
            acceleration = action_one[0]
            steering = action_one[1]
            acceleration = np.clip(acceleration, self.action_range[0], self.action_range[1])
            steering = np.clip(steering, self.action_range[0], self.action_range[1])

            acceleration *= self.action_scale[a_i][0]
            steering *= self.action_scale[a_i][1]
            self.agents[a_i].step(acceleration, steering, dt=self.dt)
        self.step_num += 1

        self._update_sensor_vectorized()
        done = self._update_terminal()
        reward = self._reward()
        state = self._get_obs()
        return state, reward, done, None

    def reset(self, stage):
        self.step_num = 0
        self.collide_others = False
        self.ego_observe_pedestrain = False
        self.ego_collide_pedestrain = False
        self.block_collide_pedestrain = False
        self.terminal = False
        if self.renderer is not None:
            self.renderer.close()
            self.renderer = None
        
        self._spawn_agents(stage)
        self._update_sensor_vectorized()
        state = self._get_obs()
        return state

    def render(self):
        # create the renderer the first time
        if self.renderer is None:
            from .renderer import Renderer
            self.renderer = Renderer(*self.viewer_xy, self.agent_list)
        
        agent_info = []
        for a_i in range(self.n_agents):
            agent_info.append(self.agents[a_i].get_info())
        sensor_info = self.sensors
        label_info = {
            'frame_num': self.step_num,
            'ego_speed': self.agents[0].get_speed(),
            'ped_speed': self.agents[1].get_speed(),
        }
        self.renderer.render(agent_info, sensor_info, label_info)
        if self.save_gif:
            self.renderer.save_image()

    def close(self, filename=None):
        if self.save_gif and filename is not None:
            make_gif(filename)

    def _get_sensor(self):
        observation = []
        for s_i in range(self.n_agents):
            s = self.sensors[s_i][:, 0].flatten()/self.agent_list[s_i]['sensor_max']
            observation.append(s)
        return observation

    def _update_terminal(self):
        # check maximium step
        max_step = False
        if self.step_num >= self.max_step_num:
            max_step = True

        return max_step or self.ego_collide_pedestrain or self.collide_others or self.block_collide_pedestrain

    def _update_sensor_vectorized(self):
        """
        Use predicted intersection to determine the action of ego vehicle.
        This vectorized version can dramatically reduce the running time.
        """
        for a_i in range(self.n_agents):
            n_sensor = self.agent_list[a_i]['n_sensor']
            sensor_max = self.agent_list[a_i]['sensor_max']
            sensor_theta_boundary = self.agent_list[a_i]['sensor_theta_boundary']
            if n_sensor == 0:
                continue

            # update sensors data
            cx, cy, _, rotation = self.agents[a_i].get_info()[0:4]
            sensor_theta = np.linspace(sensor_theta_boundary[0], sensor_theta_boundary[1], n_sensor)
            xs = cx + (np.zeros((n_sensor,))+sensor_max) * np.cos(sensor_theta)
            ys = cy + (np.zeros((n_sensor,))+sensor_max) * np.sin(sensor_theta)
            xys = np.array([[x, y] for x, y in zip(xs, ys)])
            tmp_x = xys[:, 0] - cx
            tmp_y = xys[:, 1] - cy
            rotated_x = tmp_x * np.cos(rotation) - tmp_y * np.sin(rotation)
            rotated_y = tmp_x * np.sin(rotation) + tmp_y * np.cos(rotation)
            self.sensors[a_i][:, -2:] = np.vstack([rotated_x+cx, rotated_y+cy]).T

            # treat other objects all as obstacles
            potential_obstacles = []
            obstacle_name = []
            for a_j in range(self.n_agents):
                if a_i != a_j: # dont need to check itself
                    ox, oy, _, _, width, length = self.agents[a_j].get_info()
                    agent_j_name = self.agent_list[a_j]['name']
                    p1 = [ox - width/2, oy - length/2]
                    p2 = [ox + width/2, oy - length/2]
                    p3 = [ox + width/2, oy + length/2]
                    p4 = [ox - width/2, oy + length/2]
                    # TODO: apply rotation
                    potential_obstacles.append(np.array([p1, p2, p3, p4]))
                    obstacle_name.append(agent_j_name)
            assert len(potential_obstacles) == self.n_agents - 1, 'obstacle number is not correct'
            potential_obstacles = np.array(potential_obstacles)
            
            q = np.array([cx, cy])
            all_collision_name = []
            s = np.array(self.sensors[a_i][:, -2:] - q)
            
            possible_sensor_distance = []
            possible_intersections = [] 
            possible_collision_name = []
            
            def _compute_collision(p, r):
                """
                Math Stuff in computing distances.
                @return:
                    u, index
                """
                if len(p.shape) == 3: # obstacle
                    u = np.expand_dims(np.cross(np.squeeze(q-p), r),2)
                    t = np.array([np.array([np.cross(q-p[o_i, o_j], s) for o_j in range(p.shape[1])]) for o_i in range(p.shape[0])])
                    den = np.array([np.array([np.cross(r[o_i, o_j], s) for o_j in range(r.shape[1])]) for o_i in range(r.shape[0])]) # 4-corners
                    den_idx = np.where(den == 0)
                else: # window
                    u = np.expand_dims(np.cross(np.squeeze(q-p), r), 1)
                    t = np.array([np.cross(q-p[o_i], s) for o_i in range(p.shape[0])])
                    den = np.array([np.cross(r[o_i], s) for o_i in range(r.shape[0])]) # 4-corners
                    den_idx = np.where(den == 0)

                t = t / (den + 1e-20)
                u = u / (den + 1e-20)
                t_idx = np.where(np.abs(t-1/2.) <= 1/2.)
                u_idx = np.where(np.abs(u-1/2.) <= 1/2.)
                t_set = [np.array(t_idx)[:, i].tolist() for i in range(len(t_idx[0]))]
                u_set = [np.array(u_idx)[:, i].tolist() for i in range(len(u_idx[0]))]
                den_set = [np.array(den_idx)[:, i].tolist() for i in range(len(den_idx[0]))]

                final_idx = np.array([val for val in t_set if val in u_set and val not in den_set]).T
                return u, final_idx
            
            # obstacle collision detection for all the sensors of agent a_i
            p_obs = potential_obstacles
            r_obs = -potential_obstacles[:, :, :] + potential_obstacles[:, list(range(1, potential_obstacles.shape[1]))+[0], :]      
            
            u, final_idx = _compute_collision(p_obs, r_obs)
            if len(final_idx) >= 1:
                final_idx = tuple(final_idx[i] for i in range(3))
                intersection = q + np.expand_dims(u[final_idx],1) * s[final_idx[2],:]
                possible_intersections.extend(intersection.tolist())
                possible_sensor_distance.extend(np.linalg.norm(intersection-q, axis=1).tolist())
                possible_collision_name.extend([obstacle_name[final_idx[0][i]] for i in range(len(final_idx[0]))])

            # window collision detection for all the sensors of agent a_i
            p_win = self.win_coord
            r_win = -self.win_coord[:, :] + self.win_coord[list(range(1, self.win_coord.shape[0]))+[0], :]

            u_win, final_idx_win = _compute_collision(p_win, r_win)
            if len(final_idx_win) >= 1:
                final_idx_win = tuple(final_idx_win[i] for i in range(2))
                intersection = q + np.expand_dims(u_win[final_idx_win],1) * s[final_idx_win[1], :]
                possible_intersections.extend(intersection.tolist())
                possible_sensor_distance.extend(np.linalg.norm(intersection-q, axis=1).tolist())
                possible_collision_name.extend(['windows' for i in range(len(final_idx_win[0]))])
            
            # Evaluate for each sensor, if none of them is colliding, simply return
            if len(final_idx) < 1 and len(final_idx_win) < 1:
                for si in range(len(self.sensors[a_i])):
                    self.sensors[a_i][si, 0] = sensor_max
                    all_collision_name.append('farest')
            # at least one of the sensor detects something
            else:
                if len(final_idx) > 0:
                    obs_idx_offset = len(final_idx[0])
                else:
                    obs_idx_offset = 0

                for si in range(len(self.sensors[a_i])):
                    if len(final_idx) >= 1:
                        obs_idx = np.where(final_idx[2] == si)
                        col_idx = obs_idx

                        if len(final_idx_win) >= 1:
                            win_idx = np.where(final_idx_win[1] == si)

                            if len(win_idx[0]) >= 1:
                                for win_i in range(len(win_idx[0])):
                                    win_idx[0][win_i] = win_idx[0][win_i] + obs_idx_offset
                                if len(obs_idx[0]) >= 1:
                                    col_idx = (np.concatenate([col_idx[0], win_idx[0]], 0),) # all possible collision index
                                else:
                                    col_idx = win_idx
                    else:
                        if len(final_idx_win) >= 1:
                            win_idx = np.where(final_idx_win[1] == si)
                            
                            if len(win_idx[0]) >= 1:
                                for win_i in range(len(win_idx[0])):
                                    win_idx[0][win_i] = win_idx[0][win_i] + obs_idx_offset
                                col_idx = win_idx # all possible collision index
                            else:
                                self.sensors[a_i][si, 0] = sensor_max
                                all_collision_name.append('farest')
                                continue

                    tmp_dist = np.array(possible_sensor_distance)[col_idx]
                    tmp_inter = np.array(possible_intersections)[col_idx]
                    tmp_name = [possible_collision_name[i] for i in col_idx[0].tolist()]

                    if len(tmp_dist) == 0:
                        tmp_dist = np.array([sensor_max])
                        tmp_inter = np.array([self.sensors[a_i][si, -2:]])
                        tmp_name = ['farest']                   

                    self.sensors[a_i][si, 0] = np.min(tmp_dist)
                    self.sensors[a_i][si, -2:] = tmp_inter[np.argmin(tmp_dist)]                
                    all_collision_name.append(tmp_name[np.argmin(tmp_dist)])

            # check ego vehicle observation
            agent_i_name = self.agent_list[a_i]['name']
            if (agent_i_name == 'ego') and ('pedestrain' in all_collision_name):
                self.ego_observe_pedestrain = True
            else:
                self.ego_observe_pedestrain = False

            '''
            # check sensors collision
            self.sensor_collision = False
            for s_i in range(len(self.sensors[a_i])):
                if self.sensors[a_i][s_i, 0] < self.collision_threshold:
                    # colliding with pedestrain or other objects
                    if (agent_i_name == 'ego') and (all_collision_name[s_i] == 'pedestrain'):
                        self.ego_collide_pedestrain = True
                    elif (agent_i_name == 'block_vehicle') and (all_collision_name[s_i] == 'pedestrain'):
                        self.block_collide_pedestrain = True
                    elif (agent_i_name == 'pedestrain') and all_collision_name[s_i] == 'ego':
                        # avoid the case that ego intentionally hit other vehicles
                        _, ped_cy = self.agents[self.agent_mapping['pedestrain']].get_info()[0:2] # ped
                        _, ego_cy = self.agents[self.agent_mapping['ego']].get_info()[0:2] # block
                        if ped_cy - ego_cy > 15:
                            self.ego_collide_pedestrain = True
                        else:
                            self.collide_others = True
                            print('intentional collide ego', ego_cy, ped_cy)
                    elif (agent_i_name == 'pedestrain') and all_collision_name[s_i] == 'block_vehicle':
                        # avoid the case that ego intentionally hit other vehicles
                        _, ped_cy = self.agents[self.agent_mapping['pedestrain']].get_info()[0:2] # ped
                        _, block_cy = self.agents[self.agent_mapping['block_vehicle']].get_info()[0:2] # block
                        if ped_cy - block_cy > 15:
                            self.block_collide_pedestrain = True
                        else:
                            self.collide_others = True
                            print('intentional collide block_vehicle', block_cy, ped_cy)
                    else:
                        self.collide_others = True
                        _, block_cy = self.agents[self.agent_mapping['block_vehicle']].get_info()[0:2] # block
                        print(agent_i_name, all_collision_name[s_i], 'block_cy:', block_cy)
            '''

            # check sensors collision
            self.sensor_collision = False
            for s_i in range(len(self.sensors[a_i])):
                if self.sensors[a_i][s_i, 0] < self.collision_threshold:
                    # colliding with pedestrain or other objects
                    if (agent_i_name == 'ego') and (all_collision_name[s_i] == 'pedestrain'):
                        self.ego_collide_pedestrain = True
                    elif (agent_i_name == 'block_vehicle') and (all_collision_name[s_i] == 'pedestrain'):
                        self.block_collide_pedestrain = True
                    elif agent_i_name == 'pedestrain':
                        if self.sensors[a_i][s_i, 0] < 4:
                            self.collide_others = True
                            #print('ped intentional collide ego or block', self.sensors[a_i][s_i, 0])
                    else:
                        self.collide_others = True
