import argparse
from os import path as osp
import sys
import mmcv
import numpy as np
import os
from collections import OrderedDict
from nuscenes.nuscenes import NuScenes
from nuscenes.utils.geometry_utils import view_points
from os import path as osp
# from pyquaternion import Quaternion
from shapely.geometry import MultiPoint, box
from typing import Dict, List, Optional, Tuple, Union

from mmdet3d.core.bbox.box_np_ops import points_cam2img
from mmdet3d.datasets import NuScenesDataset
from nuscenes.map_expansion.map_api import NuScenesMap, NuScenesMapExplorer
from nuscenes.eval.common.utils import quaternion_yaw, Quaternion
from nuscenes.map_expansion.bitmap import BitMap
from matplotlib.patches import Polygon as mPolygon

from functools import partial
from multiprocessing import Pool
import multiprocessing
from tqdm import tqdm

from shapely import affinity, ops
# from shapely.geometry import LineString, box, MultiPolygon, MultiLineString
from shapely.geometry import Polygon, MultiPolygon, LineString, Point, box, MultiLineString
from matplotlib.axes import Axes
from matplotlib.figure import Figure
import networkx as nx
from PIL import Image
sys.path.append('.')

import osm_parser


NUSC_LANEMARKTYPE_TO_LABEL = {
    'DOUBLE_DASHED_WHITE'      : 'divider_dashed',

    'DOUBLE_SOLID_WHITE'       : 'divider_solid',
    'SINGLE_SOLID_WHITE'       : 'divider_solid',
    'SINGLE_ZIGZAG_WHITE'      : 'divider_solid',
    'SINGLE_SOLID_YELLOW'      : 'divider_solid',
}


RANDOM_MASKING_LIST = [
    ['centerline'],
    ['ped_crossing'],
    ['boundary'],
    ['divider_solid', 'divider_dashed'],
    ['boundary', 'divider_dashed', 'divider_solid', 'ped_crossing'],
    ['centerline', 'divider_dashed', 'divider_solid', 'ped_crossing'],
    ['boundary', 'centerline', 'divider_dashed', 'divider_solid', 'ped_crossing'],
]

class CNuScenesMapExplorer(NuScenesMapExplorer):
    def __ini__(self, *args, **kwargs):
        super(self, CNuScenesMapExplorer).__init__(*args, **kwargs)

    def _get_centerline(self,
                           patch_box: Tuple[float, float, float, float],
                           patch_angle: float,
                           layer_name: str,
                           return_token: bool = False) -> dict:
        """
         Retrieve the centerline of a particular layer within the specified patch.
         :param patch_box: Patch box defined as [x_center, y_center, height, width].
         :param patch_angle: Patch orientation in degrees.
         :param layer_name: name of map layer to be extracted.
         :return: dict(token:record_dict, token:record_dict,...)
         """
        if layer_name not in ['lane','lane_connector']:
            raise ValueError('{} is not a centerline layer'.format(layer_name))

        patch_x = patch_box[0]
        patch_y = patch_box[1]

        patch = self.get_patch_coord(patch_box, patch_angle)

        records = getattr(self.map_api, layer_name)

        centerline_dict = dict()
        for record in records:
            if record['polygon_token'] is None:
                # import ipdb
                # ipdb.set_trace()
                continue
            polygon = self.map_api.extract_polygon(record['polygon_token'])

            # if polygon.intersects(patch) or polygon.within(patch):
            #     if not polygon.is_valid:
            #         print('within: {}, intersect: {}'.format(polygon.within(patch), polygon.intersects(patch)))
            #         print('polygon token {} is_valid: {}'.format(record['polygon_token'], polygon.is_valid))

            # polygon = polygon.buffer(0)

            if polygon.is_valid:
                # if within or intersect :

                new_polygon = polygon.intersection(patch)
                # new_polygon = polygon

                if not new_polygon.is_empty:
                    centerline = self.map_api.discretize_lanes(
                            record, 0.5)
                    centerline = list(self.map_api.discretize_lanes([record['token']], 0.5).values())[0]
                    centerline = LineString(np.array(centerline)[:,:2].round(3))
                    if centerline.is_empty:
                        continue
                    centerline = centerline.intersection(patch)
                    if not centerline.is_empty:
                        centerline = \
                            to_patch_coord(centerline, patch_angle, patch_x, patch_y)
                        
                        # centerline.coords = np.array(centerline.coords).round(3)
                        # if centerline.geom_type != 'LineString':
                            # import ipdb;ipdb.set_trace()
                        record_dict = dict(
                            centerline=centerline,
                            token=record['token'],
                            incoming_tokens=self.map_api.get_incoming_lane_ids(record['token']),
                            outgoing_tokens=self.map_api.get_outgoing_lane_ids(record['token']),
                        )
                        centerline_dict.update({record['token']: record_dict})
        return centerline_dict

def to_patch_coord(new_polygon, patch_angle, patch_x, patch_y):
    new_polygon = affinity.rotate(new_polygon, -patch_angle,
                                  origin=(patch_x, patch_y), use_radians=False)
    new_polygon = affinity.affine_transform(new_polygon,
                                            [1.0, 0.0, 0.0, 1.0, -patch_x, -patch_y])
    return new_polygon



def get_available_scenes(nusc):
    """Get available scenes from the input nuscenes class.

    Given the raw data, get the information of available scenes for
    further info generation.

    Args:
        nusc (class): Dataset class in the nuScenes dataset.

    Returns:
        available_scenes (list[dict]): List of basic information for the
            available scenes.
    """
    available_scenes = []
    print('total scene num: {}'.format(len(nusc.scene)))
    for scene in nusc.scene:
        scene_token = scene['token']
        scene_rec = nusc.get('scene', scene_token)
        sample_rec = nusc.get('sample', scene_rec['first_sample_token'])
        sd_rec = nusc.get('sample_data', sample_rec['data']['LIDAR_TOP'])
        has_more_frames = True
        scene_not_exist = False
        while has_more_frames:
            lidar_path, boxes, _ = nusc.get_sample_data(sd_rec['token'])
            lidar_path = str(lidar_path)
            if os.getcwd() in lidar_path:
                # path from lyftdataset is absolute path
                lidar_path = lidar_path.split(f'{os.getcwd()}/')[-1]
                # relative path
            if not mmcv.is_filepath(lidar_path):
                scene_not_exist = True
                break
            else:
                break
        if scene_not_exist:
            continue
        available_scenes.append(scene)
    print('exist scene num: {}'.format(len(available_scenes)))
    return available_scenes

def _get_can_bus_info(nusc, nusc_can_bus, sample):
    scene_name = nusc.get('scene', sample['scene_token'])['name']
    sample_timestamp = sample['timestamp']
    try:
        pose_list = nusc_can_bus.get_messages(scene_name, 'pose')
    except:
        return np.zeros(18)  # server scenes do not have can bus information.
    can_bus = []
    # during each scene, the first timestamp of can_bus may be large than the first sample's timestamp
    last_pose = pose_list[0]
    for i, pose in enumerate(pose_list):
        if pose['utime'] > sample_timestamp:
            break
        last_pose = pose
    _ = last_pose.pop('utime')  # useless
    pos = last_pose.pop('pos')
    rotation = last_pose.pop('orientation')
    can_bus.extend(pos)
    can_bus.extend(rotation)
    for key in last_pose.keys():
        can_bus.extend(pose[key])  # 16 elements
    can_bus.extend([0., 0.])
    return np.array(can_bus)


