import os
import sys
import json
import glob
import hydra
import torch_scatter
import torch
import pickle
import random
from tqdm import tqdm

import matplotlib.pyplot as plt
from torch_geometric.data import Dataset
from torch_geometric.loader import DataLoader
torch.set_printoptions(threshold=100000)
import numpy as np
np.set_printoptions(suppress=True)

from utils.data import *
from utils.geometry import apply_se2_transform, angle_sub_tensor

class RLWaymoDatasetFineTuning(Dataset):
    # constants defining reward dimensions
    # 0: position target achieved (0 or 1)
    POS_TARGET_ACHIEVED_REW_IDX = 0
    # 1: heading target achieved (0 or 1)
    HEADING_TARGET_ACHIEVED_REW_IDX = 1
    # 2: speed target achieved (0 or 1)
    SPEED_TARGET_ACHIEVED_REW_IDX = 2
    # 3: position goal reward (shaped)
    POS_GOAL_SHAPED_REW_IDX = 3
    # 4: speed goal reward (shaped)
    SPEED_GOAL_SHAPED_REW_IDX = 4
    # 5: heading goal reward (shaped)
    HEADING_GOAL_SHAPED_REW_IDX = 5
    # 6: veh-veh collision reward (0 or 1)
    VEH_VEH_COLLISION_REW_IDX = 6
    # 7: veh-edge collision reward (0 or 1)
    VEH_EDGE_COLLISION_REW_IDX = 7
    
    def __init__(self, cfg):
        super(RLWaymoDatasetFineTuning, self).__init__()
        self.cfg = cfg
        # these are the offline_rl files
        # TODO: change!!
        self.simulated_files_dir = '/scratch/preprocessed_robust'
        self.simulated_files = glob.glob(self.simulated_files_dir + "/*.pkl")
        self.simulated_files = sorted(self.simulated_files)
        
        self.real_files = glob.glob(os.path.join('/scratch/preprocess', "train_march_new") + "/*.pkl")
        self.real_files = sorted(self.real_files)
        self.sample_real_indices()
        
        self.dset_len = len(self.simulated_files) + len(self.real_indices)
        self.preprocess = self.cfg.preprocess
        
        # retrieve preprocessed data from the correct dataset split of real/simulated data
        self.simulated_preprocessed_dir = self.simulated_files_dir
        self.real_preprocessed_dir = '/scratch/preprocess/train_march_new'

    
    def sample_real_indices(self):
        num_samples = int(len(self.simulated_files) * (self.cfg.replay_ratio / (1 - self.cfg.replay_ratio)))
        self.real_indices = np.random.choice(len(self.real_files), size=num_samples, replace=False)
        print("First 10 indices", self.real_indices[:10])


    def get_roads(self, data):
        roads_data = data['roads']
        num_roads = len(roads_data)
        
        final_roads = []
        final_road_types = []
        road_edge_polylines = []
        for n in range(num_roads):
            curr_road_rawdat = roads_data[n]['geometry']
            if isinstance(curr_road_rawdat, dict):
                # for stop sign, repeat x/y coordinate along the point dimension
                final_roads.append(np.array((curr_road_rawdat['x'], curr_road_rawdat['y'], 1.0)).reshape(1, -1).repeat(self.cfg.max_num_road_pts_per_polyline, 0))
                final_road_types.append(get_road_type_onehot(roads_data[n]['type']))
            else:
                if roads_data[n]['type'] == 'road_edge':
                    polyline = []
                    for p in range(len(curr_road_rawdat)):
                        polyline.append(np.array((curr_road_rawdat[p]['x'], curr_road_rawdat[p]['y'])))
                    road_edge_polylines.append(np.array(polyline))
                
                # either we add points until we run out of points and append zeros
                # or we fill up with points until we reach max limit
                curr_road = []
                for p in range(len(curr_road_rawdat)):
                    curr_road.append(np.array((curr_road_rawdat[p]['x'], curr_road_rawdat[p]['y'], 1.0)))
                    if len(curr_road) == self.cfg.max_num_road_pts_per_polyline:
                        final_roads.append(np.array(curr_road))
                        curr_road = []
                        final_road_types.append(get_road_type_onehot(roads_data[n]['type']))
                if len(curr_road) < self.cfg.max_num_road_pts_per_polyline and len(curr_road) > 0:
                    tmp_curr_road = np.zeros((self.cfg.max_num_road_pts_per_polyline, 3))
                    tmp_curr_road[:len(curr_road)] = np.array(curr_road)
                    final_roads.append(tmp_curr_road)
                    final_road_types.append(get_road_type_onehot(roads_data[n]['type']))

        return np.array(final_roads), np.array(final_road_types), road_edge_polylines


    def extract_rawdata(self, agents_data):
        # Get indices of non-parked cars and cars that exist for the entire episode
        agent_data = []
        agent_types = []
        agent_actions = []
        agent_rewards = []
        if self.cfg.goal_fix:
            agent_goals = []
        parked_agent_ids = []
        incomplete_ids = []
        last_exist_timesteps = []

        for n in range(len(agents_data)):
            ag_position = agents_data[n]['position']
            x_values = [entry['x'] for entry in ag_position]
            y_values = [entry['y'] for entry in ag_position]
            ag_position = np.column_stack((x_values, y_values))
            ag_heading = np.array(agents_data[n]['heading']).reshape((-1, 1))
            ag_velocity = agents_data[n]['velocity']
            x_values = [entry['x'] for entry in ag_velocity]
            y_values = [entry['y'] for entry in ag_velocity]
            ag_velocity = np.column_stack((x_values, y_values))
            # parked vehicle: average velocity < 0.05m/s
            if np.linalg.norm(ag_velocity, axis=-1).mean() < self.cfg.parked_car_velocity_threshold:
                parked_agent_ids.append(n)

            ag_existence = np.array(agents_data[n]['existence']).reshape((-1, 1))
            idx_of_first_disappearance = np.where(ag_existence == 0.0)[0]
            # once we find first missing step, all subsequent steps should be missing (as simulation is now undefined)
            if len(idx_of_first_disappearance) > 0:
                assert np.all(ag_existence[idx_of_first_disappearance[0]:] == 0.0)

            # only one timestep in ground-truth trajectory so no valid timesteps in offline RL dataset
            # since we need at least two timesteps in ground-truth trajectory to define a valid action with inverse Bicycle Model
            if len(idx_of_first_disappearance) > 0 and idx_of_first_disappearance[0] == 0:
                incomplete_ids.append(n)
                idx_of_last_existence = -1
            else:
                # for each agent, get the timestep of last existence
                idx_of_last_existence = np.where(ag_existence == 1.0)[0][-1]
                
            last_exist_timesteps.append(idx_of_last_existence)
            ag_actions = np.column_stack((agents_data[n]['acceleration'], agents_data[n]['steering']))

            ag_length = np.ones((len(ag_position), 1)) * agents_data[n]['length']
            ag_width = np.ones((len(ag_position), 1)) * agents_data[n]['width']
            agent_type = get_object_type_onehot(agents_data[n]['type'])
            # zero out reward for missing timesteps 
            rewards = np.array(agents_data[n]['reward']) * ag_existence 

            if self.cfg.goal_fix:
                goal_position_x = agents_data[n]['goal_position']['x']
                goal_position_y = agents_data[n]['goal_position']['y']
                goal_heading = agents_data[n]['goal_heading']
                goal_speed = agents_data[n]['goal_speed']
                goal_velocity_x = goal_speed * np.cos(goal_heading)
                goal_velocity_y = goal_speed * np.sin(goal_heading)
                goal = np.array([goal_position_x, goal_position_y, goal_velocity_x, goal_velocity_y, goal_heading])
                goal = np.repeat(goal[None, :], len(ag_position), 0)

            ag_state = np.concatenate((ag_position, ag_velocity, ag_heading, ag_length, ag_width, ag_existence), axis=-1)
            agent_data.append(ag_state)
            agent_actions.append(ag_actions)
            agent_rewards.append(rewards)
            agent_types.append(agent_type)
            if self.cfg.goal_fix:
                agent_goals.append(goal)
        
        # convert to numpy array
        agent_data = np.array(agent_data)
        agent_actions = np.array(agent_actions)
        agent_rewards = np.array(agent_rewards)
        agent_types = np.array(agent_types)
        if self.cfg.goal_fix:
            agent_goals = np.array(agent_goals)
        else:
            agent_goals = None
        parked_agent_ids = np.array(parked_agent_ids)
        incomplete_ids = np.array(incomplete_ids)
        last_exist_timesteps = np.array(last_exist_timesteps)
        
        return agent_data, agent_actions, agent_rewards, agent_types, agent_goals, parked_agent_ids, incomplete_ids, last_exist_timesteps


    def compute_dist_to_nearest_road_edge_rewards(self, ag_data, road_edge_polylines):
        # get all road edge polylines
        dist_to_road_edge_rewards = []
        for n in range(len(ag_data)):
            dist_to_road_edge = compute_distance_to_road_edge(ag_data[n, :, 0].reshape(1, -1),
                                                                ag_data[n, :, 1].reshape(1, -1), road_edge_polylines)
            dist_to_road_edge_rewards.append(-dist_to_road_edge / self.cfg.dist_to_road_edge_scaling_factor)
        
        dist_to_road_edge_rewards = np.array(dist_to_road_edge_rewards)
        
        return dist_to_road_edge_rewards

    
    def compute_dist_to_nearest_road_edge_rewards_gpu(self, ag_data, road_edge_polylines):
        ag_data = from_numpy(ag_data).cuda()
        
        # get all road edge polylines
        dist_to_road_edge_rewards = []
        for n in range(len(ag_data)):
            dist_to_road_edge = compute_distance_to_road_edge_gpu(ag_data[n, :, 0].reshape(1, -1),
                                                                ag_data[n, :, 1].reshape(1, -1), road_edge_polylines).cpu().numpy()
            dist_to_road_edge_rewards.append(-dist_to_road_edge / self.cfg.dist_to_road_edge_scaling_factor)
        
        dist_to_road_edge_rewards = np.array(dist_to_road_edge_rewards)
        
        return dist_to_road_edge_rewards


    def compute_dist_to_nearest_vehicle_rewards(self, ag_data, normalize=True):
        num_timesteps = ag_data.shape[1]
        
        ag_positions = ag_data[:,:,:2]
        ag_existence = ag_data[:,:,-1]

        # set x/y position at each nonexisting timestep to np.inf
        mask = np.repeat(ag_existence[:,:,np.newaxis], repeats=2, axis=-1).astype(bool)
        ag_positions[~mask] = np.inf

        # data[:, np.newaxis] has shape (A, 1, 90, 2) and data[np.newaxis, :] has shape (1, A, 90, 2)
        # Subtracting these gives an array of shape (A, A, 90, 2) with pairwise differences
        diff = ag_positions[:, np.newaxis] - ag_positions[np.newaxis, :]
        squared_dist = np.sum(diff**2, axis=-1)

        # Replace zero distances (distance to self) with np.inf
        for i in range(num_timesteps):
            np.fill_diagonal(squared_dist[:,:,i], np.inf)

        # Find minimum distance for each agent at each timestep, shape (A, 90)
        dist_nearest_vehicle = np.sqrt(np.min(squared_dist, axis=1))
        # handles case when only one valid agent at specific timestep
        dist_nearest_vehicle[dist_nearest_vehicle == np.inf] = np.nan 
        
        # if dist > 15, give 15
        if normalize:
            dist_nearest_vehicle = np.clip(dist_nearest_vehicle * ag_existence, a_min=0.0, a_max=self.cfg.max_veh_veh_distance)
            # given that every reward is in [0, 15], we will normalize this to be between [0, 1] by simply dividing by 15.0
            dist_nearest_vehicle = dist_nearest_vehicle / self.cfg.max_veh_veh_distance
        else:
            dist_nearest_vehicle = dist_nearest_vehicle * ag_existence

        # set reward to 0 when undefined
        dist_nearest_vehicle = np.nan_to_num(dist_nearest_vehicle, nan=0.0)
        
        return dist_nearest_vehicle

    
    def compute_rewards(self, ag_data, ag_rewards, veh_edge_dist_rewards, veh_veh_dist_rewards):
        ag_existence = ag_data[:, :, -1:]

        processed_rewards = np.array(ag_rewards)
        if self.cfg.remove_shaped_goal:
            goal_pos_rewards = processed_rewards[:, :, self.POS_TARGET_ACHIEVED_REW_IDX] * self.cfg.pos_target_achieved_rew_multiplier
        elif self.cfg.only_shaped_goal:
            goal_pos_rewards = (np.clip(processed_rewards[:, :, self.POS_GOAL_SHAPED_REW_IDX], a_min=self.cfg.pos_goal_shaped_min, a_max=self.cfg.pos_goal_shaped_max) - self.cfg.pos_goal_shaped_max) * (1 / self.cfg.pos_goal_shaped_max)
        else:
            goal_pos_rewards = processed_rewards[:, :, self.POS_TARGET_ACHIEVED_REW_IDX] * self.cfg.pos_target_achieved_rew_multiplier \
                 + (np.clip(processed_rewards[:, :, self.POS_GOAL_SHAPED_REW_IDX], a_min=self.cfg.pos_goal_shaped_min, a_max=self.cfg.pos_goal_shaped_max) - self.cfg.pos_goal_shaped_max) * (1 / self.cfg.pos_goal_shaped_max)
        goal_pos_rewards = goal_pos_rewards[:, :, np.newaxis] * ag_existence

        goal_heading_rewards = processed_rewards[:, :, self.HEADING_TARGET_ACHIEVED_REW_IDX] + processed_rewards[:, :, self.HEADING_GOAL_SHAPED_REW_IDX]
        goal_heading_rewards = goal_heading_rewards[:, :, np.newaxis] * ag_existence

        goal_velocity_rewards = processed_rewards[:, :, self.SPEED_TARGET_ACHIEVED_REW_IDX] + processed_rewards[:, :, self.SPEED_GOAL_SHAPED_REW_IDX]
        goal_velocity_rewards = goal_velocity_rewards[:, :, np.newaxis] * ag_existence
        
        if self.cfg.remove_shaped_veh_reward:
            veh_veh_collision_rewards = -1 * processed_rewards[:, :, self.VEH_VEH_COLLISION_REW_IDX] * self.cfg.veh_veh_collision_rew_multiplier
        elif self.cfg.only_shaped_veh_reward:
            veh_veh_collision_rewards = veh_veh_dist_rewards
        else:
            veh_veh_collision_rewards = veh_veh_dist_rewards - \
                processed_rewards[:, :, self.VEH_VEH_COLLISION_REW_IDX] * self.cfg.veh_veh_collision_rew_multiplier

        veh_veh_collision_rewards = veh_veh_collision_rewards[:, :, np.newaxis] * ag_existence
        
        if self.cfg.remove_shaped_edge_reward:
            veh_edge_collision_rewards = -1 * processed_rewards[:, :, self.VEH_EDGE_COLLISION_REW_IDX] * self.cfg.veh_edge_collision_rew_multiplier
        elif self.cfg.only_shaped_edge_reward:
            veh_edge_collision_rewards = np.clip(np.abs(veh_edge_dist_rewards) * self.cfg.dist_to_road_edge_scaling_factor, a_min=0, a_max=5) / 5.
        else:
            veh_edge_collision_rewards = np.clip(np.abs(veh_edge_dist_rewards) * self.cfg.dist_to_road_edge_scaling_factor, a_min=0, a_max=5) / 5. - \
                processed_rewards[:, :, self.VEH_EDGE_COLLISION_REW_IDX] * self.cfg.veh_edge_collision_rew_multiplier

        veh_edge_collision_rewards = veh_edge_collision_rewards[:, :, np.newaxis] * ag_existence

        all_rewards = np.concatenate((goal_pos_rewards, goal_heading_rewards, goal_velocity_rewards,
                                    veh_veh_collision_rewards, veh_edge_collision_rewards), axis=-1)
        return all_rewards

    
    def select_relevant_agents(self, agent_states, agent_types, actions, rtgs, goals, origin_agent_idx, timestep, moving_agent_mask):
        origin_states = agent_states[origin_agent_idx, timestep, :2].reshape(1, -1)
        dist_to_origin = np.linalg.norm(origin_states - agent_states[:, timestep, :2], axis=-1)
        valid_agents = np.where(dist_to_origin < self.cfg.agent_dist_threshold)[0]
    
        final_agent_states = np.zeros((self.cfg.max_num_agents, *agent_states[0].shape))
        final_agent_types = -np.ones((self.cfg.max_num_agents, *agent_types[0].shape))
        final_actions = np.zeros((self.cfg.max_num_agents, *actions[0].shape))
        final_rtgs = np.zeros((self.cfg.max_num_agents, *rtgs[0].shape))
        final_goals = np.zeros((self.cfg.max_num_agents, *goals[0].shape))
        final_moving_agent_mask = np.zeros(self.cfg.max_num_agents)

        closest_ag_ids = np.argsort(dist_to_origin)[:self.cfg.max_num_agents]
        closest_ag_ids = np.intersect1d(closest_ag_ids, valid_agents)
        # shuffle ids so it is not ordered by distance
        np.random.shuffle(closest_ag_ids)
        
        final_agent_states[:len(closest_ag_ids)] = agent_states[closest_ag_ids]
        final_agent_types[:len(closest_ag_ids)] = agent_types[closest_ag_ids]
        final_actions[:len(closest_ag_ids)] = actions[closest_ag_ids]
        final_rtgs[:len(closest_ag_ids)] = rtgs[closest_ag_ids]
        final_goals[:len(closest_ag_ids)] = goals[closest_ag_ids]
        final_moving_agent_mask[:len(closest_ag_ids)] = moving_agent_mask[closest_ag_ids]

        # idx of origin agent in new state tensors
        new_origin_agent_idx = np.where(closest_ag_ids == origin_agent_idx)[0][0]
        
        return final_agent_states, final_agent_types, final_actions, final_rtgs, final_goals, final_moving_agent_mask, new_origin_agent_idx

    
    def select_relevant_agents_new(self, agent_states, agent_types, actions, rtgs, goals, origin_agent_idx, timestep, moving_agent_mask, relevant_agent_idxs):
        origin_states = agent_states[origin_agent_idx, timestep, :2].reshape(1, -1)
        dist_to_origin = np.linalg.norm(origin_states - agent_states[:, timestep, :2], axis=-1)
        valid_agents = np.where(dist_to_origin < self.cfg.agent_dist_threshold)[0]
    
        final_agent_states = np.zeros((self.cfg.max_num_agents, *agent_states[0].shape))
        final_agent_types = -np.ones((self.cfg.max_num_agents, *agent_types[0].shape))
        final_actions = np.zeros((self.cfg.max_num_agents, *actions[0].shape))
        final_rtgs = np.zeros((self.cfg.max_num_agents, *rtgs[0].shape))
        final_goals = np.zeros((self.cfg.max_num_agents, *goals[0].shape))
        final_moving_agent_mask = np.zeros(self.cfg.max_num_agents)

        if len(relevant_agent_idxs) == 0:
            closest_ag_ids = np.argsort(dist_to_origin)[:self.cfg.max_num_agents]
            closest_ag_ids = np.intersect1d(closest_ag_ids, valid_agents)
            # shuffle ids so it is not ordered by distance
            np.random.shuffle(closest_ag_ids)
        else:
            closest_ag_ids = np.array(relevant_agent_idxs).astype(int)
            closest_ag_ids = np.intersect1d(closest_ag_ids, valid_agents)
            
            if len(closest_ag_ids) < len(relevant_agent_idxs):
                out_of_range_vehicles = np.setdiff1d(relevant_agent_idxs, closest_ag_ids)
                relevant_agent_idxs = [idx for idx in relevant_agent_idxs if idx not in out_of_range_vehicles]
        
        final_agent_states[:len(closest_ag_ids)] = agent_states[closest_ag_ids]
        final_agent_types[:len(closest_ag_ids)] = agent_types[closest_ag_ids]
        final_actions[:len(closest_ag_ids)] = actions[closest_ag_ids]
        final_rtgs[:len(closest_ag_ids)] = rtgs[closest_ag_ids]
        final_goals[:len(closest_ag_ids)] = goals[closest_ag_ids]
        final_moving_agent_mask[:len(closest_ag_ids)] = moving_agent_mask[closest_ag_ids]

        # idx of origin agent in new state tensors
        new_agent_idx_dict = {}
        for new_idx, old_idx in enumerate(closest_ag_ids):
            new_agent_idx_dict[old_idx] = new_idx
        new_origin_agent_idx = np.where(closest_ag_ids == origin_agent_idx)[0][0]
        
        return final_agent_states, final_agent_types, final_actions, final_rtgs, final_goals, final_moving_agent_mask, new_agent_idx_dict, relevant_agent_idxs


    def select_random_origin_agent(self, agent_states, moving_mask, focal_agent_idx):
        
        # origin_idx for the simulated scene is the focal sim controlled agent
        origin_idx = None
        if self.cfg.center_on_focal_agent and focal_agent_idx:
            origin_idx = focal_agent_idx
        # search for moving agent that exists at first timestep
        else:
            valid_idxs = np.where((agent_states[:, 0, -1] == 1) * moving_mask)[0]
            rand_idx = np.random.choice(len(valid_idxs))
            origin_idx = valid_idxs[rand_idx]

        return origin_idx


    def undiscretize_actions(self, actions):
        # Initialize the array for the continuous actions
        actions_shape = (actions.shape[0], actions.shape[1], 2)
        continuous_actions = np.zeros(actions_shape)
        
        # Separate the combined actions back into their discretized components
        continuous_actions[:, :, 0] = actions // self.cfg.steer_discretization  # Acceleration component
        continuous_actions[:, :, 1] = actions % self.cfg.steer_discretization   # Steering component
        
        # Reverse the discretization
        continuous_actions[:, :, 0] /= (self.cfg.accel_discretization - 1)
        continuous_actions[:, :, 1] /= (self.cfg.steer_discretization - 1)
        
        # Denormalize to get back the original continuous values
        continuous_actions[:, :, 0] = (continuous_actions[:, :, 0] * (self.cfg.max_accel - self.cfg.min_accel)) + self.cfg.min_accel
        continuous_actions[:, :, 1] = (continuous_actions[:, :, 1] * (self.cfg.max_steer - self.cfg.min_steer)) + self.cfg.min_steer
        
        return continuous_actions

    
    def get_tilt_logits(self, goal_tilt, veh_tilt, road_tilt):
        if self.cfg.use_veh_edge_rtg:
            rtg_bin_values = np.zeros((self.cfg.rtg_discretization, 3))
        else:
            rtg_bin_values = np.zeros((self.cfg.rtg_discretization, 2))
        
        rtg_bin_values[:, 0] = goal_tilt * np.linspace(0, 1, self.cfg.rtg_discretization)
        rtg_bin_values[:, 1] = veh_tilt * np.linspace(0, 1, self.cfg.rtg_discretization)

        if self.cfg.use_veh_edge_rtg:
            rtg_bin_values[:, 2] = road_tilt * np.linspace(0, 1, self.cfg.rtg_discretization)

        return rtg_bin_values


    def undiscretize_rtgs(self, rtgs):
        continuous_rtgs = np.zeros_like(rtgs).astype(float)
        continuous_rtgs[:, :, 0] = rtgs[:, :, 0] / (self.cfg.rtg_discretization - 1)
        continuous_rtgs[:, :, 1] = rtgs[:, :, 1] / (self.cfg.rtg_discretization - 1)

        continuous_rtgs[:, :, 0] = (continuous_rtgs[:, :, 0] * (self.cfg.max_rtg_pos - self.cfg.min_rtg_pos)) + self.cfg.min_rtg_pos
        continuous_rtgs[:, :, 1] = (continuous_rtgs[:, :, 1] * (self.cfg.max_rtg_veh - self.cfg.min_rtg_veh)) + self.cfg.min_rtg_veh

        if self.cfg.use_veh_edge_rtg:
            continuous_rtgs[:, :, 2] = rtgs[:, :, 2] / (self.cfg.rtg_discretization - 1)
            continuous_rtgs[:, :, 2] = (continuous_rtgs[:, :, 2] * (self.cfg.max_rtg_road - self.cfg.min_rtg_road)) + self.cfg.min_rtg_road


        return continuous_rtgs

    
    def discretize_actions(self, actions):
        # normalize
        actions[:, :, 0] = ((np.clip(actions[:, :, 0], a_min=self.cfg.min_accel, a_max=self.cfg.max_accel) - self.cfg.min_accel)
                             / (self.cfg.max_accel - self.cfg.min_accel))
        actions[:, :, 1] = ((np.clip(actions[:, :, 1], a_min=self.cfg.min_steer, a_max=self.cfg.max_steer) - self.cfg.min_steer)
                             / (self.cfg.max_steer - self.cfg.min_steer))

        # discretize the actions
        actions[:, :, 0] = np.round(actions[:, :, 0] * (self.cfg.accel_discretization - 1))
        actions[:, :, 1] = np.round(actions[:, :, 1] * (self.cfg.steer_discretization - 1))

        # combine into a single categorical value
        combined_actions = actions[:, :, 0] * self.cfg.steer_discretization + actions[:, :, 1]

        return combined_actions

    
    def discretize_rtgs(self, rtgs):
        rtgs[:, :, 0] = np.round(rtgs[:, :, 0] * (self.cfg.rtg_discretization - 1))
        rtgs[:, :, 1] = np.round(rtgs[:, :, 1] * (self.cfg.rtg_discretization - 1))

        if self.cfg.use_veh_edge_rtg:
            rtgs[:, :, 2] = np.round(rtgs[:, :, 2] * (self.cfg.rtg_discretization - 1))

        return rtgs

    
    def normalize_scene(self, agent_states, road_points, road_types, goals, origin_agent_idx):
        # normalize scene to ego vehicle (this includes agent states, goals, and roads)
        yaw = agent_states[origin_agent_idx, 0, 4]
        angle_of_rotation = (np.pi / 2) + np.sign(-yaw) * np.abs(yaw)
        translation = agent_states[origin_agent_idx, 0, :2].copy()
        translation = translation[np.newaxis, np.newaxis, :]

        agent_states[:, :, :2] = apply_se2_transform(coordinates=agent_states[:, :, :2],
                                           translation=translation,
                                           yaw=angle_of_rotation)
        agent_states[:, :, 2:4] = apply_se2_transform(coordinates=agent_states[:, :, 2:4],
                                           translation=np.zeros_like(translation),
                                           yaw=angle_of_rotation)
        agent_states[:, :, 4] = angle_sub_tensor(agent_states[:, :, 4], -angle_of_rotation.reshape(1, 1))
        assert np.all(agent_states[:, :, 4] <= np.pi) and np.all(agent_states[:, :, 4] >= -np.pi)
        goals[:, :2] = apply_se2_transform(coordinates=goals[:, :2],
                                 translation=translation[:, 0],
                                 yaw=angle_of_rotation)
        if self.cfg.goal_dim == 5:
            goals[:, 2:4] = apply_se2_transform(coordinates=goals[:, 2:4],
                                    translation=np.zeros_like(translation[:, 0]),
                                    yaw=angle_of_rotation)
            goals[:, 4] = angle_sub_tensor(goals[:, 4], -angle_of_rotation.reshape(1))
        road_points[:, :, :2] = apply_se2_transform(coordinates=road_points[:, :, :2],
                                                    translation=translation,
                                                    yaw=angle_of_rotation)

        if len(road_points) > self.cfg.max_num_road_polylines:
            max_road_dist_to_orig = (np.linalg.norm(road_points[:, :, :2], axis=-1) * road_points[:, :, -1]).max(1)
            closest_roads_to_ego = np.argsort(max_road_dist_to_orig)[:self.cfg.max_num_road_polylines]
            final_road_points = road_points[closest_roads_to_ego]
            final_road_types = road_types[closest_roads_to_ego]
        else:
            final_road_points = np.zeros((self.cfg.max_num_road_polylines, *road_points.shape[1:]))
            final_road_points[:len(road_points)] = road_points
            final_road_types = -np.ones((self.cfg.max_num_road_polylines, road_types.shape[1]))
            final_road_types[:len(road_points)] = road_types

        ### TESTING WITH VISUALIZATION ######
        # for r in range(len(final_road_points)):
        #     mask = final_road_points[r, :, 2].astype(bool)
        #     plt.plot(final_road_points[r, :, 0][mask], final_road_points[r, :, 1][mask], color='black')
        
        # coordinates = agent_states[:, :, :2]
        # coordinates_mask = agent_states[:, :, -1].astype(bool)
        # for a in range(len(coordinates)):
        #     plt.plot(coordinates[a, :, 0][coordinates_mask[a]], coordinates[a, :, 1][coordinates_mask[a]], color='blue')

        # for a in range(len(goals)):
        #     if np.abs(goals[a, 0]) > 1000 or np.abs(goals[a, 1]) > 1000:
        #         continue
        #     plt.scatter(goals[a, 0], goals[a, 1], color='red', s=10)
        
        # plt.savefig('test.png', dpi=250)
        # plt.clf()
        ####
        
        return agent_states, final_road_points, final_road_types, goals


    def get_distance_to_road_edge(self, agent_states, road_feats):
        road_type = np.argmax(road_feats[:, 4:12], axis=1).astype(int)
        mask = road_type == 3
        road_feats = road_feats[mask]
        road_points = np.concatenate([road_feats[:, :2], road_feats[:, 2:4]], axis=0)
        agent_positions = agent_states[:, :, :2].reshape(-1, 2)
        # Compute the difference along each dimension [N, M, 2]
        diff = agent_positions[:, np.newaxis, :] - road_points[np.newaxis, :, :]
        # Compute the squared distances [N, M]
        squared_distances = np.sum(diff ** 2, axis=2)
        # Find the minimum squared distance for each point in array1 [N,]
        min_squared_distances = np.min(squared_distances, axis=1)
        # If you need the actual distances, take the square root
        min_distances = np.sqrt(min_squared_distances)

        min_distances = min_distances.reshape(24, 33)
        return min_distances

    
    def get_data(self, data, idx):
        if self.preprocess:
            idx = data['idx']
            num_agents = data['num_agents']
            road_points = data['road_points']
            road_types = data['road_types']
            ag_data = data['ag_data']
            ag_actions = data['ag_actions']
            ag_types = data['ag_types']
            last_exist_timesteps = data['last_exist_timesteps']
            if not self.cfg.goal_fix:
                rtgs = data['rtgs']
            else:
                ag_rewards = data['ag_rewards']
                veh_edge_dist_rewards = data['veh_edge_dist_rewards']
                veh_veh_dist_rewards = data['veh_veh_dist_rewards']
            filtered_ag_ids = data['filtered_ag_ids']
            if self.cfg.goal_fix:
                ag_goals = data['ag_goals']

            if 'focal_agent_idx' in data.keys():
                focal_agent_idx = data['focal_agent_idx']
            else:
                focal_agent_idx = None
            
        else:
            agent_data = data['objects']
            num_agents = len(agent_data)

            road_points, road_types, road_edge_polylines = self.get_roads(data)
            ag_data, ag_actions, ag_rewards, ag_types, ag_goals, parked_ids, incomplete_ids, last_exist_timesteps = self.extract_rawdata(agent_data)
            # zero out reward when timestep does not exist
            veh_edge_dist_rewards = self.compute_dist_to_nearest_road_edge_rewards(ag_data.copy(), road_edge_polylines) * ag_data[:, :, -1]
            veh_veh_dist_rewards = self.compute_dist_to_nearest_vehicle_rewards(ag_data.copy()) * ag_data[:, :, -1]
            if not self.cfg.goal_fix:
                all_rewards = self.compute_rewards(ag_data, ag_rewards, veh_edge_dist_rewards, veh_veh_dist_rewards)
                rtgs = np.cumsum(all_rewards[:, ::-1], axis=1)[:, ::-1]

            raw_ag_ids = np.arange(num_agents).tolist()
            # no point in training on vehicles that don't have even one valid timestep
            filtered_ag_ids = list(filter(lambda x: x not in incomplete_ids, raw_ag_ids))
            assert len(filtered_ag_ids) > 0
            
            raw_file_name = os.path.splitext(os.path.basename(self.files[idx]))[0]
            to_pickle = dict()
            to_pickle['idx'] = idx
            to_pickle['num_agents'] = num_agents 
            to_pickle['road_points'] = road_points
            to_pickle['road_types'] = road_types
            to_pickle['ag_data'] = ag_data 
            to_pickle['ag_actions'] = ag_actions 
            to_pickle['ag_types'] = ag_types 
            to_pickle['last_exist_timesteps'] = last_exist_timesteps 
            if not self.cfg.goal_fix:
                to_pickle['rtgs'] = rtgs 
            else:
                to_pickle['veh_edge_dist_rewards'] = veh_edge_dist_rewards
                to_pickle['veh_veh_dist_rewards'] = veh_veh_dist_rewards 
                to_pickle['ag_rewards'] = ag_rewards
            to_pickle['filtered_ag_ids'] = filtered_ag_ids
            if self.cfg.preprocess_simulated_data:
                to_pickle['focal_agent_idx'] = data['focal_agent_idx']

            if self.cfg.goal_fix:
                assert ag_goals is not None
                to_pickle['ag_goals'] = ag_goals
            
            with open(os.path.join(self.preprocessed_dir, f'{raw_file_name}.pkl'), 'wb') as f:
                pickle.dump(to_pickle, f, protocol=pickle.HIGHEST_PROTOCOL)

        # we are only preprocessing the simulated data here
        if self.cfg.preprocess_simulated_data:
            return

        ##### VEH_EDGE_REWARD TESTING WITH VISUALIZATION ######
        # for r in range(len(road_edge_polylines)):
        #     # only plot if it is a road edge
        #     plt.plot(road_edge_polylines[r][:, 0], road_edge_polylines[r][:, 1], color='black', linewidth=0.5)
        
        # coordinates = ag_data[:, :, :2]
        # coordinates_mask = ag_data[:, :, -1].astype(bool)
        # for a in range(1):
        #     plt.plot(coordinates[a, :, 0][coordinates_mask[a]], coordinates[a, :, 1][coordinates_mask[a]], color='blue')
        
        # for a in range(1):
        #     if np.sum(coordinates_mask[a])-1 < 0:
        #         continue
        #     plt.scatter(coordinates[a, np.sum(coordinates_mask[a])-1, 0], coordinates[a, np.sum(coordinates_mask[a])-1, 1], color='red', s=10)
        
        # plt.savefig('test_{}.png'.format(idx), dpi=250)
        # plt.clf()
        # print("{}".format(idx), veh_edge_dist_rewards[0][0::10] * 15)
        # print("{}".format(idx), ag_rewards[0, :, self.VEH_EDGE_COLLISION_REW_IDX][0::10])
        #####


        ###### VEH_VEH_REWARD TESTING WITH VISUALIZATION ######
        # for r in range(len(road_edge_polylines)):
        #     # only plot if it is a road edge
        #     plt.plot(road_edge_polylines[r][:, 0], road_edge_polylines[r][:, 1], color='black', linewidth=0.5)

        # coordinates = ag_data[:, :, :2]
        # coordinates_mask = ag_data[:, :, -1].astype(bool)
        # for a in range(1):
        #     plt.plot(coordinates[a, :, 0][coordinates_mask[a]], coordinates[a, :, 1][coordinates_mask[a]], color='blue')
        # for a in range(1, len(coordinates)):
        #     plt.plot(coordinates[a, :, 0][coordinates_mask[a]], coordinates[a, :, 1][coordinates_mask[a]], color='green')
        
        # for a in range(1):
        #     if np.sum(coordinates_mask[a])-1 < 0:
        #         continue
        #     plt.scatter(coordinates[a, np.sum(coordinates_mask[a])-1, 0], coordinates[a, np.sum(coordinates_mask[a])-1, 1], color='red', s=10)
        
        # plt.savefig('test_{}.png'.format(idx), dpi=250)
        # plt.clf()
        # print("{}".format(idx), veh_veh_dist_rewards[0][0::10] * 15)
        # print("{}".format(idx), ag_rewards[0, :, self.VEH_VEH_COLLISION_REW_IDX][0::10])
        ######

        if self.cfg.goal_fix:
            all_rewards = self.compute_rewards(ag_data, ag_rewards, veh_edge_dist_rewards, veh_veh_dist_rewards)
            rtgs = np.cumsum(all_rewards[:, ::-1], axis=1)[:, ::-1]

        if self.cfg.use_veh_edge_rtg:
            rtgs = np.concatenate([rtgs[:, :, :1], rtgs[:, :, 3:5]], axis=2)
        else:
            rtgs = np.concatenate([rtgs[:, :, :1], rtgs[:, :, 3:4]], axis=2)
        rtgs[:, :, 0] = ((np.clip(rtgs[:, :, 0], a_min=self.cfg.min_rtg_pos, a_max=self.cfg.max_rtg_pos) - self.cfg.min_rtg_pos)
                             / (self.cfg.max_rtg_pos - self.cfg.min_rtg_pos))
        rtgs[:, :, 1] = ((np.clip(rtgs[:, :, 1], a_min=self.cfg.min_rtg_veh, a_max=self.cfg.max_rtg_veh) - self.cfg.min_rtg_veh)
                             / (self.cfg.max_rtg_veh - self.cfg.min_rtg_veh))
        if self.cfg.use_veh_edge_rtg:
            rtgs[:, :, 2] = ((np.clip(rtgs[:, :, 2], a_min=self.cfg.min_rtg_road, a_max=self.cfg.max_rtg_road) - self.cfg.min_rtg_road)
                             / (self.cfg.max_rtg_road - self.cfg.min_rtg_road))
        
        if not self.cfg.goal_fix:
            goal_timesteps = np.sum(ag_data[:, :, -1], axis=-1) - 1
            goal_timesteps_repeat = np.repeat(goal_timesteps[:, np.newaxis, np.newaxis], 5, axis=2)
            goals = np.take_along_axis(ag_data[:, :, :5], goal_timesteps_repeat.astype(int), axis=1)
            moving_ids = np.where(np.linalg.norm(ag_data[:, 0, :2] - goals[:, 0, :2], axis=1) > self.cfg.moving_threshold)[0]
            goals = goals[filtered_ag_ids, 0]
        else:
            goals = ag_goals
            moving_ids = np.where(np.linalg.norm(ag_data[:, 0, :2] - goals[:, 0, :2], axis=1) > self.cfg.moving_threshold)[0]
            goals = ag_goals[filtered_ag_ids, 0]
        
        # find max timestep to set as present timestep such that there exists an agent with train_context_length future timesteps
        max_timestep = np.max(last_exist_timesteps[moving_ids]) - (self.cfg.train_context_length - 1)
        # In this case, there will be some future timesteps such that all agents do not exist
        if max_timestep < 0:
            max_timestep = 0
        origin_t = np.random.randint(0, max_timestep+1)

        # original_idxs[i] tells you the original idx before filtering
        if focal_agent_idx is not None:
            assert len(filtered_ag_ids) == ag_data.shape[0]
        
        timesteps = np.arange(self.cfg.train_context_length) + origin_t
        timesteps = np.repeat(timesteps[np.newaxis, :, np.newaxis], self.cfg.max_num_agents, 0)
        agent_states = ag_data[filtered_ag_ids, origin_t:origin_t+self.cfg.train_context_length]
        agent_types = ag_types[filtered_ag_ids]
        actions = ag_actions[filtered_ag_ids, origin_t:origin_t+self.cfg.train_context_length]
        rtgs = rtgs[filtered_ag_ids, origin_t:origin_t+self.cfg.train_context_length]

        # filter for agents that move at least 0.05 metres
        moving_agent_mask = np.isin(filtered_ag_ids, moving_ids)
        # randomly choose moving agent to be at origin
        origin_agent_idx = self.select_random_origin_agent(agent_states, moving_agent_mask, focal_agent_idx)
        
        if self.cfg.supervise_focal_agent:
            moving_agent_mask = np.zeros_like(moving_agent_mask)
            moving_agent_mask[focal_agent_idx] = 1
            moving_agent_mask = moving_agent_mask.astype(bool)

        agent_states, agent_types, actions, rtgs, goals, moving_agent_mask, new_origin_agent_idx = self.select_relevant_agents(agent_states, agent_types, actions, rtgs, goals, origin_agent_idx, 0, moving_agent_mask)
        actions = self.discretize_actions(actions)
        if not self.cfg.decision_transformer:
            rtgs = self.discretize_rtgs(rtgs)
        
        ##### VISUALIZATION FOR TESTING
        # if focal_agent_idx:
        #     for r in range(len(road_points)):
        #         # only plot if it is a road edge
        #         plt.plot(road_points[r, :, 0][road_points[r, :, -1].astype(bool)], road_points[r, :, 1][road_points[r, :, -1].astype(bool)], color='black', linewidth=0.5)
            
        #     coordinates = agent_states[:, :, :2]
        #     coordinates_mask = agent_states[:, :, -1].astype(bool)
        #     for a in range(len(coordinates)):
        #         if a == new_origin_agent_idx:
        #             color='purple'
        #         else:
        #             color='blue'
        #         plt.plot(coordinates[a, :, 0][coordinates_mask[a]], coordinates[a, :, 1][coordinates_mask[a]], color=color)
            
        #     for a in range(len(coordinates)):
        #         if a == new_origin_agent_idx:
        #             color='red'
        #         else:
        #             color='pink'
        #         if np.sum(coordinates_mask[a])-1 < 0:
        #             continue
        #         plt.scatter(goals[a, 0], goals[a, 1], color=color, s=10)

        #     plt.savefig('test_{}.png'.format(idx), dpi=250)
        #     plt.clf()
        #####
        
        num_polylines = len(road_points)
        if num_polylines == 0:
            d = MotionData({})
            no_road_feats = True 
        else:
            agent_states, road_points, road_types, goals = self.normalize_scene(agent_states, road_points, road_types, goals, new_origin_agent_idx)
            d = dict()
            d['idx'] = idx
            # need to add batch dim as pytorch_geometric batches along first dimension of torch Tensors
            d['agent'] = from_numpy({
                'agent_states': add_batch_dim(agent_states),
                'agent_types': add_batch_dim(agent_types), 
                'goals': add_batch_dim(goals),
                'actions': add_batch_dim(actions),
                'rtgs': add_batch_dim(rtgs),
                'timesteps': add_batch_dim(timesteps),
                'moving_agent_mask': add_batch_dim(moving_agent_mask)
            })
            d['map'] = from_numpy({
                'road_points': add_batch_dim(road_points),
                'road_types': add_batch_dim(road_types)
            })
            d = MotionData(d)
            no_road_feats = False

        return d, no_road_feats
        

    def get(self, idx: int):
        if idx < len(self.simulated_files):
            simulated_idx = idx
            raw_path = os.path.join(self.simulated_preprocessed_dir, self.simulated_files[simulated_idx])
            with open(raw_path, 'rb') as f:
                data = pickle.load(f)
            d, no_road_feats = self.get_data(data, simulated_idx)

        else:
            proceed = False 
            while not proceed:
                real_idx = self.real_indices[idx - len(self.simulated_files)]
                raw_path = os.path.join(self.real_preprocessed_dir, self.real_files[real_idx])
                if os.path.exists(raw_path):
                    with open(raw_path, 'rb') as f:
                        data = pickle.load(f)
                    proceed = True
                else:
                    idx += 1

                if proceed:
                    d, no_road_feats = self.get_data(data, real_idx)
                    # only load sample if it has a map
                    if no_road_feats:
                        proceed = False 
                        idx += 1
        
        return d

    
    def len(self):
        return self.dset_len

@hydra.main(version_base=None, config_path="/home/ctrl-sim-dev/cfgs/", config_name="config")
def main(cfg):
    dset = RLWaymoDatasetFineTuning(cfg.datasets.rl_waymo)
    
    np.random.seed(2025)
    random.seed(2025)
    
    dloader = DataLoader(dset, 
               batch_size=64, 
               shuffle=True, 
               num_workers=0,
               pin_memory=True,
               drop_last=True)

    i = 0
    for d in tqdm(dloader):
        i += 1
        exit()

if __name__ == '__main__':
    main()