import os
import sys
import json
import glob
import hydra
import torch_scatter
import torch
import pickle
import random
from tqdm import tqdm

import matplotlib.pyplot as plt
from torch_geometric.data import Dataset
from torch_geometric.loader import DataLoader
torch.set_printoptions(threshold=100000)
import numpy as np
np.set_printoptions(suppress=True)

from utils.data import *
from utils.geometry import apply_se2_transform, angle_sub_tensor

class RLWaymoDataset(Dataset):
    # constants defining reward dimensions
    # 0: position target achieved (0 or 1)
    POS_TARGET_ACHIEVED_REW_IDX = 0
    # 1: heading target achieved (0 or 1)
    HEADING_TARGET_ACHIEVED_REW_IDX = 1
    # 2: speed target achieved (0 or 1)
    SPEED_TARGET_ACHIEVED_REW_IDX = 2
    # 3: position goal reward (shaped)
    POS_GOAL_SHAPED_REW_IDX = 3
    # 4: speed goal reward (shaped)
    SPEED_GOAL_SHAPED_REW_IDX = 4
    # 5: heading goal reward (shaped)
    HEADING_GOAL_SHAPED_REW_IDX = 5
    # 6: veh-veh collision reward (0 or 1)
    VEH_VEH_COLLISION_REW_IDX = 6
    # 7: veh-edge collision reward (0 or 1)
    VEH_EDGE_COLLISION_REW_IDX = 7
    
    def __init__(self, cfg, split_name='train', mode='train'):
        super(RLWaymoDataset, self).__init__()
        self.cfg = cfg
        self.data_root = self.cfg.dataset_path
        self.split_name = split_name 
        self.mode = mode
        # these are the offline_rl files
        if not self.cfg.preprocess_simulated_data:
            self.files = glob.glob(os.path.join(self.cfg.preprocess_dir, f"{self.split_name}_march_new") + "/*.pkl")
        else:
            self.files = glob.glob(self.cfg.simulated_dataset + "/*.json")
        self.files = sorted(self.files)
        self.dset_len = len(self.files)
        self.preprocess = self.cfg.preprocess
        if not self.cfg.preprocess_simulated_data:
            if self.cfg.goal_fix:
                self.preprocessed_dir = os.path.join(self.cfg.preprocess_dir, f"{self.split_name}_march_new")
            else:
                self.preprocessed_dir = os.path.join(self.cfg.preprocess_dir, f"{self.split_name}")
        else:
            self.preprocessed_dir = self.cfg.simulated_dataset_preprocessed_dir
        if not os.path.exists(self.preprocessed_dir):
            os.makedirs(self.preprocessed_dir, exist_ok=True)


    def get_roads(self, data):
        roads_data = data['roads']
        num_roads = len(roads_data)
        
        final_roads = []
        final_road_types = []
        road_edge_polylines = []
        for n in range(num_roads):
            curr_road_rawdat = roads_data[n]['geometry']
            if isinstance(curr_road_rawdat, dict):
                # for stop sign, repeat x/y coordinate along the point dimension
                final_roads.append(np.array((curr_road_rawdat['x'], curr_road_rawdat['y'], 1.0)).reshape(1, -1).repeat(self.cfg.max_num_road_pts_per_polyline, 0))
                final_road_types.append(get_road_type_onehot(roads_data[n]['type']))
            else:
                if roads_data[n]['type'] == 'road_edge':
                    polyline = []
                    for p in range(len(curr_road_rawdat)):
                        polyline.append(np.array((curr_road_rawdat[p]['x'], curr_road_rawdat[p]['y'])))
                    road_edge_polylines.append(np.array(polyline))
                
                # either we add points until we run out of points and append zeros
                # or we fill up with points until we reach max limit
                curr_road = []
                for p in range(len(curr_road_rawdat)):
                    curr_road.append(np.array((curr_road_rawdat[p]['x'], curr_road_rawdat[p]['y'], 1.0)))
                    if len(curr_road) == self.cfg.max_num_road_pts_per_polyline:
                        final_roads.append(np.array(curr_road))
                        curr_road = []
                        final_road_types.append(get_road_type_onehot(roads_data[n]['type']))
                if len(curr_road) < self.cfg.max_num_road_pts_per_polyline and len(curr_road) > 0:
                    tmp_curr_road = np.zeros((self.cfg.max_num_road_pts_per_polyline, 3))
                    tmp_curr_road[:len(curr_road)] = np.array(curr_road)
                    final_roads.append(tmp_curr_road)
                    final_road_types.append(get_road_type_onehot(roads_data[n]['type']))

        return np.array(final_roads), np.array(final_road_types), road_edge_polylines


    def extract_rawdata(self, agents_data):
        # Get indices of non-parked cars and cars that exist for the entire episode
        agent_data = []
        agent_types = []
        agent_actions = []
        agent_rewards = []
        if self.cfg.goal_fix:
            agent_goals = []
        parked_agent_ids = []
        incomplete_ids = []
        last_exist_timesteps = []

        for n in range(len(agents_data)):
            ag_position = agents_data[n]['position']
            x_values = [entry['x'] for entry in ag_position]
            y_values = [entry['y'] for entry in ag_position]
            ag_position = np.column_stack((x_values, y_values))
            ag_heading = np.array(agents_data[n]['heading']).reshape((-1, 1))
            ag_velocity = agents_data[n]['velocity']
            x_values = [entry['x'] for entry in ag_velocity]
            y_values = [entry['y'] for entry in ag_velocity]
            ag_velocity = np.column_stack((x_values, y_values))
            # parked vehicle: average velocity < 0.05m/s
            if np.linalg.norm(ag_velocity, axis=-1).mean() < self.cfg.parked_car_velocity_threshold:
                parked_agent_ids.append(n)

            ag_existence = np.array(agents_data[n]['existence']).reshape((-1, 1))
            idx_of_first_disappearance = np.where(ag_existence == 0.0)[0]
            # once we find first missing step, all subsequent steps should be missing (as simulation is now undefined)
            if len(idx_of_first_disappearance) > 0:
                assert np.all(ag_existence[idx_of_first_disappearance[0]:] == 0.0)

            # only one timestep in ground-truth trajectory so no valid timesteps in offline RL dataset
            # since we need at least two timesteps in ground-truth trajectory to define a valid action with inverse Bicycle Model
            if len(idx_of_first_disappearance) > 0 and idx_of_first_disappearance[0] == 0:
                incomplete_ids.append(n)
                idx_of_last_existence = -1
            else:
                # for each agent, get the timestep of last existence
                idx_of_last_existence = np.where(ag_existence == 1.0)[0][-1]
                
            last_exist_timesteps.append(idx_of_last_existence)
            ag_actions = np.column_stack((agents_data[n]['acceleration'], agents_data[n]['steering']))

            ag_length = np.ones((len(ag_position), 1)) * agents_data[n]['length']
            ag_width = np.ones((len(ag_position), 1)) * agents_data[n]['width']
            agent_type = get_object_type_onehot(agents_data[n]['type'])
            # zero out reward for missing timesteps 
            rewards = np.array(agents_data[n]['reward']) * ag_existence 

            if self.cfg.goal_fix:
                goal_position_x = agents_data[n]['goal_position']['x']
                goal_position_y = agents_data[n]['goal_position']['y']
                goal_heading = agents_data[n]['goal_heading']
                goal_speed = agents_data[n]['goal_speed']
                goal_velocity_x = goal_speed * np.cos(goal_heading)
                goal_velocity_y = goal_speed * np.sin(goal_heading)
                goal = np.array([goal_position_x, goal_position_y, goal_velocity_x, goal_velocity_y, goal_heading])
                goal = np.repeat(goal[None, :], len(ag_position), 0)

            ag_state = np.concatenate((ag_position, ag_velocity, ag_heading, ag_length, ag_width, ag_existence), axis=-1)
            agent_data.append(ag_state)
            agent_actions.append(ag_actions)
            agent_rewards.append(rewards)
            agent_types.append(agent_type)
            if self.cfg.goal_fix:
                agent_goals.append(goal)
        
        # convert to numpy array
        agent_data = np.array(agent_data)
        agent_actions = np.array(agent_actions)
        agent_rewards = np.array(agent_rewards)
        agent_types = np.array(agent_types)
        if self.cfg.goal_fix:
            agent_goals = np.array(agent_goals)
        else:
            agent_goals = None
        parked_agent_ids = np.array(parked_agent_ids)
        incomplete_ids = np.array(incomplete_ids)
        last_exist_timesteps = np.array(last_exist_timesteps)
        
        return agent_data, agent_actions, agent_rewards, agent_types, agent_goals, parked_agent_ids, incomplete_ids, last_exist_timesteps


    def compute_dist_to_nearest_road_edge_rewards(self, ag_data, road_edge_polylines):
        # get all road edge polylines
        dist_to_road_edge_rewards = []
        for n in range(len(ag_data)):
            dist_to_road_edge = compute_distance_to_road_edge(ag_data[n, :, 0].reshape(1, -1),
                                                                ag_data[n, :, 1].reshape(1, -1), road_edge_polylines)
            dist_to_road_edge_rewards.append(-dist_to_road_edge / self.cfg.dist_to_road_edge_scaling_factor)
        
        dist_to_road_edge_rewards = np.array(dist_to_road_edge_rewards)
        
        return dist_to_road_edge_rewards

    
    def compute_dist_to_nearest_road_edge_rewards_gpu(self, ag_data, road_edge_polylines):
        ag_data = from_numpy(ag_data).cuda()
        
        # get all road edge polylines
        dist_to_road_edge_rewards = []
        for n in range(len(ag_data)):
            dist_to_road_edge = compute_distance_to_road_edge_gpu(ag_data[n, :, 0].reshape(1, -1),
                                                                ag_data[n, :, 1].reshape(1, -1), road_edge_polylines).cpu().numpy()
            dist_to_road_edge_rewards.append(-dist_to_road_edge / self.cfg.dist_to_road_edge_scaling_factor)
        
        dist_to_road_edge_rewards = np.array(dist_to_road_edge_rewards)
        
        return dist_to_road_edge_rewards


    def compute_dist_to_nearest_vehicle_rewards(self, ag_data, normalize=True):
        num_timesteps = ag_data.shape[1]
        
        ag_positions = ag_data[:,:,:2]
        ag_existence = ag_data[:,:,-1]

        # set x/y position at each nonexisting timestep to np.inf
        mask = np.repeat(ag_existence[:,:,np.newaxis], repeats=2, axis=-1).astype(bool)
        ag_positions[~mask] = np.inf

        # data[:, np.newaxis] has shape (A, 1, 90, 2) and data[np.newaxis, :] has shape (1, A, 90, 2)
        # Subtracting these gives an array of shape (A, A, 90, 2) with pairwise differences
        diff = ag_positions[:, np.newaxis] - ag_positions[np.newaxis, :]
        squared_dist = np.sum(diff**2, axis=-1)

        # Replace zero distances (distance to self) with np.inf
        for i in range(num_timesteps):
            np.fill_diagonal(squared_dist[:,:,i], np.inf)

        # Find minimum distance for each agent at each timestep, shape (A, 90)
        dist_nearest_vehicle = np.sqrt(np.min(squared_dist, axis=1))
        # handles case when only one valid agent at specific timestep
        dist_nearest_vehicle[dist_nearest_vehicle == np.inf] = np.nan 
        
        # if dist > 15, give 15
        if normalize:
            dist_nearest_vehicle = np.clip(dist_nearest_vehicle * ag_existence, a_min=0.0, a_max=self.cfg.max_veh_veh_distance)
            # given that every reward is in [0, 15], we will normalize this to be between [0, 1] by simply dividing by 15.0
            dist_nearest_vehicle = dist_nearest_vehicle / self.cfg.max_veh_veh_distance
        else:
            dist_nearest_vehicle = dist_nearest_vehicle * ag_existence

        # set reward to 0 when undefined
        dist_nearest_vehicle = np.nan_to_num(dist_nearest_vehicle, nan=0.0)
        
        return dist_nearest_vehicle

    
    def compute_rewards(self, ag_data, ag_rewards, veh_edge_dist_rewards, veh_veh_dist_rewards):
        ag_existence = ag_data[:, :, -1:]

        processed_rewards = np.array(ag_rewards)
        if self.cfg.remove_shaped_goal:
            goal_pos_rewards = processed_rewards[:, :, self.POS_TARGET_ACHIEVED_REW_IDX] * self.cfg.pos_target_achieved_rew_multiplier
        elif self.cfg.only_shaped_goal:
            goal_pos_rewards = (np.clip(processed_rewards[:, :, self.POS_GOAL_SHAPED_REW_IDX], a_min=self.cfg.pos_goal_shaped_min, a_max=self.cfg.pos_goal_shaped_max) - self.cfg.pos_goal_shaped_max) * (1 / self.cfg.pos_goal_shaped_max)
        else:
            goal_pos_rewards = processed_rewards[:, :, self.POS_TARGET_ACHIEVED_REW_IDX] * self.cfg.pos_target_achieved_rew_multiplier \
                 + (np.clip(processed_rewards[:, :, self.POS_GOAL_SHAPED_REW_IDX], a_min=self.cfg.pos_goal_shaped_min, a_max=self.cfg.pos_goal_shaped_max) - self.cfg.pos_goal_shaped_max) * (1 / self.cfg.pos_goal_shaped_max)
        goal_pos_rewards = goal_pos_rewards[:, :, np.newaxis] * ag_existence

        goal_heading_rewards = processed_rewards[:, :, self.HEADING_TARGET_ACHIEVED_REW_IDX] + processed_rewards[:, :, self.HEADING_GOAL_SHAPED_REW_IDX]
        goal_heading_rewards = goal_heading_rewards[:, :, np.newaxis] * ag_existence

        goal_velocity_rewards = processed_rewards[:, :, self.SPEED_TARGET_ACHIEVED_REW_IDX] + processed_rewards[:, :, self.SPEED_GOAL_SHAPED_REW_IDX]
        goal_velocity_rewards = goal_velocity_rewards[:, :, np.newaxis] * ag_existence
        
        if self.cfg.remove_shaped_veh_reward:
            veh_veh_collision_rewards = -1 * processed_rewards[:, :, self.VEH_VEH_COLLISION_REW_IDX] * self.cfg.veh_veh_collision_rew_multiplier
        elif self.cfg.only_shaped_veh_reward:
            veh_veh_collision_rewards = veh_veh_dist_rewards
        else:
            veh_veh_collision_rewards = veh_veh_dist_rewards - \
                processed_rewards[:, :, self.VEH_VEH_COLLISION_REW_IDX] * self.cfg.veh_veh_collision_rew_multiplier

        veh_veh_collision_rewards = veh_veh_collision_rewards[:, :, np.newaxis] * ag_existence
        
        if self.cfg.remove_shaped_edge_reward:
            veh_edge_collision_rewards = -1 * processed_rewards[:, :, self.VEH_EDGE_COLLISION_REW_IDX] * self.cfg.veh_edge_collision_rew_multiplier
        elif self.cfg.only_shaped_edge_reward:
            veh_edge_collision_rewards = np.clip(np.abs(veh_edge_dist_rewards) * self.cfg.dist_to_road_edge_scaling_factor, a_min=0, a_max=5) / 5.
        else:
            veh_edge_collision_rewards = np.clip(np.abs(veh_edge_dist_rewards) * self.cfg.dist_to_road_edge_scaling_factor, a_min=0, a_max=5) / 5. - \
                processed_rewards[:, :, self.VEH_EDGE_COLLISION_REW_IDX] * self.cfg.veh_edge_collision_rew_multiplier

        veh_edge_collision_rewards = veh_edge_collision_rewards[:, :, np.newaxis] * ag_existence

        all_rewards = np.concatenate((goal_pos_rewards, goal_heading_rewards, goal_velocity_rewards,
                                    veh_veh_collision_rewards, veh_edge_collision_rewards), axis=-1)
        return all_rewards

    
    def select_relevant_agents(self, agent_states, agent_types, actions, rtgs, goals, origin_agent_idx, timestep, moving_agent_mask):
        origin_states = agent_states[origin_agent_idx, timestep, :2].reshape(1, -1)
        dist_to_origin = np.linalg.norm(origin_states - agent_states[:, timestep, :2], axis=-1)
        valid_agents = np.where(dist_to_origin < self.cfg.agent_dist_threshold)[0]
    
        final_agent_states = np.zeros((self.cfg.max_num_agents, *agent_states[0].shape))
        final_agent_types = -np.ones((self.cfg.max_num_agents, *agent_types[0].shape))
        final_actions = np.zeros((self.cfg.max_num_agents, *actions[0].shape))
        final_rtgs = np.zeros((self.cfg.max_num_agents, *rtgs[0].shape))
        final_goals = np.zeros((self.cfg.max_num_agents, *goals[0].shape))
        final_moving_agent_mask = np.zeros(self.cfg.max_num_agents)

        closest_ag_ids = np.argsort(dist_to_origin)[:self.cfg.max_num_agents]
        closest_ag_ids = np.intersect1d(closest_ag_ids, valid_agents)
        # shuffle ids so it is not ordered by distance
        if self.mode == 'train':
            np.random.shuffle(closest_ag_ids)
        
        final_agent_states[:len(closest_ag_ids)] = agent_states[closest_ag_ids]
        final_agent_types[:len(closest_ag_ids)] = agent_types[closest_ag_ids]
        final_actions[:len(closest_ag_ids)] = actions[closest_ag_ids]
        final_rtgs[:len(closest_ag_ids)] = rtgs[closest_ag_ids]
        final_goals[:len(closest_ag_ids)] = goals[closest_ag_ids]
        final_moving_agent_mask[:len(closest_ag_ids)] = moving_agent_mask[closest_ag_ids]

        # idx of origin agent in new state tensors
        new_origin_agent_idx = np.where(closest_ag_ids == origin_agent_idx)[0][0]
        
        return final_agent_states, final_agent_types, final_actions, final_rtgs, final_goals, final_moving_agent_mask, new_origin_agent_idx

    
    def select_relevant_agents_new(self, agent_states, agent_types, actions, rtgs, goals, origin_agent_idx, timestep, moving_agent_mask, relevant_agent_idxs):
        origin_states = agent_states[origin_agent_idx, timestep, :2].reshape(1, -1)
        dist_to_origin = np.linalg.norm(origin_states - agent_states[:, timestep, :2], axis=-1)
        valid_agents = np.where(dist_to_origin < self.cfg.agent_dist_threshold)[0]
    
        final_agent_states = np.zeros((self.cfg.max_num_agents, *agent_states[0].shape))
        final_agent_types = -np.ones((self.cfg.max_num_agents, *agent_types[0].shape))
        final_actions = np.zeros((self.cfg.max_num_agents, *actions[0].shape))
        final_rtgs = np.zeros((self.cfg.max_num_agents, *rtgs[0].shape))
        final_goals = np.zeros((self.cfg.max_num_agents, *goals[0].shape))
        final_moving_agent_mask = np.zeros(self.cfg.max_num_agents)

        if len(relevant_agent_idxs) == 0:
            closest_ag_ids = np.argsort(dist_to_origin)[:self.cfg.max_num_agents]
            closest_ag_ids = np.intersect1d(closest_ag_ids, valid_agents)
            # shuffle ids so it is not ordered by distance
            if self.mode == 'train':
                np.random.shuffle(closest_ag_ids)
        else:
            closest_ag_ids = np.array(relevant_agent_idxs).astype(int)
            closest_ag_ids = np.intersect1d(closest_ag_ids, valid_agents)
            
            if len(closest_ag_ids) < len(relevant_agent_idxs):
                out_of_range_vehicles = np.setdiff1d(relevant_agent_idxs, closest_ag_ids)
                relevant_agent_idxs = [idx for idx in relevant_agent_idxs if idx not in out_of_range_vehicles]
        
        final_agent_states[:len(closest_ag_ids)] = agent_states[closest_ag_ids]
        final_agent_types[:len(closest_ag_ids)] = agent_types[closest_ag_ids]
        final_actions[:len(closest_ag_ids)] = actions[closest_ag_ids]
        final_rtgs[:len(closest_ag_ids)] = rtgs[closest_ag_ids]
        final_goals[:len(closest_ag_ids)] = goals[closest_ag_ids]
        final_moving_agent_mask[:len(closest_ag_ids)] = moving_agent_mask[closest_ag_ids]

        # idx of origin agent in new state tensors
        new_agent_idx_dict = {}
        for new_idx, old_idx in enumerate(closest_ag_ids):
            new_agent_idx_dict[old_idx] = new_idx
        new_origin_agent_idx = np.where(closest_ag_ids == origin_agent_idx)[0][0]
        
        return final_agent_states, final_agent_types, final_actions, final_rtgs, final_goals, final_moving_agent_mask, new_agent_idx_dict, relevant_agent_idxs


    def select_random_origin_agent(self, agent_states, moving_mask):
        # search for moving agent that exists at future timestep
        valid_idxs = np.where((agent_states[:, self.cfg.input_horizon, -1] == 1) * moving_mask)[0]
        if len(valid_idxs) == 0:
            return 0, False
        rand_idx = np.random.choice(len(valid_idxs))
        return valid_idxs[rand_idx], True


    def undiscretize_actions(self, actions):
        # Initialize the array for the continuous actions
        actions_shape = (actions.shape[0], actions.shape[1], 2)
        continuous_actions = np.zeros(actions_shape)
        
        # Separate the combined actions back into their discretized components
        continuous_actions[:, :, 0] = actions // self.cfg.steer_discretization  # Acceleration component
        continuous_actions[:, :, 1] = actions % self.cfg.steer_discretization   # Steering component
        
        # Reverse the discretization
        continuous_actions[:, :, 0] /= (self.cfg.accel_discretization - 1)
        continuous_actions[:, :, 1] /= (self.cfg.steer_discretization - 1)
        
        # Denormalize to get back the original continuous values
        continuous_actions[:, :, 0] = (continuous_actions[:, :, 0] * (self.cfg.max_accel - self.cfg.min_accel)) + self.cfg.min_accel
        continuous_actions[:, :, 1] = (continuous_actions[:, :, 1] * (self.cfg.max_steer - self.cfg.min_steer)) + self.cfg.min_steer
        
        return continuous_actions

    
    def get_tilt_logits(self, goal_tilt, veh_tilt, road_tilt):
        if self.cfg.use_veh_edge_rtg:
            rtg_bin_values = np.zeros((self.cfg.rtg_discretization, 3))
        else:
            rtg_bin_values = np.zeros((self.cfg.rtg_discretization, 2))
        
        rtg_bin_values[:, 0] = goal_tilt * np.linspace(0, 1, self.cfg.rtg_discretization)
        rtg_bin_values[:, 1] = veh_tilt * np.linspace(0, 1, self.cfg.rtg_discretization)

        if self.cfg.use_veh_edge_rtg:
            rtg_bin_values[:, 2] = road_tilt * np.linspace(0, 1, self.cfg.rtg_discretization)

        return rtg_bin_values


    def undiscretize_rtgs(self, rtgs):
        continuous_rtgs = np.zeros_like(rtgs).astype(float)
        continuous_rtgs[:, :, 0] = rtgs[:, :, 0] / (self.cfg.rtg_discretization - 1)
        continuous_rtgs[:, :, 1] = rtgs[:, :, 1] / (self.cfg.rtg_discretization - 1)

        continuous_rtgs[:, :, 0] = (continuous_rtgs[:, :, 0] * (self.cfg.max_rtg_pos - self.cfg.min_rtg_pos)) + self.cfg.min_rtg_pos
        continuous_rtgs[:, :, 1] = (continuous_rtgs[:, :, 1] * (self.cfg.max_rtg_veh - self.cfg.min_rtg_veh)) + self.cfg.min_rtg_veh

        if self.cfg.use_veh_edge_rtg:
            continuous_rtgs[:, :, 2] = rtgs[:, :, 2] / (self.cfg.rtg_discretization - 1)
            continuous_rtgs[:, :, 2] = (continuous_rtgs[:, :, 2] * (self.cfg.max_rtg_road - self.cfg.min_rtg_road)) + self.cfg.min_rtg_road


        return continuous_rtgs

    
    def discretize_actions(self, actions):
        # normalize
        actions[:, :, 0] = ((np.clip(actions[:, :, 0], a_min=self.cfg.min_accel, a_max=self.cfg.max_accel) - self.cfg.min_accel)
                             / (self.cfg.max_accel - self.cfg.min_accel))
        actions[:, :, 1] = ((np.clip(actions[:, :, 1], a_min=self.cfg.min_steer, a_max=self.cfg.max_steer) - self.cfg.min_steer)
                             / (self.cfg.max_steer - self.cfg.min_steer))

        # discretize the actions
        actions[:, :, 0] = np.round(actions[:, :, 0] * (self.cfg.accel_discretization - 1))
        actions[:, :, 1] = np.round(actions[:, :, 1] * (self.cfg.steer_discretization - 1))

        # combine into a single categorical value
        combined_actions = actions[:, :, 0] * self.cfg.steer_discretization + actions[:, :, 1]

        return combined_actions

    
    def discretize_rtgs(self, rtgs):
        rtgs[:, :, 0] = np.round(rtgs[:, :, 0] * (self.cfg.rtg_discretization - 1))
        rtgs[:, :, 1] = np.round(rtgs[:, :, 1] * (self.cfg.rtg_discretization - 1))

        if self.cfg.use_veh_edge_rtg:
            rtgs[:, :, 2] = np.round(rtgs[:, :, 2] * (self.cfg.rtg_discretization - 1))

        return rtgs

    
    def normalize_scene(self, agent_states, road_points, road_types, goals, origin_agent_idx):
        # normalize scene to ego vehicle (this includes agent states, goals, and roads)
        yaw = agent_states[origin_agent_idx, 0, 4]
        angle_of_rotation = (np.pi / 2) + np.sign(-yaw) * np.abs(yaw)
        translation = agent_states[origin_agent_idx, 0, :2].copy()
        translation = translation[np.newaxis, np.newaxis, :]

        agent_states[:, :, :2] = apply_se2_transform(coordinates=agent_states[:, :, :2],
                                           translation=translation,
                                           yaw=angle_of_rotation)
        agent_states[:, :, 2:4] = apply_se2_transform(coordinates=agent_states[:, :, 2:4],
                                           translation=np.zeros_like(translation),
                                           yaw=angle_of_rotation)
        agent_states[:, :, 4] = angle_sub_tensor(agent_states[:, :, 4], -angle_of_rotation.reshape(1, 1))
        assert np.all(agent_states[:, :, 4] <= np.pi) and np.all(agent_states[:, :, 4] >= -np.pi)
        goals[:, :2] = apply_se2_transform(coordinates=goals[:, :2],
                                 translation=translation[:, 0],
                                 yaw=angle_of_rotation)
        if self.cfg.goal_dim == 5:
            goals[:, 2:4] = apply_se2_transform(coordinates=goals[:, 2:4],
                                    translation=np.zeros_like(translation[:, 0]),
                                    yaw=angle_of_rotation)
            goals[:, 4] = angle_sub_tensor(goals[:, 4], -angle_of_rotation.reshape(1))
        road_points[:, :, :2] = apply_se2_transform(coordinates=road_points[:, :, :2],
                                                    translation=translation,
                                                    yaw=angle_of_rotation)

        if len(road_points) > self.cfg.max_num_road_polylines:
            max_road_dist_to_orig = (np.linalg.norm(road_points[:, :, :2], axis=-1) * road_points[:, :, -1]).max(1)
            closest_roads_to_ego = np.argsort(max_road_dist_to_orig)[:self.cfg.max_num_road_polylines]
            final_road_points = road_points[closest_roads_to_ego]
            final_road_types = road_types[closest_roads_to_ego]
        else:
            final_road_points = np.zeros((self.cfg.max_num_road_polylines, *road_points.shape[1:]))
            final_road_points[:len(road_points)] = road_points
            final_road_types = -np.ones((self.cfg.max_num_road_polylines, road_types.shape[1]))
            final_road_types[:len(road_points)] = road_types

        ### TESTING WITH VISUALIZATION ######
        # for r in range(len(final_road_points)):
        #     mask = final_road_points[r, :, 2].astype(bool)
        #     plt.plot(final_road_points[r, :, 0][mask], final_road_points[r, :, 1][mask], color='black')
        
        # coordinates = agent_states[:, :, :2]
        # coordinates_mask = agent_states[:, :, -1].astype(bool)
        # for a in range(len(coordinates)):
        #     plt.plot(coordinates[a, :, 0][coordinates_mask[a]], coordinates[a, :, 1][coordinates_mask[a]], color='blue')

        # for a in range(len(goals)):
        #     if np.abs(goals[a, 0]) > 1000 or np.abs(goals[a, 1]) > 1000:
        #         continue
        #     plt.scatter(goals[a, 0], goals[a, 1], color='red', s=10)
        
        # plt.savefig('test.png', dpi=250)
        # plt.clf()
        ####
        
        return agent_states, final_road_points, final_road_types, goals

    def get_distance_to_road_edge(self, agent_states, road_feats):
        road_type = np.argmax(road_feats[:, 4:12], axis=1).astype(int)
        mask = road_type == 3
        road_feats = road_feats[mask]
        road_points = np.concatenate([road_feats[:, :2], road_feats[:, 2:4]], axis=0)
        agent_positions = agent_states[:, :, :2].reshape(-1, 2)
        # Compute the difference along each dimension [N, M, 2]
        diff = agent_positions[:, np.newaxis, :] - road_points[np.newaxis, :, :]
        # Compute the squared distances [N, M]
        squared_distances = np.sum(diff ** 2, axis=2)
        # Find the minimum squared distance for each point in array1 [N,]
        min_squared_distances = np.min(squared_distances, axis=1)
        # If you need the actual distances, take the square root
        min_distances = np.sqrt(min_squared_distances)

        min_distances = min_distances.reshape(24, 33)
        return min_distances

    
    def _get_constant_velocity_futures(self, present_states):
        # Shape [A, 13]
        # [0: local_pos_x, 
        #  1: local_pos_y,
        #  2: local_vel_x,
        #  3: local_vel_y,
        #  4: local_yaw,
        #  5: global_pos_x,
        #  6: global_pos_y,
        #  7: global_vel_x,
        #  8: global_vel_y,
        #  9: global_yaw,
        #  10: length,
        #  11: width,
        #  12: existence]
        present_states = present_states.copy()
        
        T = self.cfg.train_context_length - 10
        cv_future_states = np.zeros((present_states.shape[0], T, present_states.shape[1]))
        cv_future_states[:, :, 2:5] = np.expand_dims(present_states[:, 2:5], axis=1).repeat(T, axis=1)
        cv_future_states[:, :, 7:] = np.expand_dims(present_states[:, 7:], axis=1).repeat(T, axis=1)

        dt = 0.1
        offset = np.expand_dims(np.arange(T), axis=0) * dt + dt
        offset_local_x = offset * present_states[:, 2:3].repeat(T, axis=1)
        offset_local_y = offset * present_states[:, 3:4].repeat(T, axis=1)
        offset_global_x = offset * present_states[:, 7:8].repeat(T, axis=1)
        offset_global_y = offset * present_states[:, 8:9].repeat(T, axis=1)

        cv_future_states[:, :, 0] = present_states[:, :1].repeat(T, axis=1) + offset_local_x
        cv_future_states[:, :, 1] = present_states[:, 1:2].repeat(T, axis=1) + offset_local_y
        cv_future_states[:, :, 5] = present_states[:, 5:6].repeat(T, axis=1) + offset_global_x
        cv_future_states[:, :, 6] = present_states[:, 6:7].repeat(T, axis=1) + offset_global_y
        
        return cv_future_states



    def _prepare_relative_encodings(self, in_agents, present_states):
        relative_encoding_dimension = 7
        relative_encodings = np.zeros((in_agents.shape[0], in_agents.shape[0], in_agents.shape[1], relative_encoding_dimension))

        global_yaw_dimension = 9
        present_headings = present_states[:, 0, global_yaw_dimension].copy()

        # Using broadcasting for rotation_matrices
        cosines = np.cos(-present_headings + np.pi/2)
        sines = np.sin(-present_headings + np.pi/2)

        # [N_agents, 2, 2]
        rotation_matrices = np.array([
            [cosines, -sines],
            [sines, cosines]
        ]).transpose(2, 0, 1)

        global_positions_all = in_agents[:, :, 5:7].copy()
        present_positions_all = present_states[:, 0, 5:7].copy()
        global_yaws_all = in_agents[:, :, 9].copy()
        present_yaws_all = present_states[:, 0, 9].copy()
        global_speeds_all = np.linalg.norm(in_agents[:, :, 2:4], ord=2, axis=-1)
        present_speeds_all = np.linalg.norm(present_states[:, :1, 2:4], ord=2, axis=-1)

        for i in range(rotation_matrices.shape[0]):
            rotation_matrix = rotation_matrices[i]
            offsets = global_positions_all - present_positions_all[i]

            # Shape (n_agents, T, 2): This contains the offsets (x^j_t - x^i_0 , y^j_t - y^i_0) R^i_0.T 
            # so that the offsets are in i's local frame.
            rotated_offsets = np.matmul(offsets, rotation_matrix.T)
            relative_encodings[i, :, :, :2] = rotated_offsets

            yaw_offsets = global_yaws_all - present_yaws_all[i]
            relative_encodings[i, :, :, 2] = np.cos(yaw_offsets)
            relative_encodings[i, :, :, 3] = np.sin(yaw_offsets)

            relative_encodings[i, :, :, 4] = global_speeds_all * relative_encodings[i, :, :, 3] - present_speeds_all[i, 0]
            relative_encodings[i, :, :, 5] = global_speeds_all * relative_encodings[i, :, :, 4]

        relative_encodings[:, :, :, 6] = np.linalg.norm(np.expand_dims(global_positions_all, 0) - np.expand_dims(global_positions_all, 1), ord=2, axis=-1)

        return relative_encodings
    
    def select_indiv_agent_roads(self, agent_states, road_points, road_types):
        num_agents = agent_states.shape[0]
        if len(road_points) > self.cfg.max_num_road_polylines:   
            road_existence = road_points[None, :, :, -1].copy()
            road_existence[np.where(road_existence == 0.)] = np.nan
            max_dist_to_road = np.nanmax(np.linalg.norm(road_points[None, :, :, :2] - 
                                            agent_states[:, -1:, None, :2], axis=-1) * road_existence, axis=2)
            closest_roads_to_agent = np.argsort(max_dist_to_road, axis=-1)[:, :self.cfg.max_num_road_polylines]
            repeated_road_points = np.repeat(road_points[None], num_agents, axis=0)
            final_road_points = np.take_along_axis(repeated_road_points, closest_roads_to_agent[:, :, None, None], axis=1)
            repeated_road_types = np.repeat(road_types[None], num_agents, axis=0)
            final_road_types = np.take_along_axis(repeated_road_types, closest_roads_to_agent[:, : , None], axis=1)
        else:
            final_road_points = np.zeros((num_agents, self.cfg.max_num_road_polylines, *road_points.shape[1:]))
            final_road_points[:, :len(road_points)] = road_points[None]
            final_road_types = -np.ones((num_agents, self.cfg.max_num_road_polylines, road_types.shape[1]))
            final_road_types[:, :len(road_points)] = road_types[None]

        final_road_points[:, :, :, -1] = agent_states[:, -1:, -1:] * final_road_points[:, :, :, -1]
        final_road_types = final_road_types * agent_states[:, -1:, -1:]

        return final_road_points, final_road_types

    
    def _get_agents_local_frame_eval(self, agent_pasts, road_points, goals):
        num_agents = agent_pasts.shape[0]
        new_agents_pasts = np.zeros((agent_pasts.shape[0], agent_pasts.shape[1], agent_pasts.shape[2] + 5))
        new_agents_pasts[:, :, 5:] = agent_pasts

        new_roads = road_points.copy()
        # other agents
        for n in range(num_agents):
            if not agent_pasts[n, -1, -1]:
                continue

            yaw = agent_pasts[n, -1, 4]
            angle_of_rotation = (np.pi / 2) + np.sign(-yaw) * np.abs(yaw)
            translation = agent_pasts[n, -1, :2]

            new_agents_pasts[n, :, :2] = apply_se2_transform(coordinates=agent_pasts[n, :, :2], 
                                                             translation=translation.reshape(1, -1),
                                                             yaw=angle_of_rotation)
            new_agents_pasts[n, :, 2:4] = apply_se2_transform(coordinates=agent_pasts[n, :, 2:4], 
                                                              translation=np.zeros_like(translation).reshape(1, -1),
                                                              yaw=angle_of_rotation)
            new_agents_pasts[n, :, 4] = angle_sub_tensor(agent_pasts[n, :, 4], -angle_of_rotation)

            new_roads[n, :, :, :2] = apply_se2_transform(coordinates=road_points[n, :, :, :2],
                                                         translation=translation.reshape(1, 1, -1),
                                                         yaw=angle_of_rotation)
            new_roads[n][np.where(new_roads[n, :, :, -1] == 0)] = 0.0

            goals[n:n+1, :2] = apply_se2_transform(coordinates=goals[n:n+1, :2],
                                               translation=translation.reshape(1, -1),
                                               yaw=angle_of_rotation)
            if self.cfg.goal_dim == 5:
                goals[n:n+1, 2:4] = apply_se2_transform(coordinates=goals[n:n+1, 2:4],
                                        translation=np.zeros_like(translation).reshape(1, -1),
                                        yaw=angle_of_rotation)
                goals[n:n+1, 4] = angle_sub_tensor(goals[n:n+1, 4], -angle_of_rotation)
        
        return new_agents_pasts, new_roads, goals

    
    def _get_agents_local_frame(self, agent_pasts, agent_futures, road_points, goals):
        num_agents = agent_pasts.shape[0]
        new_agents_pasts = np.zeros((agent_pasts.shape[0], agent_pasts.shape[1], agent_pasts.shape[2] + 5))
        new_agents_pasts[:, :, 5:] = agent_pasts

        new_agents_futures = np.zeros((agent_futures.shape[0], agent_futures.shape[1], agent_futures.shape[2] + 5))
        new_agents_futures[:, :, 5:] = agent_futures

        new_roads = road_points.copy()
        # other agents
        for n in range(num_agents):
            if not agent_pasts[n, -1, -1]:
                continue

            yaw = agent_pasts[n, -1, 4]
            angle_of_rotation = (np.pi / 2) + np.sign(-yaw) * np.abs(yaw)
            translation = agent_pasts[n, -1, :2]

            new_agents_pasts[n, :, :2] = apply_se2_transform(coordinates=agent_pasts[n, :, :2], 
                                                             translation=translation.reshape(1, -1),
                                                             yaw=angle_of_rotation)
            new_agents_pasts[n, :, 2:4] = apply_se2_transform(coordinates=agent_pasts[n, :, 2:4], 
                                                              translation=np.zeros_like(translation).reshape(1, -1),
                                                              yaw=angle_of_rotation)
            new_agents_pasts[n, :, 4] = angle_sub_tensor(agent_pasts[n, :, 4], -angle_of_rotation)

            new_agents_futures[n, :, :2] = apply_se2_transform(coordinates=agent_futures[n, :, :2], 
                                                               translation=translation.reshape(1, -1),
                                                               yaw=angle_of_rotation)
            new_agents_futures[n, :, 2:4] = apply_se2_transform(coordinates=agent_futures[n, :, 2:4], 
                                                                translation=np.zeros_like(translation).reshape(1, -1),
                                                                yaw=angle_of_rotation)
            new_agents_futures[n, :, 4] = angle_sub_tensor(agent_futures[n, :, 4], -angle_of_rotation)

            new_roads[n, :, :, :2] = apply_se2_transform(coordinates=road_points[n, :, :, :2],
                                                         translation=translation.reshape(1, 1, -1),
                                                         yaw=angle_of_rotation)
            new_roads[n][np.where(new_roads[n, :, :, -1] == 0)] = 0.0

            goals[n:n+1, :2] = apply_se2_transform(coordinates=goals[n:n+1, :2],
                                               translation=translation.reshape(1, -1),
                                               yaw=angle_of_rotation)
            if self.cfg.goal_dim == 5:
                goals[n:n+1, 2:4] = apply_se2_transform(coordinates=goals[n:n+1, 2:4],
                                        translation=np.zeros_like(translation).reshape(1, -1),
                                        yaw=angle_of_rotation)
                goals[n:n+1, 4] = angle_sub_tensor(goals[n:n+1, 4], -angle_of_rotation)
        
        return new_agents_pasts, new_agents_futures, new_roads, goals

    def _normalize_actions(self, actions):
        actions[:, :, 0] = ((np.clip(actions[:, :, 0], a_min=self.cfg.min_accel, a_max=self.cfg.max_accel) - self.cfg.min_accel)
                             / (self.cfg.max_accel - self.cfg.min_accel))
        actions[:, :, 1] = ((np.clip(actions[:, :, 1], a_min=self.cfg.min_steer, a_max=self.cfg.max_steer) - self.cfg.min_steer)
                             / (self.cfg.max_steer - self.cfg.min_steer))
        
        return (2.0 * actions) - 1.0

    def _unnormalize_actions(self, actions):
        actions = (actions + 1.0) / 2.0
        actions[:, :, 0] = actions[:, :, 0] * (self.cfg.max_accel - self.cfg.min_accel) + self.cfg.min_accel
        actions[:, :, 1] = actions[:, :, 1] * (self.cfg.max_steer - self.cfg.min_steer) + self.cfg.min_steer
        
        return actions

    def get_data(self, data, idx):
        if self.preprocess:
            idx = data['idx']
            num_agents = data['num_agents']
            road_points = data['road_points']
            road_types = data['road_types']
            ag_data = data['ag_data']
            ag_actions = data['ag_actions']
            ag_types = data['ag_types']
            last_exist_timesteps = data['last_exist_timesteps']
            if not self.cfg.goal_fix:
                rtgs = data['rtgs']
            else:
                ag_rewards = data['ag_rewards']
                veh_edge_dist_rewards = data['veh_edge_dist_rewards']
                veh_veh_dist_rewards = data['veh_veh_dist_rewards']
            filtered_ag_ids = data['filtered_ag_ids']
            if self.cfg.goal_fix:
                ag_goals = data['ag_goals']
            
        else:
            agent_data = data['objects']
            num_agents = len(agent_data)

            road_points, road_types, road_edge_polylines = self.get_roads(data)
            ag_data, ag_actions, ag_rewards, ag_types, ag_goals, parked_ids, incomplete_ids, last_exist_timesteps = self.extract_rawdata(agent_data)
            # zero out reward when timestep does not exist
            veh_edge_dist_rewards = self.compute_dist_to_nearest_road_edge_rewards(ag_data.copy(), road_edge_polylines) * ag_data[:, :, -1]
            veh_veh_dist_rewards = self.compute_dist_to_nearest_vehicle_rewards(ag_data.copy()) * ag_data[:, :, -1]
            if not self.cfg.goal_fix:
                all_rewards = self.compute_rewards(ag_data, ag_rewards, veh_edge_dist_rewards, veh_veh_dist_rewards)
                rtgs = np.cumsum(all_rewards[:, ::-1], axis=1)[:, ::-1]

            raw_ag_ids = np.arange(num_agents).tolist()
            # no point in training on vehicles that don't have even one valid timestep
            filtered_ag_ids = list(filter(lambda x: x not in incomplete_ids, raw_ag_ids))
            assert len(filtered_ag_ids) > 0
            
            raw_file_name = os.path.splitext(os.path.basename(self.files[idx]))[0]
            to_pickle = dict()
            to_pickle['idx'] = idx
            to_pickle['num_agents'] = num_agents 
            to_pickle['road_points'] = road_points
            to_pickle['road_types'] = road_types
            to_pickle['ag_data'] = ag_data 
            to_pickle['ag_actions'] = ag_actions 
            to_pickle['ag_types'] = ag_types 
            to_pickle['last_exist_timesteps'] = last_exist_timesteps 
            if not self.cfg.goal_fix:
                to_pickle['rtgs'] = rtgs 
            else:
                to_pickle['veh_edge_dist_rewards'] = veh_edge_dist_rewards
                to_pickle['veh_veh_dist_rewards'] = veh_veh_dist_rewards 
                to_pickle['ag_rewards'] = ag_rewards
            to_pickle['filtered_ag_ids'] = filtered_ag_ids
            if self.cfg.preprocess_simulated_data:
                to_pickle['focal_agent_idx'] = data['focal_agent_idx']

            if self.cfg.goal_fix:
                assert ag_goals is not None
                to_pickle['ag_goals'] = ag_goals
            
            with open(os.path.join(self.preprocessed_dir, f'{raw_file_name}.pkl'), 'wb') as f:
                pickle.dump(to_pickle, f, protocol=pickle.HIGHEST_PROTOCOL)

        # we are only preprocessing the simulated data here
        if self.cfg.preprocess_simulated_data:
            return
        
        if self.cfg.goal_fix:
            all_rewards = self.compute_rewards(ag_data, ag_rewards, veh_edge_dist_rewards, veh_veh_dist_rewards)
            rtgs = np.cumsum(all_rewards[:, ::-1], axis=1)[:, ::-1]
        
        if self.mode == 'eval':
            return rtgs, road_points, road_types

        if self.cfg.use_veh_edge_rtg:
            rtgs = np.concatenate([rtgs[:, :, :1], rtgs[:, :, 3:5]], axis=2)
        else:
            rtgs = np.concatenate([rtgs[:, :, :1], rtgs[:, :, 3:4]], axis=2)
        rtgs[:, :, 0] = ((np.clip(rtgs[:, :, 0], a_min=self.cfg.min_rtg_pos, a_max=self.cfg.max_rtg_pos) - self.cfg.min_rtg_pos)
                             / (self.cfg.max_rtg_pos - self.cfg.min_rtg_pos))
        rtgs[:, :, 1] = ((np.clip(rtgs[:, :, 1], a_min=self.cfg.min_rtg_veh, a_max=self.cfg.max_rtg_veh) - self.cfg.min_rtg_veh)
                             / (self.cfg.max_rtg_veh - self.cfg.min_rtg_veh))
        if self.cfg.use_veh_edge_rtg:
            rtgs[:, :, 2] = ((np.clip(rtgs[:, :, 2], a_min=self.cfg.min_rtg_road, a_max=self.cfg.max_rtg_road) - self.cfg.min_rtg_road)
                             / (self.cfg.max_rtg_road - self.cfg.min_rtg_road))
        
        if not self.cfg.goal_fix:
            goal_timesteps = np.sum(ag_data[:, :, -1], axis=-1) - 1
            goal_timesteps_repeat = np.repeat(goal_timesteps[:, np.newaxis, np.newaxis], 5, axis=2)
            goals = np.take_along_axis(ag_data[:, :, :5], goal_timesteps_repeat.astype(int), axis=1)
            moving_ids = np.where(np.linalg.norm(ag_data[:, 0, :2] - goals[:, 0, :2], axis=1) > self.cfg.moving_threshold)[0]
            goals = goals[filtered_ag_ids, 0]
        else:
            goals = ag_goals
            moving_ids = np.where(np.linalg.norm(ag_data[:, 0, :2] - goals[:, 0, :2], axis=1) > self.cfg.moving_threshold)[0]
            goals = ag_goals[filtered_ag_ids, 0]
        
        # find max timestep to set as present timestep such that there exists an agent with train_context_length future timesteps
        # max_timestep = np.max(last_exist_timesteps[moving_ids]) - (self.cfg.train_context_length - 1)
        max_timestep = np.max(last_exist_timesteps[moving_ids]) - (self.cfg.input_horizon + 1)
        # In this case, there will be some future timesteps such that all agents do not exist
        if max_timestep < 0:
            max_timestep = 0
        origin_t = np.random.randint(0, max_timestep+1)

        timesteps = np.arange(self.cfg.train_context_length) + origin_t
        timesteps = np.repeat(timesteps[np.newaxis, :, np.newaxis], self.cfg.max_num_agents, 0)
        timesteps = np.ones_like(timesteps) * (origin_t + self.cfg.input_horizon - 1)  # ultimate laziness...
        agent_states = ag_data[filtered_ag_ids, origin_t:origin_t+self.cfg.train_context_length]
        agent_types = ag_types[filtered_ag_ids]
        if origin_t == 0:
            _actions = ag_actions[filtered_ag_ids, origin_t:origin_t+self.cfg.train_context_length-1]
            actions = np.concatenate((np.zeros((len(filtered_ag_ids), 1, _actions.shape[-1])), _actions), axis=1)
        else:
            actions = ag_actions[filtered_ag_ids, origin_t-1:origin_t+self.cfg.train_context_length-1]
        rtgs = rtgs[filtered_ag_ids, origin_t:origin_t+self.cfg.train_context_length]

        # padding incomplete scenes
        current_num_timesteps = agent_states.shape[1]
        if current_num_timesteps < self.cfg.train_context_length:
            # states
            padded_agent_states = np.zeros((len(filtered_ag_ids), self.cfg.train_context_length, agent_states.shape[-1]))
            padded_agent_states[:, :current_num_timesteps] = agent_states
            agent_states = padded_agent_states.copy()
            # actions
            padded_agent_actions = np.zeros((len(filtered_ag_ids), self.cfg.train_context_length, actions.shape[-1]))
            padded_agent_actions[:, :current_num_timesteps] = actions[:, :-1]  # this -1 is to account for timestep offset.
            actions = padded_agent_actions.copy()
            # rtgs
            padded_agent_rtgs = np.zeros((len(filtered_ag_ids), self.cfg.train_context_length, rtgs.shape[-1]))
            padded_agent_rtgs[:, :current_num_timesteps] = rtgs
            rtgs = padded_agent_rtgs.copy()

        # filter for agents that move at least 0.05 metres
        moving_agent_mask = np.isin(filtered_ag_ids, moving_ids)
        # randomly choose moving agent to be at origin
        # if the scene does not contain any agent with >10 timesteps of existence, scene is invalid
        origin_agent_idx, valid_scene = self.select_random_origin_agent(agent_states, moving_agent_mask)
        
        if valid_scene:
            agent_states, agent_types, actions, rtgs, goals, moving_agent_mask, new_origin_agent_idx = self.select_relevant_agents(agent_states, agent_types, actions, rtgs, goals, origin_agent_idx, 0, moving_agent_mask)
        
        num_polylines = len(road_points)
        if num_polylines == 0 or not valid_scene:
            d = MotionData({})
            no_road_feats = True 
        else:
            # agent_states, road_points, road_types, goals = self.normalize_scene(agent_states, road_points, road_types, goals, new_origin_agent_idx)
            agent_states_past = agent_states[:, :self.cfg.input_horizon]
            agent_states_future = agent_states[:, self.cfg.input_horizon:]
            road_points, road_types = self.select_indiv_agent_roads(agent_states_past, road_points, road_types)

            yaw = agent_states_past[:, -1, 4].copy()
            angle_of_rotation = (np.pi / 2) + np.sign(-yaw) * np.abs(yaw)
            translation = agent_states_past[:, -1, :2].copy()
            translation_yaws = np.concatenate((translation, angle_of_rotation[:, None]), axis=-1)
            
            agent_states_past, agent_states_future, road_points, goals =\
                self._get_agents_local_frame(agent_states_past, agent_states_future, road_points, goals)
            
            past_relative_encoding = self._prepare_relative_encodings(agent_states_past, agent_states_past[:, -1:, :])
            if self.cfg.future_relative_encoding:
                if self.mode == 'train':
                    future_relative_encoding = self._prepare_relative_encodings(agent_states_future, agent_states_past[:, -1:, :])
                else:
                    # use rel_ag_states_future here to compare
                    cv_future_states = self._get_constant_velocity_futures(agent_states_past[:, -1])
                    future_relative_encoding = self._prepare_relative_encodings(cv_future_states, agent_states_past[:, -1:, :])
            else:
                future_relative_encoding = past_relative_encoding[:, :, -1:].repeat(agent_states_future.shape[1], axis=2)

            # remove global coodinates
            agent_states_past = np.concatenate((agent_states_past[:, :, 0:5], agent_states_past[:, :, 10:]), axis=-1)
            agent_states_future = np.concatenate((agent_states_future[:, :, 0:5], agent_states_future[:, :, -1:]), axis=-1)

            # apply normalization for diffusion, approx around 0.0
            agent_states_past[:, :, :2] /= self.cfg.state_normalizer.pos_div
            agent_states_past[:, :, 2:4] /= self.cfg.state_normalizer.vel_div
            agent_states_future[:, :, :2] /= self.cfg.state_normalizer.pos_div
            agent_states_future[:, :, 2:4] /= self.cfg.state_normalizer.vel_div
            goals[:, :2] /= self.cfg.state_normalizer.pos_div
            goals[:, 2:4] /= self.cfg.state_normalizer.vel_div
            road_points[:, :, :, :2] /= self.cfg.state_normalizer.pos_div

            actions = self._normalize_actions(actions)
            agent_actions_past = actions[:, :self.cfg.input_horizon]
            agent_actions_future = actions[:, self.cfg.input_horizon:]

            rtgs = self.discretize_rtgs(rtgs)
            rtgs = rtgs[:, :self.cfg.input_horizon]

            d = dict()
            d['idx'] = idx
            # need to add batch dim as pytorch_geometric batches along first dimension of torch Tensors
            d['agent'] = from_numpy({
                'agent_past_states': add_batch_dim(agent_states_past),
                'agent_past_actions': add_batch_dim(agent_actions_past),
                'agent_future_states': add_batch_dim(agent_states_future),
                'agent_future_actions': add_batch_dim(agent_actions_future),
                'past_relative_encodings': add_batch_dim(past_relative_encoding),
                'future_relative_encodings': add_batch_dim(future_relative_encoding),
                'agent_types': add_batch_dim(agent_types), 
                'goals': add_batch_dim(goals),
                'rtgs': add_batch_dim(rtgs),
                'timesteps': add_batch_dim(timesteps),
                'moving_agent_mask': add_batch_dim(moving_agent_mask),
                'agent_translation_yaws': add_batch_dim(translation_yaws)
            })
            d['map'] = from_numpy({
                'road_points': add_batch_dim(road_points),
                'road_types': add_batch_dim(road_types)
            })
            d = MotionData(d)
            no_road_feats = False

        return d, no_road_feats
        

    def get(self, idx: int):
        # search for file with at least 2 agents
        if not self.cfg.preprocess:
            proceed = False 
            while not proceed:
                with open(self.files[idx], 'r') as file:
                    data = json.load(file)
                    if len(data['objects']) == 1:
                        idx += 1
                    else:
                        proceed = True 
            
            d, no_road_feats = self.get_data(data, idx)

        else:
            proceed = False 
            while not proceed:
                raw_file_name = os.path.splitext(os.path.basename(self.files[idx]))[0]
                raw_path = os.path.join(self.preprocessed_dir, f'{raw_file_name}.pkl')
                if os.path.exists(raw_path):
                    with open(raw_path, 'rb') as f:
                        data = pickle.load(f)
                    proceed = True
                else:
                    idx += 1

                if proceed:
                    if self.mode =='train':
                        d, no_road_feats = self.get_data(data, idx)
                        # only load sample if it has a map
                        if no_road_feats:
                            proceed = False 
                            idx += 1
                    else:
                        rtgs, road_points, road_types = self.get_data(data, idx)
                        d = {
                            'rtgs': rtgs,
                            'road_points': road_points,
                            'road_types': road_types
                        }
        
        return d

    
    def len(self):
        return self.dset_len