def obtain_sensor2top(nusc,
                      sensor_token,
                      l2e_t,
                      l2e_r_mat,
                      e2g_t,
                      e2g_r_mat,
                      sensor_type='lidar'):
    """Obtain the info with RT matric from general sensor to Top LiDAR.

    Args:
        nusc (class): Dataset class in the nuScenes dataset.
        sensor_token (str): Sample data token corresponding to the
            specific sensor type.
        l2e_t (np.ndarray): Translation from lidar to ego in shape (1, 3).
        l2e_r_mat (np.ndarray): Rotation matrix from lidar to ego
            in shape (3, 3).
        e2g_t (np.ndarray): Translation from ego to global in shape (1, 3).
        e2g_r_mat (np.ndarray): Rotation matrix from ego to global
            in shape (3, 3).
        sensor_type (str): Sensor to calibrate. Default: 'lidar'.

    Returns:
        sweep (dict): Sweep information after transformation.
    """
    sd_rec = nusc.get('sample_data', sensor_token)
    cs_record = nusc.get('calibrated_sensor',
                         sd_rec['calibrated_sensor_token'])
    pose_record = nusc.get('ego_pose', sd_rec['ego_pose_token'])
    data_path = str(nusc.get_sample_data_path(sd_rec['token']))
    if os.getcwd() in data_path:  # path from lyftdataset is absolute path
        data_path = data_path.split(f'{os.getcwd()}/')[-1]  # relative path
    sweep = {
        'data_path': data_path,
        'type': sensor_type,
        'sample_data_token': sd_rec['token'],
        'sensor2ego_translation': cs_record['translation'],
        'sensor2ego_rotation': cs_record['rotation'],
        'ego2global_translation': pose_record['translation'],
        'ego2global_rotation': pose_record['rotation'],
        'timestamp': sd_rec['timestamp']
    }

    l2e_r_s = sweep['sensor2ego_rotation']
    l2e_t_s = sweep['sensor2ego_translation']
    e2g_r_s = sweep['ego2global_rotation']
    e2g_t_s = sweep['ego2global_translation']

    # obtain the RT from sensor to Top LiDAR
    # sweep->ego->global->ego'->lidar
    l2e_r_s_mat = Quaternion(l2e_r_s).rotation_matrix
    e2g_r_s_mat = Quaternion(e2g_r_s).rotation_matrix
    R = (l2e_r_s_mat.T @ e2g_r_s_mat.T) @ (
        np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T)
    T = (l2e_t_s @ e2g_r_s_mat.T + e2g_t_s) @ (
        np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T)
    T -= e2g_t @ (np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T
                  ) + l2e_t @ np.linalg.inv(l2e_r_mat).T
    sweep['sensor2lidar_rotation'] = R.T  # points @ R.T + T
    sweep['sensor2lidar_translation'] = T
    return sweep

def _fill_trainval_infos_one_sample(sample_tuple,
                         nusc,
                         nusc_can_bus,
                         nusc_maps, 
                         osm_maps,
                         map_explorer,
                         masked_elements,
                         remove_not_relevant_keys,
                         test=False,
                         max_sweeps=10,
                         point_cloud_range=[-15.0, -30.0,-10.0, 15.0, 30.0, 10.0]):

    frame_idx, sample = sample_tuple
    
    map_location = nusc.get('log', nusc.get('scene', sample['scene_token'])['log_token'])['location']

    # if map_location != 'boston-seaport':
    #     return []

    lidar_token = sample['data']['LIDAR_TOP']
    sd_rec = nusc.get('sample_data', sample['data']['LIDAR_TOP'])
    cs_record = nusc.get('calibrated_sensor',
                         sd_rec['calibrated_sensor_token'])
    pose_record = nusc.get('ego_pose', sd_rec['ego_pose_token'])
    lidar_path, boxes, _ = nusc.get_sample_data(lidar_token)

    mmcv.check_file_exist(lidar_path)
    can_bus = _get_can_bus_info(nusc, nusc_can_bus, sample)
    ##
    info = {
        'lidar_path': lidar_path,
        'token': sample['token'],
        'prev': sample['prev'],
        'next': sample['next'],
        'can_bus': can_bus,
        'frame_idx': frame_idx,  # temporal related info
        'sweeps': [],
        'cams': dict(),
        'map_location': map_location,
        'scene_token': sample['scene_token'],  # temporal related info
        'lidar2ego_translation': cs_record['translation'],
        'lidar2ego_rotation': cs_record['rotation'],
        'ego2global_translation': pose_record['translation'],
        'ego2global_rotation': pose_record['rotation'],
        'timestamp': sample['timestamp'],
    }

    l2e_r = info['lidar2ego_rotation']
    l2e_t = info['lidar2ego_translation']
    e2g_r = info['ego2global_rotation']
    e2g_t = info['ego2global_translation']
    l2e_r_mat = Quaternion(l2e_r).rotation_matrix
    e2g_r_mat = Quaternion(e2g_r).rotation_matrix

    # obtain 6 image's information per frame
    camera_types = [
        'CAM_FRONT',
        'CAM_FRONT_RIGHT',
        'CAM_FRONT_LEFT',
        'CAM_BACK',
        'CAM_BACK_LEFT',
        'CAM_BACK_RIGHT',
    ]
    for cam in camera_types:
        cam_token = sample['data'][cam]
        cam_path, _, cam_intrinsic = nusc.get_sample_data(cam_token)
        cam_info = obtain_sensor2top(nusc, cam_token, l2e_t, l2e_r_mat,
                                     e2g_t, e2g_r_mat, cam)
        cam_info.update(cam_intrinsic=cam_intrinsic)
        info['cams'].update({cam: cam_info})

    # obtain sweeps for a single key-frame
    sd_rec = nusc.get('sample_data', sample['data']['LIDAR_TOP'])
    sweeps = []
    while len(sweeps) < max_sweeps:
        if not sd_rec['prev'] == '':
            sweep = obtain_sensor2top(nusc, sd_rec['prev'], l2e_t,
                                      l2e_r_mat, e2g_t, e2g_r_mat, 'lidar')
            sweeps.append(sweep)
            sd_rec = nusc.get('sample_data', sd_rec['prev'])
        else:
            break
    info['sweeps'] = sweeps
    # obtain annotation
    # import ipdb;ipdb.set_trace()

    info_list = []
    if 'random_whole_dataset' in masked_elements:
        for masked_elements_random in RANDOM_MASKING_LIST:
            info_cpy = deepcopy(info)
            info_cpy = obtain_vectormap(nusc_maps, osm_maps, map_explorer, info_cpy, point_cloud_range, masked_elements_random, remove_not_relevant_keys)
            info_list.append(info_cpy)
    elif 'random' in masked_elements:
        masked_elements_random = random.choice(RANDOM_MASKING_LIST)
        info = obtain_vectormap(nusc_maps, osm_maps, map_explorer, info, point_cloud_range, masked_elements_random, remove_not_relevant_keys)
        info_list.append(info)
    else:
        info = obtain_vectormap(nusc_maps, osm_maps, map_explorer, info, point_cloud_range, masked_elements, remove_not_relevant_keys)
        info_list.append(info)
    
    

    return info_list

def _fill_trainval_infos(nusc,
                         nusc_can_bus,
                         nusc_maps, 
                         osm_maps,
                         map_explorer,
                         train_scenes,
                         val_scenes,
                         masked_elements,
                         remove_not_relevant_keys,
                         test=False,
                         max_sweeps=10,
                         point_cloud_range=[-15.0, -30.0,-10.0, 15.0, 30.0, 10.0],):
    """Generate the train/val infos from the raw data.

    Args:
        nusc (:obj:`NuScenes`): Dataset class in the nuScenes dataset.
        train_scenes (list[str]): Basic information of training scenes.
        val_scenes (list[str]): Basic information of validation scenes.
        test (bool): Whether use the test mode. In the test mode, no
            annotations can be accessed. Default: False.
        max_sweeps (int): Max number of sweeps. Default: 10.

    Returns:
        tuple[list[dict]]: Information of training set and validation set
            that will be saved to the info file.
    """

    fn = partial(_fill_trainval_infos_one_sample,   nusc = nusc,
                                                    nusc_can_bus = nusc_can_bus,
                                                    nusc_maps = nusc_maps,
                                                    osm_maps = osm_maps,
                                                    map_explorer = map_explorer,
                                                    masked_elements = masked_elements,
                                                    remove_not_relevant_keys = remove_not_relevant_keys,
                                                    test = test,
                                                    max_sweeps = max_sweeps,
                                                    point_cloud_range = point_cloud_range)

    frame_indices = np.zeros([len(nusc.sample)], dtype=np.int64)
    sample_indices = np.arange(len(nusc.sample), dtype=np.int64)
    has_next = np.array([True if sample['next'] == '' else False for sample in nusc.sample], dtype=np.bool)
    frame_idx = 0
    for idx, sample in zip(sample_indices, nusc.sample):
        if sample['next'] == '':
            frame_idx = 0
        else:
            frame_idx += 1
        frame_indices[idx] =  frame_idx

    train_nusc_infos = []
    val_nusc_infos = []

    print(point_cloud_range)
    
    pool = multiprocessing.Pool(processes=32)
    try:
        for info_list in tqdm(pool.imap(fn, zip(frame_indices, nusc.sample), chunksize=100)):
        # for info_list in tqdm(map(fn, zip(frame_indices, nusc.sample))):
            # print(info)
            for info in info_list:
                if info['scene_token'] in train_scenes:
                    train_nusc_infos.append(info)
                else:
                    val_nusc_infos.append(info)
    except KeyboardInterrupt:
        logging.warning("got Ctrl+C")
    finally:
        pool.terminate()
        pool.join()

    return train_nusc_infos, val_nusc_infos

def obtain_vectormap(nusc_maps, osm_maps, map_explorer, info, point_cloud_range, masked_elements, remove_not_relevant_keys):
    # import ipdb;ipdb.set_trace()
    lidar2ego = np.eye(4)
    lidar2ego[:3,:3] = Quaternion(info['lidar2ego_rotation']).rotation_matrix
    lidar2ego[:3, 3] = info['lidar2ego_translation']
    ego2global = np.eye(4)
    ego2global[:3,:3] = Quaternion(info['ego2global_rotation']).rotation_matrix
    ego2global[:3, 3] = info['ego2global_translation']

    lidar2global = ego2global @ lidar2ego

    lidar2global_translation = list(lidar2global[:3,3])
    lidar2global_rotation = list(Quaternion(matrix=lidar2global).q)

    location = info['map_location']
    ego2global_translation = info['ego2global_translation']
    ego2global_rotation = info['ego2global_rotation']

    patch_h = point_cloud_range[4]-point_cloud_range[1]
    patch_w = point_cloud_range[3]-point_cloud_range[0]
    patch_size = (patch_h, patch_w)
    
    vector_map = VectorizedLocalMap(nusc_maps[location], map_explorer[location],patch_size)
    map_anns = vector_map.gen_vectorized_samples(osm_maps[location], lidar2global_translation, lidar2global_rotation, masked_elements, remove_not_relevant_keys)
    # import ipdb;ipdb.set_trace()
    info["annotation"] = map_anns
    return info


def generate_osm_map_info(osm_map, patch_box, patch_angle, remove_not_relevant_keys):

    patch = NuScenesMapExplorer.get_patch_coord(patch_box, patch_angle)
    patch_x = patch_box[0]
    patch_y = patch_box[1]

    result_dict = osm_map.get_elements_in_patch(patch, remove_not_relevant_keys)

    # import pdb;pdb.set_trace()

    # if result_dict['osm_map_nodes_pts'].size:
    #     new_pts = affinity.rotate(MultiPoint(result_dict['osm_map_nodes_pts']), -patch_angle, origin=(patch_x, patch_y), use_radians=False)
    #     new_pts = affinity.affine_transform(new_pts,
    #                                  [1.0, 0.0, 0.0, 1.0, -patch_x, -patch_y])
    #     result_dict['osm_map_nodes_pts'] = np.array([p.coords for p in new_pts.geoms]).squeeze()

    # transformed_ways_points = []
    # for lstring in result_dict['osm_map_ways_pts']:
    #     new_line = affinity.rotate(lstring, -patch_angle, origin=(patch_x, patch_y), use_radians=False)
    #     new_line = affinity.affine_transform(new_line,
    #                                          [1.0, 0.0, 0.0, 1.0, -patch_x, -patch_y])
    #     transformed_ways_points.append(np.array(new_line.coords))

    if result_dict['osm_map_nodes_pts'].size:
        new_pts = affinity.rotate(MultiPoint(result_dict['osm_map_nodes_pts']), -patch_angle, origin=(patch_x, patch_y), use_radians=False)
        new_pts = affinity.affine_transform(new_pts,
                                     [1.0, 0.0, 0.0, 1.0, -patch_x, -patch_y])
        result_dict['osm_map_nodes_pts'] = np.array([p.coords for p in new_pts.geoms]).squeeze()

    transformed_ways_points = []
    for lstring in result_dict['osm_map_ways_pts']:
        new_line = affinity.rotate(lstring, -patch_angle, origin=(patch_x, patch_y), use_radians=False)
        new_line = affinity.affine_transform(new_line,
                                             [1.0, 0.0, 0.0, 1.0, -patch_x, -patch_y])
        transformed_ways_points.append(np.array(new_line.coords))

    result_dict['osm_map_ways_pts'] = transformed_ways_points

    return result_dict