@hydra.main(version_base=None, config_path="/home/ctrl-sim-dev/cfgs/", config_name="config")
def main(cfg):
    dset = RLWaymoDataset(cfg.datasets.rl_waymo_diffusion, split_name='train')
    
    np.random.seed(10)
    random.seed(10)
    torch.manual_seed(10)
    
    dloader = DataLoader(dset, 
               batch_size=64, 
               shuffle=True, 
               num_workers=0,
               pin_memory=True,
               drop_last=True)

    # i = 0
    # for d in tqdm(dloader):
    #     i += 1
    #     break
    
    # actions_all = []
    # for idx in range(10): #range(len(dset)):
    #     print("idx", idx)
    #     if not dset.cfg.preprocess:
    #         proceed = False 
    #         while not proceed:
    #             with open(dset.files[idx], 'r') as file:
    #                 data = json.load(file)
    #                 if len(data['objects']) == 1:
    #                     idx += 1
    #                 else:
    #                     proceed = True 
    #     else:
    #         proceed = False 
    #         while not proceed:
    #             raw_file_name = os.path.splitext(os.path.basename(dset.files[idx]))[0]
    #             raw_path = os.path.join(dset.preprocessed_dir, f'{raw_file_name}.pkl')
    #             if os.path.exists(raw_path):
    #                 with open(raw_path, 'rb') as f:
    #                     data = pickle.load(f)
    #                 proceed = True
    #             else:
    #                 idx += 1
        
    #     d, no_road_feats = dset.get_data(data, idx)
    #     print(d['agent'].agent_past_states.shape)
    #     print(d['agent'].agent_future_states.shape)
    #     print(d['agent'].agent_past_actions.shape)
    #     print(d['agent'].agent_future_actions.shape)
    #     print(d['agent'].goals.shape)
    #     print(d['map'].road_points.shape)
    #     print(d['agent'].past_relative_encodings.shape)
    #     print(d['agent'].future_relative_encodings.shape)
    #     exit()
        
    #     for i in range(24):
    #         final_road_points = d['map'].road_points[0, i]
    #         final_road_types = d['map'].road_types[0, i]

    #         exist = d['agent'].agent_past_states[0, i, :, -1].bool()
    #         exist_future = d['agent'].agent_future_states[0, i, :, -1].bool()
    #         plt.plot(d['agent'].agent_past_states[0, i, :, 0][exist], d['agent'].agent_past_states[0, i, :, 1][exist], color='red')
    #         plt.plot(d['agent'].agent_future_states[0, i, :, 0][exist_future], d['agent'].agent_future_states[0, i, :, 1][exist_future], color='green')
            
    #         for j in range(22):
    #             if exist_future[j]:
    #                 plt.arrow(d['agent'].agent_future_states[0, i, j, 0], d['agent'].agent_future_states[0, i, j, 1], d['agent'].agent_future_states[0, i, j, 0] + d['agent'].agent_future_states[0, i, j, 2] * 0.1, d['agent'].agent_future_states[0, i, j, 1] + d['agent'].agent_future_states[0, i, j, 3] * 0.1, color='purple')
            
            
    #         for r in range(len(final_road_points)):
    #             if final_road_types[r, 3] != 1:
    #                 continue
    #             mask = final_road_points[r, :, 2].bool()
    #             plt.plot(final_road_points[r, :, 0][mask], final_road_points[r, :, 1][mask], color='grey', linewidth=0.5)
            
    #         for r in range(len(final_road_points)):
    #             if final_road_types[r, 2] != 1 and final_road_types[r, 2] != 1:
    #                 continue
    #             mask = final_road_points[r, :, 2].bool()
    #             plt.plot(final_road_points[r, :, 0][mask], final_road_points[r, :, 1][mask], color='lightgray', linewidth=0.3)

    #         plt.scatter(d['agent'].goals[0, i, 0], d['agent'].goals[0, i, 1], color='blue')
            
    #         plt.savefig(f'{i}.png', dpi=500)
    #         plt.clf()
        
    #     exit()

if __name__ == '__main__':
    main()