class VectorizedLocalMap(object):
    def __init__(self,
                 nusc_map,
                 map_explorer,
                 patch_size,
                 map_classes=['divider_dashed','divider_solid','ped_crossing','boundary','centerline'],
                 line_classes=['road_divider', 'lane_divider'],
                 ped_crossing_classes=['ped_crossing'],
                 contour_classes=['road_segment', 'lane'],
                 centerline_classes=['lane_connector','lane'],
                 use_simplify=True,
                 ):
        super().__init__()
        self.nusc_map = nusc_map
        self.map_explorer = map_explorer
        self.vec_classes = map_classes
        self.line_classes = line_classes
        self.ped_crossing_classes = ped_crossing_classes
        self.polygon_classes = contour_classes
        self.centerline_classes = centerline_classes
        self.patch_size = patch_size


    def gen_vectorized_samples(self, osm_map, lidar2global_translation, lidar2global_rotation, masked_elements, remove_not_relevant_keys):
        '''
        use lidar2global to get gt map layers
        '''
        
        map_pose = lidar2global_translation[:2]
        rotation = Quaternion(lidar2global_rotation)
        # import ipdb;ipdb.set_trace()
        patch_box = (map_pose[0], map_pose[1], self.patch_size[0], self.patch_size[1])
        patch_angle = quaternion_yaw(rotation) / np.pi * 180

        result_dict_osm_map = generate_osm_map_info( osm_map, patch_box, patch_angle, remove_not_relevant_keys)
        
        # print(result_dict_osm_map['osm_map_ways_pts'])

        map_dict=dict(
            divider_dashed=[],
            divider_solid=[],
            ped_crossing=[],
            boundary=[],
            centerline=[],
            divider_dashed_masked=[],
            divider_solid_masked=[],
            centerline_masked=[],
            boundary_masked=[],
            osm_map_nodes_pts=[],
            osm_map_nodes_tags=[],
            osm_map_ways_pts=[],
            osm_map_ways_tags=[],
            osm_map_relations_tags=[],
            osm_map_relations_node_member_indices=[],
            osm_map_relations_way_member_indices=[],
            osm_map_relations_relation_member_indices=[],
            osm_map_relations_node_member_tags=[],
            osm_map_relations_way_member_tags=[],
            osm_map_relations_relation_member_tags=[]
        )

        map_dict['osm_map_nodes_pts'] = result_dict_osm_map['osm_map_nodes_pts']
        map_dict['osm_map_nodes_tags'] = result_dict_osm_map['osm_map_nodes_tags']
        map_dict['osm_map_ways_pts'] = result_dict_osm_map['osm_map_ways_pts']
        map_dict['osm_map_ways_tags'] = result_dict_osm_map['osm_map_ways_tags']
        map_dict['osm_map_relations_tags'] = result_dict_osm_map['osm_map_relations_tags']
        map_dict['osm_map_relations_node_member_indices'] = result_dict_osm_map['osm_map_relations_node_member_indices']
        map_dict['osm_map_relations_way_member_indices'] = result_dict_osm_map['osm_map_relations_way_member_indices']
        map_dict['osm_map_relations_relation_member_indices'] = result_dict_osm_map['osm_map_relations_relation_member_indices']
        map_dict['osm_map_relations_node_member_tags'] = result_dict_osm_map['osm_map_relations_node_member_tags']
        map_dict['osm_map_relations_way_member_tags'] = result_dict_osm_map['osm_map_relations_way_member_tags']
        map_dict['osm_map_relations_relation_member_tags'] = result_dict_osm_map['osm_map_relations_relation_member_tags']
        
        dividers_processed = False
        vectors = []
        for vec_class in self.vec_classes:
            if vec_class == 'divider_dashed' or vec_class == 'divider_solid':
                if dividers_processed:
                    continue
                line_geom = self.get_map_geom(patch_box, patch_angle, self.line_classes)
                line_instances_dict = self.line_geoms_to_instances(line_geom)     
                for line_type, instances in line_instances_dict.items():

                    # for instance, line_type in instances:
                    #     map_dict[line_type].append(np.array(instance.coords))
                    
                    try:
                        for instance, line_type in instances:
                            map_dict[line_type].append(np.array(instance.coords))
                    except:
                        import pdb;pdb.set_trace()

                map_dict['divider_dashed' + '_masked'] = np.array([False for bd in map_dict['divider_dashed']])
                map_dict['divider_solid' + '_masked'] = np.array([False for bd in map_dict['divider_solid']])

                dividers_processed = True
            elif vec_class == 'ped_crossing':
                ped_geom = self.get_map_geom(patch_box, patch_angle, self.ped_crossing_classes)
                ped_instance_list = self.ped_poly_geoms_to_instances(ped_geom)
                for instance in ped_instance_list:
                    map_dict[vec_class].append(np.array(instance.coords))

                map_dict[vec_class + '_masked'] = np.array([False for bd in map_dict[vec_class]])

            elif vec_class == 'boundary':
                polygon_geom = self.get_map_geom(patch_box, patch_angle, self.polygon_classes)
                poly_bound_list = self.poly_geoms_to_instances(polygon_geom)
                for instance in poly_bound_list:
                    # import ipdb;ipdb.set_trace()
                    map_dict[vec_class].append(np.array(instance.coords))

                map_dict[vec_class + '_masked'] = np.array([False for bd in map_dict[vec_class]])

            elif vec_class =='centerline':
                centerline_geom = self.get_centerline_geom(patch_box, patch_angle, self.centerline_classes)
                centerline_list = self.centerline_geoms_to_instances(centerline_geom)
                for instance in centerline_list:
                    map_dict[vec_class].append(np.array(instance.coords))

                map_dict[vec_class + '_masked'] = np.array([False for bd in map_dict[vec_class]])

            else:
                raise ValueError(f'WRONG vec_class: {vec_class}')
            
        # import pdb;pdb.set_trace()

        for el_type in masked_elements:
            if el_type in map_dict:
                map_dict[el_type + '_masked'] = np.array([True for bd in map_dict[el_type]])
                
        

        #====================================================================================================
        
        # import matplotlib.pyplot as plt
        # car_img = Image.open('./figs/lidar_car.png')
        # colors_plt_dict = {'divider_dashed': 'xkcd:aquamarine', 'divider_solid': 'xkcd:salmon', 'divider_mixed': 'purple', 'divider_virtual': 'dimgrey', 'ped_crossing': 'blue', 'boundary': 'red', 'centerline': 'green', 'masked': 'lightgray'}
        # osm_map_nodes_pts = [map_dict['osm_map_nodes_pts']]
        # osm_map_ways_pts = [map_dict['osm_map_ways_pts']]
        # if len(osm_map_nodes_pts[0].shape) == 1:
        #         osm_map_nodes_pts[0] = np.zeros([1,2])
        # osm_nodes_tags = map_dict['osm_map_nodes_tags']
        # osm_ways_tags = map_dict['osm_map_ways_tags']
        # osm_rels_tags = map_dict['osm_map_relations_tags']      
        # osm_map_relations_node_member_indices = [map_dict['osm_map_relations_node_member_indices']]
        # osm_map_relations_way_member_indices = [map_dict['osm_map_relations_way_member_indices']]
        # osm_map_relations_relation_member_indices = [map_dict['osm_map_relations_relation_member_indices']]
        # plt.figure(figsize=(2, 4))
        # plt.xlim(-self.patch_size[1]/2, self.patch_size[1]/2)
        # plt.ylim(-self.patch_size[0]/2, self.patch_size[0]/2)
        # plt.axis('off')
        # # gt_bboxes_3d[0].fixed_num=30 #TODO, this is a hack
        # for vec_class in self.vec_classes:

        #     for pts in map_dict[vec_class]:
        #         # import pdb;pdb.set_trace() 
        #         x = np.array([pt[0] for pt in pts])
        #         y = np.array([pt[1] for pt in pts])
        #         # plt.quiver(x[:-1], y[:-1], x[1:] - x[:-1], y[1:] - y[:-1], scale_units='xy', angles='xy', scale=1, color=colors_plt[gt_label_3d])
                
        #         if vec_class == 'centerline':
        #             plt.quiver(x[:-1], y[:-1], x[1:] - x[:-1], y[1:] - y[:-1], scale_units='xy', angles='xy', scale=1, color=colors_plt_dict[vec_class])
        #         else:
        #             plt.plot(x, y, color=colors_plt_dict[vec_class],linewidth=1,alpha=0.8,zorder=-1)
        #         # plt.scatter(x, y, color=colors_plt_dict[label_names[gt_label_3d.item()]],s=2,alpha=0.8,zorder=-1)
        #         # plt.plot(x, y, color=colors_plt[gt_label_3d])
        #         # plt.scatter(x, y, color=colors_plt[gt_label_3d],s=1)
        # plt.imshow(car_img, extent=[-1.2, 1.2, -1.5, 1.5])
        # gt_fixedpts_map_path = osp.join('/workspace/sdtagnet/gen_labels/nusc_no_prior_with_osm_map', 'GT_fixednum_pts_MAP.png')
        # plt.savefig(gt_fixedpts_map_path, bbox_inches='tight', format='png',dpi=1200)
        # plt.close() 

        # plt.figure(figsize=(2, 4))
        # plt.xlim(-self.patch_size[1]/2, self.patch_size[1]/2)
        # plt.ylim(-self.patch_size[0]/2, self.patch_size[0]/2)
        # plt.axis('off')

        # drawn_texts = []

        # for i, el in enumerate(osm_nodes_tags):
        #     if el == '':
        #         continue
        #     # if args.clip_text and len(el) > 60:
        #     #     el = el[:60] + '...'
        #     text = plt.text(osm_map_nodes_pts[0][i][0], osm_map_nodes_pts[0][i][1], s=el, wrap=True, color='black', 
        #              horizontalalignment='center', verticalalignment='center', fontsize=3,  
        #              bbox=dict(boxstyle="square", ec=(0.3, 0.3, 0.3, 0.3), fc=(0.3, 0.3, 0.3, 0.3)))
        #     text._get_wrap_line_width = lambda : 1000
        #     drawn_texts.append(text)
        # for i, el in enumerate(osm_ways_tags):
        #     if el == '':
        #         continue
        #     # if args.clip_text and len(el) > 60:
        #     #     el = el[:60] + '...'
        #     text = plt.text(osm_map_ways_pts[0][i][-1][0], osm_map_ways_pts[0][i][-1][1], el, wrap=True, 
        #              color='black', horizontalalignment='center', verticalalignment='center', fontsize=3,
        #              bbox=dict(boxstyle="square", ec=(0.3, 0.3, 0.3, 0.3), fc=(0.3, 0.3, 0.3, 0.3)))
        #     text._get_wrap_line_width = lambda : 1000
        #     drawn_texts.append(text)

        # # print(osm_rels_tags)
        # # print(osm_map_relations_node_member_indices)
        # # print(osm_map_relations_way_member_indices)
        # # print(osm_map_relations_relation_member_indices)
        # for i, el in enumerate(osm_rels_tags):
        #     if el == '':
        #         continue
        #     # if args.clip_text and len(el) > 60:
        #     #     el = el[:60] + '...'
        #     member_pts = []
        #     if len(osm_map_relations_node_member_indices[0][i]):
        #         for idx in osm_map_relations_node_member_indices[0][i]:
        #             member_pts.append(osm_map_nodes_pts[0][idx].squeeze())
        #     if len(osm_map_relations_way_member_indices[0][i]):
        #         for idx in osm_map_relations_way_member_indices[0][i]:
        #             member_pts.append(np.average(osm_map_ways_pts[0][idx].squeeze(), axis=0))
        #     if len(osm_map_relations_relation_member_indices[0][i]):
        #         for idx in osm_map_relations_relation_member_indices[0][i]:
        #             member_pts.append(np.array([0, 0]))
        #     # import pdb;pdb.set_trace()

        #     member_pts = np.vstack(member_pts)
        #     center = np.average(member_pts, axis=0)
        #     for pt in member_pts:
        #         direction = center - pt
        #         plt.arrow(pt[0], pt[1], direction[0], direction[1], color='black',linewidth=0.6, alpha=0.8, zorder=5)

        #     text = plt.text(center[0], center[1], el, wrap=True, color='black', 
        #              horizontalalignment='center', verticalalignment='center', fontsize=3,
        #              bbox=dict(boxstyle="square", ec=(0.3, 0.3, 0.3, 0.3), fc=(0.3, 0.3, 0.3, 0.3)))
        #     text._get_wrap_line_width = lambda : 1000
        #     drawn_texts.append(text)
        
        # # import pdb;pdb.set_trace()
        
        # if len(osm_map_nodes_pts[0]):
        #     plt.scatter(osm_map_nodes_pts[0][:, 0], osm_map_nodes_pts[0][:, 1],linewidth=1.5, color='blue')
        # if len(osm_map_ways_pts[0]):
        #     for i, line in enumerate(osm_map_ways_pts[0]):
        #         if 'highway' in osm_ways_tags[i]:
        #             plt.plot(line[:, 0], line[:, 1],linewidth=2,alpha=0.8, color='green')
        #         else:
        #             plt.plot(line[:, 0], line[:, 1],linewidth=2,alpha=0.8, color='red')

        
        # osm_map_path = osp.join('/workspace/sdtagnet/gen_labels/nusc_no_prior_with_osm_map', 'OSM_MAP.png')
        # # bbox_inches='tight'
        # plt.savefig(osm_map_path, format='png',dpi=1200)
        # plt.close()

        # import pdb;pdb.set_trace()
        #====================================================================================================


        
        return map_dict


    def get_centerline_geom(self, patch_box, patch_angle, layer_names):
        map_geom = {}
        for layer_name in layer_names:
            if layer_name in self.centerline_classes:
                return_token = False
                layer_centerline_dict = self.map_explorer._get_centerline(
                patch_box, patch_angle, layer_name, return_token=return_token)
                if len(layer_centerline_dict.keys()) == 0:
                    continue
                # import ipdb;ipdb.set_trace()
                map_geom.update(layer_centerline_dict)
        return map_geom
    def get_map_geom(self, patch_box, patch_angle, layer_names):
        map_geom = {}
        for layer_name in layer_names:
            if layer_name in self.line_classes:
                geoms = self.get_divider_line(patch_box, patch_angle, layer_name)
                # map_geom.append((layer_name, geoms))
                map_geom[layer_name] = geoms
            elif layer_name in self.polygon_classes:
                geoms = self.get_contour_line(patch_box, patch_angle, layer_name)
                # map_geom.append((layer_name, geoms))
                map_geom[layer_name] = geoms
            elif layer_name in self.ped_crossing_classes:
                geoms = self.get_ped_crossing_line(patch_box, patch_angle)
                # map_geom.append((layer_name, geoms))
                map_geom[layer_name] = geoms
        return map_geom

    def get_divider_line(self,patch_box,patch_angle,layer_name):
        if layer_name not in self.map_explorer.map_api.non_geometric_line_layers:
            raise ValueError("{} is not a line layer".format(layer_name))

        if layer_name == 'traffic_light':
            return None

        patch_x = patch_box[0]
        patch_y = patch_box[1]

        patch = self.map_explorer.get_patch_coord(patch_box, patch_angle)

        line_list = []
        records = getattr(self.map_explorer.map_api, layer_name)

        if layer_name == 'road_divider':
            for record in records:
                line = self.map_explorer.map_api.extract_line(record['line_token'])
                if line.is_empty:  # Skip lines without nodes.
                    continue

                new_line = line.intersection(patch)
                if not new_line.is_empty:
                    new_line = affinity.rotate(new_line, -patch_angle, origin=(patch_x, patch_y), use_radians=False)
                    new_line = affinity.affine_transform(new_line,
                                                         [1.0, 0.0, 0.0, 1.0, -patch_x, -patch_y])
                    line_list.append((new_line, 'divider_solid'))

        elif layer_name == 'lane_divider':
            for record in records:
                complete_line = self.map_explorer.map_api.extract_line(record['line_token'])
                if complete_line.is_empty:  # Skip lines without nodes.
                    continue

                node_types = {node_dict['node_token']: node_dict['segment_type'] for node_dict in record['lane_divider_segments']}
                line_node_tokens = [record['node_tokens'][0]]
                if record['node_tokens'][0] in node_types:
                    last_type = node_types[record['node_tokens'][0]]
                else:
                    for node in record['node_tokens']:
                        try:
                            last_type = node_types[record['node_tokens'][0]]
                        except:
                            continue

                # import pdb;pdb.set_trace()
                
                for node in record['node_tokens'][1:]:

                    line_node_tokens.append(node)

                    if node not in node_types:
                        continue

                    if last_type != node_types[node]:

                        line_nodes = [(self.map_explorer.map_api.get('node', token)['x'], self.map_explorer.map_api.get('node', token)['y'])
                        for token in line_node_tokens]
                        line = LineString(line_nodes)

                        new_line = line.intersection(patch)
                        if not new_line.is_empty:
                            new_line = affinity.rotate(new_line, -patch_angle, origin=(patch_x, patch_y), use_radians=False)
                            new_line = affinity.affine_transform(new_line,
                                                                 [1.0, 0.0, 0.0, 1.0, -patch_x, -patch_y])
                            
                            if last_type in NUSC_LANEMARKTYPE_TO_LABEL:
                                line_type = NUSC_LANEMARKTYPE_TO_LABEL[last_type]
                            elif last_type == 'NIL':
                                continue
                            else:
                                print("WARNING: UNKNOWN LINE TYPE FOUND: " + last_type)
                                continue

                            line_list.append((new_line, line_type))

                        line_node_tokens = [node]
                        last_type = node_types[node]

                if len(line_node_tokens) > 1:
                    line_nodes = [(self.map_explorer.map_api.get('node', token)['x'], self.map_explorer.map_api.get('node', token)['y'])
                    for token in line_node_tokens]
                    line = LineString(line_nodes)
                    new_line = line.intersection(patch)
                    if not new_line.is_empty:
                        new_line = affinity.rotate(new_line, -patch_angle, origin=(patch_x, patch_y), use_radians=False)
                        new_line = affinity.affine_transform(new_line,
                                                             [1.0, 0.0, 0.0, 1.0, -patch_x, -patch_y])
                        
                        if last_type in NUSC_LANEMARKTYPE_TO_LABEL:
                            line_type = NUSC_LANEMARKTYPE_TO_LABEL[last_type]
                        elif last_type == 'NIL':
                            continue
                        else:
                            print("WARNING: UNKNOWN LINE TYPE FOUND: " + last_type)
                            continue
                        
                        line_list.append((new_line, line_type))

        return line_list

    def get_contour_line(self,patch_box,patch_angle,layer_name):
        if layer_name not in self.map_explorer.map_api.non_geometric_polygon_layers:
            raise ValueError('{} is not a polygonal layer'.format(layer_name))

        patch_x = patch_box[0]
        patch_y = patch_box[1]

        patch = self.map_explorer.get_patch_coord(patch_box, patch_angle)

        records = getattr(self.map_explorer.map_api, layer_name)

        polygon_list = []
        if layer_name == 'drivable_area':
            for record in records:
                polygons = [self.map_explorer.map_api.extract_polygon(polygon_token) for polygon_token in record['polygon_tokens']]

                for polygon in polygons:
                    new_polygon = polygon.intersection(patch)
                    if not new_polygon.is_empty:
                        new_polygon = affinity.rotate(new_polygon, -patch_angle,
                                                      origin=(patch_x, patch_y), use_radians=False)
                        new_polygon = affinity.affine_transform(new_polygon,
                                                                [1.0, 0.0, 0.0, 1.0, -patch_x, -patch_y])
                        if new_polygon.geom_type == 'Polygon':
                            new_polygon = MultiPolygon([new_polygon])
                        polygon_list.append(new_polygon)

        else:
            for record in records:
                polygon = self.map_explorer.map_api.extract_polygon(record['polygon_token'])

                if polygon.is_valid:
                    new_polygon = polygon.intersection(patch)
                    if not new_polygon.is_empty:
                        new_polygon = affinity.rotate(new_polygon, -patch_angle,
                                                      origin=(patch_x, patch_y), use_radians=False)
                        new_polygon = affinity.affine_transform(new_polygon,
                                                                [1.0, 0.0, 0.0, 1.0, -patch_x, -patch_y])
                        if new_polygon.geom_type == 'Polygon':
                            new_polygon = MultiPolygon([new_polygon])
                        polygon_list.append(new_polygon)

        return polygon_list


    def get_ped_crossing_line(self, patch_box, patch_angle):
        patch_x = patch_box[0]
        patch_y = patch_box[1]

        patch = self.map_explorer.get_patch_coord(patch_box, patch_angle)
        polygon_list = []
        records = getattr(self.map_explorer.map_api, 'ped_crossing')
        # records = getattr(self.nusc_maps[location], 'ped_crossing')
        for record in records:
            polygon = self.map_explorer.map_api.extract_polygon(record['polygon_token'])
            if polygon.is_valid:
                new_polygon = polygon.intersection(patch)
                if not new_polygon.is_empty:
                    new_polygon = affinity.rotate(new_polygon, -patch_angle,
                                                      origin=(patch_x, patch_y), use_radians=False)
                    new_polygon = affinity.affine_transform(new_polygon,
                                                            [1.0, 0.0, 0.0, 1.0, -patch_x, -patch_y])
                    if new_polygon.geom_type == 'Polygon':
                        new_polygon = MultiPolygon([new_polygon])
                    polygon_list.append(new_polygon)

        return polygon_list

    def line_geoms_to_instances(self, line_geom):
        line_instances_dict = dict()
        for line_type, a_type_of_lines in line_geom.items():
            one_type_instances = self._one_type_divider_line_geom_to_instances(a_type_of_lines)
            line_instances_dict[line_type] = one_type_instances

        return line_instances_dict

    def _one_type_divider_line_geom_to_instances(self, line_geom):
        line_instances = []

        for line, line_type in line_geom:
            if not line.is_empty:
                if line.geom_type == 'MultiLineString':
                    for single_line in line.geoms:
                        line_instances.append((single_line, line_type))
                elif line.geom_type == 'LineString':
                    line_instances.append((line, line_type))
                else:
                    raise NotImplementedError
        return line_instances

    def _one_type_line_geom_to_instances(self, line_geom):
        line_instances = []
        
        for line in line_geom:
            if not line.is_empty:
                if line.geom_type == 'MultiLineString':
                    for single_line in line.geoms:
                        line_instances.append(single_line)
                elif line.geom_type == 'LineString':
                    line_instances.append(line)
                else:
                    raise NotImplementedError
        return line_instances

    def ped_poly_geoms_to_instances(self, ped_geom):
        # ped = ped_geom[0][1]
        # import ipdb;ipdb.set_trace()
        ped = ped_geom['ped_crossing']
        union_segments = ops.unary_union(ped)
        max_x = self.patch_size[1] / 2
        max_y = self.patch_size[0] / 2
        local_patch = box(-max_x - 0.2, -max_y - 0.2, max_x + 0.2, max_y + 0.2)
        exteriors = []
        interiors = []
        if union_segments.geom_type != 'MultiPolygon':
            union_segments = MultiPolygon([union_segments])
        for poly in union_segments.geoms:
            exteriors.append(poly.exterior)
            for inter in poly.interiors:
                interiors.append(inter)

        results = []
        for ext in exteriors:
            if ext.is_ccw:
                ext.coords = list(ext.coords)[::-1]
            lines = ext.intersection(local_patch)
            if isinstance(lines, MultiLineString):
                lines = ops.linemerge(lines)
            results.append(lines)

        for inter in interiors:
            if not inter.is_ccw:
                inter.coords = list(inter.coords)[::-1]
            lines = inter.intersection(local_patch)
            if isinstance(lines, MultiLineString):
                lines = ops.linemerge(lines)
            results.append(lines)

        return self._one_type_line_geom_to_instances(results)


    def poly_geoms_to_instances(self, polygon_geom):
        roads = polygon_geom['road_segment']
        lanes = polygon_geom['lane']
        # import ipdb;ipdb.set_trace()
        union_roads = ops.unary_union(roads)
        union_lanes = ops.unary_union(lanes)
        union_segments = ops.unary_union([union_roads, union_lanes])
        max_x = self.patch_size[1] / 2
        max_y = self.patch_size[0] / 2
        local_patch = box(-max_x + 0.2, -max_y + 0.2, max_x - 0.2, max_y - 0.2)
        exteriors = []
        interiors = []
        if union_segments.geom_type != 'MultiPolygon':
            union_segments = MultiPolygon([union_segments])
        for poly in union_segments.geoms:
            exteriors.append(poly.exterior)
            for inter in poly.interiors:
                interiors.append(inter)

        results = []
        for ext in exteriors:
            if ext.is_ccw:
                ext.coords = list(ext.coords)[::-1]
            lines = ext.intersection(local_patch)
            if isinstance(lines, MultiLineString):
                lines = ops.linemerge(lines)
            results.append(lines)

        for inter in interiors:
            if not inter.is_ccw:
                inter.coords = list(inter.coords)[::-1]
            lines = inter.intersection(local_patch)
            if isinstance(lines, MultiLineString):
                lines = ops.linemerge(lines)
            results.append(lines)

        return self._one_type_line_geom_to_instances(results)

    def centerline_geoms_to_instances(self,geoms_dict):
        centerline_geoms_list,pts_G = self.union_centerline(geoms_dict)
        # vectors_dict = self.centerline_geoms2vec(centerline_geoms_list)
        # import ipdb;ipdb.set_trace()
        return self._one_type_line_geom_to_instances(centerline_geoms_list)


    def centerline_geoms2vec(self, centerline_geoms_list):
        vector_dict = {}
        # import ipdb;ipdb.set_trace()
        # centerline_geoms_list = [line.simplify(0.2, preserve_topology=True) \
        #                         for line in centerline_geoms_list]
        vectors = self._geom_to_vectors(
            centerline_geoms_list)
        vector_dict.update({'centerline': ('centerline', vectors)})
        return vector_dict

    def union_centerline(self, centerline_geoms):
        # import ipdb;ipdb.set_trace()
        pts_G = nx.DiGraph()
        junction_pts_list = []
        for key, value in centerline_geoms.items():
            centerline_geom = value['centerline']
            if centerline_geom.geom_type == 'MultiLineString':
                start_pt = np.array(centerline_geom.geoms[0].coords).round(3)[0]
                end_pt = np.array(centerline_geom.geoms[-1].coords).round(3)[-1]
                for single_geom in centerline_geom.geoms:
                    single_geom_pts = np.array(single_geom.coords).round(3)
                    for idx, pt in enumerate(single_geom_pts[:-1]):
                        pts_G.add_edge(tuple(single_geom_pts[idx]),tuple(single_geom_pts[idx+1]))
            elif centerline_geom.geom_type == 'LineString':
                centerline_pts = np.array(centerline_geom.coords).round(3)
                start_pt = centerline_pts[0]
                end_pt = centerline_pts[-1]
                for idx, pts in enumerate(centerline_pts[:-1]):
                    pts_G.add_edge(tuple(centerline_pts[idx]),tuple(centerline_pts[idx+1]))
            else:
                raise NotImplementedError
            valid_incoming_num = 0
            for idx, pred in enumerate(value['incoming_tokens']):
                if pred in centerline_geoms.keys():
                    valid_incoming_num += 1
                    pred_geom = centerline_geoms[pred]['centerline']
                    if pred_geom.geom_type == 'MultiLineString':
                        pred_pt = np.array(pred_geom.geoms[-1].coords).round(3)[-1]
        #                 if pred_pt != centerline_pts[0]:
                        pts_G.add_edge(tuple(pred_pt), tuple(start_pt))
                    else:
                        pred_pt = np.array(pred_geom.coords).round(3)[-1]
                        pts_G.add_edge(tuple(pred_pt), tuple(start_pt))
            if valid_incoming_num > 1:
                junction_pts_list.append(tuple(start_pt))
            
            valid_outgoing_num = 0
            for idx, succ in enumerate(value['outgoing_tokens']):
                if succ in centerline_geoms.keys():
                    valid_outgoing_num += 1
                    succ_geom = centerline_geoms[succ]['centerline']
                    if succ_geom.geom_type == 'MultiLineString':
                        succ_pt = np.array(succ_geom.geoms[0].coords).round(3)[0]
        #                 if pred_pt != centerline_pts[0]:
                        pts_G.add_edge(tuple(end_pt), tuple(succ_pt))
                    else:
                        succ_pt = np.array(succ_geom.coords).round(3)[0]
                        pts_G.add_edge(tuple(end_pt), tuple(succ_pt))
            if valid_outgoing_num > 1:
                junction_pts_list.append(tuple(end_pt))

        roots = (v for v, d in pts_G.in_degree() if d == 0)
        leaves = [v for v, d in pts_G.out_degree() if d == 0]
        all_paths = []
        for root in roots:
            paths = nx.all_simple_paths(pts_G, root, leaves)
            all_paths.extend(paths)

        final_centerline_paths = []
        for path in all_paths:
            merged_line = LineString(path)
            merged_line = merged_line.simplify(0.2, preserve_topology=True)
            final_centerline_paths.append(merged_line)
        return final_centerline_paths, pts_G




def create_nuscenes_infos(root_path,
                          osm_root_path,
                          out_path,
                          pc_range,
                          can_bus_root_path,
                          info_prefix,
                          masked_elements,
                          remove_not_relevant_keys,
                          version='v1.0-trainval',
                          max_sweeps=10):
    """Create info file of nuscene dataset.

    Given the raw data, generate its related info file in pkl format.

    Args:
        root_path (str): Path of the data root.
        info_prefix (str): Prefix of the info file to be generated.
        version (str): Version of the data.
            Default: 'v1.0-trainval'
        max_sweeps (int): Max number of sweeps.
            Default: 10
    """
    from nuscenes.nuscenes import NuScenes
    from nuscenes.can_bus.can_bus_api import NuScenesCanBus
    print(version, root_path)
    nusc = NuScenes(version=version, dataroot=root_path, verbose=True)
    nusc_can_bus = NuScenesCanBus(dataroot=can_bus_root_path)
    MAPS = ['boston-seaport', 'singapore-hollandvillage',
                     'singapore-onenorth', 'singapore-queenstown']
    nusc_maps = {}
    map_explorer = {}
    osm_maps = {}

    for loc in MAPS:
        nusc_maps[loc] = NuScenesMap(dataroot=root_path, map_name=loc)
        map_explorer[loc] = CNuScenesMapExplorer(nusc_maps[loc])

        osm_file_path = osp.join(osm_root_path, loc + '.osm')

        print(osm_file_path)
        # import pdb;pdb.set_trace()

        with open(osm_file_path, 'r') as osm_file:
            osm_map = osm_parser.parse(osm_file)
            print("OSM map parsed!")
        
        osm_map.build_node_way_lists(loc, nusc_mode=True)
        osm_maps[loc] = osm_map

        print("loaded OSM map " + osm_file_path)


    from nuscenes.utils import splits
    available_vers = ['v1.0-trainval', 'v1.0-test', 'v1.0-mini']
    assert version in available_vers
    if version == 'v1.0-trainval':
        train_scenes = splits.train
        val_scenes = splits.val
    elif version == 'v1.0-test':
        train_scenes = splits.test
        val_scenes = []
    elif version == 'v1.0-mini':
        train_scenes = splits.mini_train
        val_scenes = splits.mini_val
    else:
        raise ValueError('unknown')

    # filter existing scenes.
    available_scenes = get_available_scenes(nusc)
    available_scene_names = [s['name'] for s in available_scenes]
    train_scenes = list(
        filter(lambda x: x in available_scene_names, train_scenes))
    val_scenes = list(filter(lambda x: x in available_scene_names, val_scenes))
    train_scenes = set([
        available_scenes[available_scene_names.index(s)]['token']
        for s in train_scenes
    ])
    val_scenes = set([
        available_scenes[available_scene_names.index(s)]['token']
        for s in val_scenes
    ])

    test = 'test' in version
    if test:
        print('test scene: {}'.format(len(train_scenes)))
    else:
        print('train scene: {}, val scene: {}'.format(
            len(train_scenes), len(val_scenes)))

    train_nusc_infos, val_nusc_infos = _fill_trainval_infos(
        nusc, nusc_can_bus, nusc_maps, osm_maps, map_explorer, train_scenes, val_scenes, masked_elements, remove_not_relevant_keys, test, max_sweeps=max_sweeps, point_cloud_range=pc_range)

    metadata = dict(version=version)
    if test:
        print('test sample: {}'.format(len(train_nusc_infos)))
        data = dict(infos=train_nusc_infos, metadata=metadata)
        info_path = osp.join(out_path,
                             '{}_map_infos_temporal_test.pkl'.format(info_prefix))
        mmcv.dump(data, info_path)
    else:
        print('train sample: {}, val sample: {}'.format(
            len(train_nusc_infos), len(val_nusc_infos)))
        data = dict(infos=train_nusc_infos, metadata=metadata)
        info_path = osp.join(out_path,
                             '{}_map_infos_temporal_train.pkl'.format(info_prefix))
        mmcv.dump(data, info_path)
        data['infos'] = val_nusc_infos
        info_val_path = osp.join(out_path,
                                 '{}_map_infos_temporal_val.pkl'.format(info_prefix))
        mmcv.dump(data, info_val_path)



def nuscenes_data_prep(root_path,
                       osm_root_path,
                       can_bus_root_path,
                       info_prefix,
                       version,
                       dataset_name,
                       out_dir,
                       pc_range,
                       masked_elements,
                       remove_not_relevant_keys,
                       max_sweeps=10):
    """Prepare data related to nuScenes dataset.

    Related data consists of '.pkl' files recording basic infos,
    2D annotations and groundtruth database.

    Args:
        root_path (str): Path of dataset root.
        info_prefix (str): The prefix of info filenames.
        version (str): Dataset version.
        dataset_name (str): The dataset class name.
        out_dir (str): Output directory of the groundtruth database info.
        max_sweeps (int): Number of input consecutive frames. Default: 10
    """
    create_nuscenes_infos(
        root_path, osm_root_path, out_dir, pc_range, can_bus_root_path, info_prefix, masked_elements, remove_not_relevant_keys, version=version, max_sweeps=max_sweeps)

    # if version == 'v1.0-test':
    #     info_test_path = osp.join(
    #         out_dir, f'{info_prefix}_infos_temporal_test.pkl')
    #     nuscenes_converter.export_2d_annotation(
    #         root_path, info_test_path, version=version)
    # else:
    #     info_train_path = osp.join(
    #         out_dir, f'{info_prefix}_infos_temporal_train.pkl')
    #     info_val_path = osp.join(
    #         out_dir, f'{info_prefix}_infos_temporal_val.pkl')
        # nuscenes_converter.export_2d_annotation(
        #     root_path, info_train_path, version=version)
        # nuscenes_converter.export_2d_annotation(
        #     root_path, info_val_path, version=version)
        # create_groundtruth_database(dataset_name, root_path, info_prefix,
        #                             f'{out_dir}/{info_prefix}_infos_train.pkl')



parser = argparse.ArgumentParser(description='Data converter arg parser')
parser.add_argument(
    '--root-path',
    type=str,
    default='./data/kitti',
    help='specify the root path of dataset')
parser.add_argument(
    '--canbus',
    type=str,
    default='./data',
    help='specify the root path of nuScenes canbus')
parser.add_argument(
    '--version',
    type=str,
    default='v1.0',
    required=False,
    help='specify the dataset version, no need for kitti')
parser.add_argument(
    '--max-sweeps',
    type=int,
    default=10,
    required=False,
    help='specify sweeps of lidar per example')
parser.add_argument(
    '--out-dir',
    type=str,
    default='./data/kitti',
    required='False',
    help='name of info pkl')
parser.add_argument(
    '--pc-range',
    type=float,
    nargs='+',
    default=[-15.0, -30.0, -5.0, 15.0, 30.0, 3.0],
    help='specify the perception point cloud range')
parser.add_argument(
    '--remove-not-relevant-keys',
    type=bool,
    default=False,
    required=False,
    help='If true, not relevant keys will be removed from osm tags')
parser.add_argument(
    '--osm-map-root',
    type=str,
    help='specify the root path of osm maps')
parser.add_argument('--extra-tag', type=str, default='nuscenes')
parser.add_argument(
    '--workers', type=int, default=4, help='number of threads to be used')
parser.add_argument(
    '--masked-elements',
    nargs='+',
    type=str,
    default=None,
    required=False,
    help="""
         Elements that should be masked from the map information given to the network. 
         If none are given, the network will not recieve any map information during training.
         Possible Options:
         ego_lane: masks out all labels associated with the ego lane
         ego_road: masks out all labels associated with the ego road
         random: randomly selects a masking type for a sample
         random_whole_dataset: duplicates each sample for each available masking type (e.g. 8 masking types = 
         8x the dataset annotations stacked, once for each masking type)
         <list of element types>: masks out all elements with the specified type, e.g. divider_solid, divider_dashed, centerline etc.
         """)
args = parser.parse_args()


if __name__ == '__main__':
    train_version = f'{args.version}-trainval'
    nuscenes_data_prep(
        root_path=args.root_path,
        osm_root_path=args.osm_map_root,
        can_bus_root_path=args.canbus,
        info_prefix=args.extra_tag,
        version=train_version,
        dataset_name='NuScenesDataset',
        out_dir=args.out_dir,
        pc_range=args.pc_range,
        masked_elements=args.masked_elements,
        remove_not_relevant_keys=args.remove_not_relevant_keys,
        max_sweeps=args.max_sweeps)
    test_version = f'{args.version}-test'
    nuscenes_data_prep(
        root_path=args.root_path,
        osm_root_path=args.osm_map_root,
        can_bus_root_path=args.canbus,
        info_prefix=args.extra_tag,
        version=test_version,
        dataset_name='NuScenesDataset',
        out_dir=args.out_dir,
        pc_range=args.pc_range,
        masked_elements=args.masked_elements,
        remove_not_relevant_keys=args.remove_not_relevant_keys,
        max_sweeps=args.max_sweeps,
        )