# -*- coding: utf-8 -*-

"""
Basedataset class for all kinds of fusion.
"""

import os
import math
import random
from collections import OrderedDict

import torch
import numpy as np
from torch.utils.data import Dataset

import opencood.utils.pcd_utils as pcd_utils
from opencood.utils.camera_utils import load_rgb_from_files
from opencood.data_utils.augmentor.data_augmentor import DataAugmentor
from opencood.hypes_yaml.yaml_utils import load_pkl
from opencood.utils.pcd_utils import downsample_lidar_minimum
from opencood.utils.transformation_utils import x1_to_x2
from opencood.data_utils import SUPER_CLASS_MAP
import opencood.data_utils.datasets

class BaseDataset(Dataset):
    """
    Base dataset for all kinds of fusion. Mainly used to initialize the
    database and associate the __get_item__ index with the correct timestamp
    and scenario.

    Parameters
    __________
    params : dict
        The dictionary contains all parameters for training/testing.

    visualize : false
        If set to true, the raw point cloud will be saved in the memory
        for visualization.

    Attributes
    ----------
    scenario_database : OrderedDict
        A structured dictionary contains all file information.

    len_record : list
        The list to record each scenario's data length. This is used to
        retrieve the correct index during training.

    pre_processor : opencood.pre_processor
        Used to preprocess the raw data.

    post_processor : opencood.post_processor
        Used to generate training labels and convert the model outputs to
        bbx formats.

    data_augmentor : opencood.data_augmentor
        Used to augment data.

    """

    def __init__(self, params, visualize, train=True):
        self.params = params
        self.visualize = visualize
        self.train = train

        self.dataset = params['dataset']
        assert self.dataset in ['V2XReal', 'V2XSim']
        if self.dataset == 'V2XReal':
            self.rsu_num = 2
            self.max_camera_num = 4
        elif self.dataset == 'V2XSim':
            self.rsu_num = 1
            self.max_camera_num = 6

        self.dataset_mode = params['dataset_mode']
        assert self.dataset_mode in ['ic', 'v2i']

        self.pre_processor = None
        self.post_processor = None
        self.data_augmentor = DataAugmentor(params['data_augment'],
                                            train)
        self.class_names = SUPER_CLASS_MAP.keys()
        self.build_inverse_super_class_map()
        self.build_class_name2int_map()
        self.modality = params['modality']
        try:
            self.ego_disconnected_rate = params['ego_disconnected_rate']
            self.agent_disconnected_rate = params['agent_disconnected_rate']
        except:
            self.ego_disconnected_rate = 0
            self.agent_disconnected_rate = 0
        # if the training/testing include noisy setting
        try:
            self.img_format = params['img_format']
        except KeyError:
            print('img_format args should be included, npy or jpg')

        if 'wild_setting' in params:
            self.seed = params['wild_setting']['seed']
            # whether to add time delay
            self.async_flag = params['wild_setting']['async']
            self.async_mode = \
                'sim' if 'async_mode' not in params['wild_setting'] \
                    else params['wild_setting']['async_mode']
            self.async_overhead = params['wild_setting']['async_overhead']

            # localization error
            self.loc_err_flag = params['wild_setting']['loc_err']
            self.xyz_noise_std = params['wild_setting']['xyz_std']
            self.ryp_noise_std = params['wild_setting']['ryp_std']

            # transmission data size
            self.data_size = \
                params['wild_setting']['data_size'] \
                    if 'data_size' in params['wild_setting'] else 0
            self.transmission_speed = \
                params['wild_setting']['transmission_speed'] \
                    if 'transmission_speed' in params['wild_setting'] else 27
            self.backbone_delay = \
                params['wild_setting']['backbone_delay'] \
                    if 'backbone_delay' in params['wild_setting'] else 0

        else:
            self.async_flag = False
            self.async_overhead = 0  # ms
            self.async_mode = 'sim'
            self.loc_err_flag = False
            self.xyz_noise_std = 0
            self.ryp_noise_std = 0
            self.data_size = 0  # Mb (Megabits)
            self.transmission_speed = 27  # Mbps
            self.backbone_delay = 0  # ms

        if self.train:
            root_dir = params['root_dir']
        else:
            root_dir = params['validate_dir']

        if 'train_params' not in params or 'max_cav' not in params['train_params']:
            self.max_cav = 4
        else:
            self.max_cav = params['train_params']['max_cav']

        # first load all paths of different scenarios
        self.scenario_folders = sorted([os.path.join(root_dir, x)
                                        for x in os.listdir(root_dir) if
                                        os.path.isdir(
                                            os.path.join(root_dir, x))])
        if not self.train and self.dataset_mode != "v2v":
            self.scenario_folders = [scenario_folder for scenario_folder in
                                     self.scenario_folders if "2023-04-07" not in scenario_folder.split("/")[-1]]
            print(self.scenario_folders)

        # Structure: {scenario_id : {cav_1 : {timestamp1 : {yaml: path,
        # lidar: path, cameras:list of path}}}}
        self.reinitialize()

    def build_inverse_super_class_map(self):
        self.INVERSE_SUPER_CLASS_MAP = {}
        for super_class_name in SUPER_CLASS_MAP.keys():
            for class_name in SUPER_CLASS_MAP[super_class_name]:
                self.INVERSE_SUPER_CLASS_MAP[class_name] = super_class_name

    def build_class_name2int_map(self):
        self.class_name2int = {}
        for i, class_name in enumerate(self.class_names):
            self.class_name2int[class_name] = i + 1

    def reinitialize(self):
        self.scenario_database = OrderedDict()
        self.len_record = []
        count = 0
        # loop over all scenarios
        for (i, scenario_folder) in enumerate(self.scenario_folders):
            cav_list = sorted([x for x in os.listdir(scenario_folder)
                               if os.path.isdir(
                    os.path.join(scenario_folder, x))], key=int)
            # if self.dataset == "V2XReal":
            # if self.train:
            if self.dataset_mode == 'ic': # Not include CAV in this mode
                cav_list = cav_list[:self.rsu_num]
                random.shuffle(cav_list)
            elif self.dataset_mode == 'v2i':
                cav_list_with_index = [list(item) for item in enumerate(cav_list)]
                rsu_list_with_index = cav_list_with_index[:self.rsu_num]
                veh_list_with_index = cav_list_with_index[self.rsu_num:]
                random.shuffle(rsu_list_with_index)
                random.shuffle(veh_list_with_index)

                rsu_list_indices, rsu_list = map(list, zip(*rsu_list_with_index))
                veh_list_indices, veh_list = map(list, zip(*veh_list_with_index))
                rsu_list.extend(veh_list)
                rsu_list_indices.extend(veh_list_indices)
                cav_list = rsu_list
                cav_list_indices = rsu_list_indices
            # else:
            #     cav_list_indices = [i for i in range(len(cav_list))]
            #     cav_list[0], cav_list[1] = cav_list[1], cav_list[0]
            #     cav_list_indices[0], cav_list_indices[1] = cav_list_indices[1], cav_list_indices[0]

            i = count
            count += 1
            self.scenario_database.update({i: OrderedDict()})

            # at least 1 cav should show up
            print(cav_list)
            assert len(cav_list) > 0

            # loop over all CAV data
            distance_list = np.load(scenario_folder + '/distance.npy')
            selected_timestamps = []
            ego_index = cav_list_indices[0]
            ego_path = os.path.join(scenario_folder, cav_list[0])
            masks = []
            oslsdir = os.listdir(ego_path)
            oslsdir.sort()

            for x in oslsdir:
                if x.endswith('.pkl') and 'additional' not in x:
                    timestamp = x.replace('.pkl', '')
                    if self.dataset_mode == 'ic':
                        selected_timestamps.append(timestamp)
                    elif self.dataset_mode == 'v2i':
                        mask = (distance_list[int(timestamp)][ego_index]
                                < opencood.data_utils.datasets.COM_RANGE)
                        if mask.sum() > 2:  # RSU 1, RSU 2, and at least a CAV
                            selected_timestamps.append(timestamp)
                            masks.append(mask)

            if len(masks) > 0:  # No timestamp has been selected sometimes
                masks = np.stack(masks)

            for (t, timestamp) in enumerate(selected_timestamps):
                self.scenario_database[i][timestamp] = OrderedDict()
                cav_num = 0
                for (j, cav_id) in enumerate(cav_list):
                    if cav_num > self.max_cav - 1:
                        # print('too many cavs')
                        break

                    if masks[t][cav_list_indices[j]] == False:  # the order of masks follow with original cav order
                        continue

                    self.scenario_database[i][timestamp][cav_id] = OrderedDict()

                    cav_content = self.scenario_database[i][timestamp][cav_id]

                    if j == 0:
                        # we regard the agent with the minimum id as the ego
                        ego_cav_content = self.scenario_database[i][timestamp][cav_id]
                        self.scenario_database[i][timestamp][cav_id]['ego'] = True
                    else:
                        self.scenario_database[i][timestamp][cav_id]['ego'] = False

                    cav_path = os.path.join(scenario_folder, cav_id)

                    pkl_file = os.path.join(cav_path,
                                            timestamp + '.pkl')
                    lidar_file = os.path.join(cav_path,
                                              timestamp + '.bin')
                    camera_files = self.load_camera_files(cav_path, timestamp, self.img_format, self.max_camera_num)

                    if self.dataset == 'V2XSim':
                        seg_label_file = os.path.join(cav_path, timestamp + '_bev.npy')
                        self.scenario_database[i][timestamp][cav_id]['seg_label'] = seg_label_file

                    self.scenario_database[i][timestamp][cav_id]['pkl'] = \
                        pkl_file
                    self.scenario_database[i][timestamp][cav_id]['lidar'] = \
                        lidar_file
                    self.scenario_database[i][timestamp][cav_id]['cameras'] = \
                        camera_files
                    # load the corresponding data into the dictionary
                    self.scenario_database[i][timestamp][cav_id]['camera_params'] = \
                        self.reform_camera_param(cav_content, ego_cav_content)
                    self.scenario_database[i][timestamp][cav_id]['params'] = \
                        self.reform_param(cav_content, ego_cav_content)

                    cav_num += 1
            # Assume all cavs will have the same timestamps length. Thus
            # we only need to calculate for the first vehicle in the
            # scene.
            if not self.len_record:
                self.len_record.append(len(selected_timestamps))
            else:
                prev_last = self.len_record[-1]
                self.len_record.append(prev_last + len(selected_timestamps))

        # Random assign infrastructure disconnected based on rate
        num_falses = round(self.len_record[-1] * self.ego_disconnected_rate)
        num_trues = self.len_record[-1] - num_falses
        self.ego_avaliable_mask = np.array([True] * num_trues + [False] * num_falses)
        np.random.shuffle(self.ego_avaliable_mask)

        # Random assign vehicle disconnected based on rate
        num_falses = round(self.len_record[-1] * self.agent_disconnected_rate)
        num_trues = self.len_record[-1] - num_falses
        self.agent_avaliable_mask = np.array([True] * num_trues + [False] * num_falses)
        np.random.shuffle(self.agent_avaliable_mask)

    def __len__(self):
        return self.len_record[-1]

    def __getitem__(self, idx):
        """
        Abstract method, needs to be define by the children class.
        """
        pass

    def retrieve_base_data(self, idx):
        """
        Given the index, return the corresponding data.

        Parameters
        ----------
        idx : int
            Index given by dataloader.

        Returns
        -------
        data : dict
            The dictionary contains loaded pkl params and lidar data for
            each cav.
        """
        # we loop the accumulated length list to see get the scenario index
        scenario_index = 0
        for i, ele in enumerate(self.len_record):
            if idx < ele:
                scenario_index = i
                break
        scenario_database = self.scenario_database[scenario_index]

        # check the timestamp index
        timestamp_index = idx if scenario_index == 0 else \
            idx - self.len_record[scenario_index - 1]
        # retrieve the corresponding timestamp key
        timestamp_key = self.return_timestamp_key(scenario_database,
                                                  timestamp_index)

        data = OrderedDict()
        # load files for all CAVs
        for cav_id, cav_content in scenario_database[timestamp_key].items():
            data[cav_id] = OrderedDict()
            data[cav_id]['ego'] = cav_content['ego']

                        # load the corresponding data into the dictionary
            data[cav_id]['camera_params'] = cav_content['camera_params']
            data[cav_id]['params'] = cav_content['params']

            data[cav_id]['lidar_np'] = \
                pcd_utils.load_lidar_bin(
                    cav_content['lidar'], self.dataset, zero_intensity=True)
            data[cav_id]['camera_np'] = \
                load_rgb_from_files(
                    cav_content['cameras'], self.img_format)

            if self.dataset == 'V2XSim':
                data[cav_id]['seg_label_np'] = np.load(cav_content['seg_label'])

            data[cav_id]['folder_name'] = \
                cav_content['lidar'].split('/')[-3]
            data[cav_id]['index'] = timestamp_index
            data[cav_id]['cav_id'] = int(cav_id)
            data[cav_id]['timestamp_key'] = timestamp_key
            # hard coded, no intensity
        return data

    @staticmethod
    def return_timestamp_key(scenario_database, timestamp_index):
        """
        Given the timestamp index, return the correct timestamp key, e.g.
        2 --> '000078'.

        Parameters
        ----------
        scenario_database : OrderedDict
            The dictionary contains all contents in the current scenario.

        timestamp_index : int
            The index for timestamp.

        Returns
        -------
        timestamp_key : str
            The timestamp key saved in the cav dictionary.
        """
        # get all timestamp keys
        timestamp_key = list(scenario_database.items())[timestamp_index][0]

        return timestamp_key

    def calc_dist_to_ego(self, scenario_database, timestamp_key):
        """
        Calculate the distance to ego for each cav.
        """
        ego_lidar_pose = None
        ego_cav_content = None
        # Find ego pose first
        for cav_id, cav_content in scenario_database.items():
            if cav_content['ego']:
                ego_cav_content = cav_content
                ego_lidar_pose = \
                    load_pkl(cav_content[timestamp_key]['pkl'])['lidar_pose']
                break

        assert ego_lidar_pose is not None

        # calculate the distance
        for cav_id, cav_content in scenario_database.items():
            cur_lidar_pose = \
                load_pkl(cav_content[timestamp_key]['pkl'])['lidar_pose']
            distance = \
                math.sqrt((cur_lidar_pose[0] -
                           ego_lidar_pose[0]) ** 2 +
                          (cur_lidar_pose[1] - ego_lidar_pose[1]) ** 2)
            cav_content['distance_to_ego'] = distance
            scenario_database.update({cav_id: cav_content})

        return ego_cav_content

    def time_delay_calculation(self, ego_flag):
        """
        Calculate the time delay for a certain vehicle.

        Parameters
        ----------
        ego_flag : boolean
            Whether the current cav is ego.

        Return
        ------
        time_delay : int
            The time delay quantization.
        """
        # there is not time delay for ego vehicle
        if ego_flag:
            return 0
        # time delay real mode
        if self.async_mode == 'real':
            # in the real mode, time delay = systematic async time + data
            # transmission time + backbone computation time
            overhead_noise = np.random.uniform(0, self.async_overhead)
            tc = self.data_size / self.transmission_speed * 1000
            time_delay = int(overhead_noise + tc + self.backbone_delay)
        elif self.async_mode == 'sim':
            # in the simulation mode, the time delay is constant
            time_delay = np.abs(self.async_overhead)

        # the data is 10 hz for both opv2v and v2x-set
        # todo: it may not be true for other dataset like DAIR-V2X and V2X-Sim
        time_delay = time_delay // 100
        return time_delay if self.async_flag else 0

    def add_loc_noise(self, pose, xyz_std, ryp_std):
        """
        Add localization noise to the pose.

        Parameters
        ----------
        pose : list
            x,y,z,roll,yaw,pitch

        xyz_std : float
            std of the gaussian noise on xyz

        ryp_std : float
            std of the gaussian noise
        """
        if not self.train:
            np.random.seed(self.seed)
        xyz_noise = np.random.normal(0, xyz_std, 3)
        ryp_std = np.random.normal(0, ryp_std, 3)
        noise_pose = [pose[0] + xyz_noise[0],
                      pose[1] + xyz_noise[1],
                      pose[2] + xyz_noise[2],
                      pose[3],
                      pose[4] + ryp_std[1],
                      pose[5]]
        return noise_pose

    def reform_param(self, cav_content, ego_content):
        """
        Reform the data params with current timestamp object groundtruth and
        delay timestamp LiDAR pose for other CAVs.

        Parameters
        ----------
        cav_content : dict
            Dictionary that contains all file paths in the current cav/rsu.

        ego_content : dict
            Ego vehicle content.

        timestamp_cur : str
            The current timestamp.

        timestamp_delay : str
            The delayed timestamp.

        cur_ego_pose_flag : bool
            Whether use current ego pose to calculate transformation matrix.

        Return
        ------
        The merged parameters.
        """
        cur_params = load_pkl(cav_content['pkl'])

        cur_ego_params = load_pkl(ego_content['pkl'])

        # we need to calculate the transformation matrix from cav to ego
        # at the delayed timestamp
        cur_ego_lidar_pose = cur_ego_params['lidar_pose']
        cur_cav_lidar_pose = cur_params['lidar_pose']

        # if not cav_content['ego'] and self.loc_err_flag:
        #     delay_cav_lidar_pose = self.add_loc_noise(delay_cav_lidar_pose,
        #                                               self.xyz_noise_std,
        #                                               self.ryp_noise_std)
        #     cur_cav_lidar_pose = self.add_loc_noise(cur_cav_lidar_pose,
        #                                             self.xyz_noise_std,
        #                                             self.ryp_noise_std)

        transformation_matrix = x1_to_x2(cur_cav_lidar_pose,
                                         cur_ego_lidar_pose)
        spatial_correction_matrix = np.eye(4)

        # This is only used for late fusion, as it did the transformation
        # in the postprocess, so we want the gt object transformation use
        # the correct one
        gt_transformation_matrix = x1_to_x2(cur_cav_lidar_pose,
                                            cur_ego_lidar_pose)
        # Map each category class name to a supper class name; Group similar classes to a super class
        self.map_class_name_to_super_class_name(cur_params['vehicles'])

        # we always use current timestamp's gt bbx to gain a fair evaluation
        cur_params['vehicles'] = self.filter_boxes_by_class(
            cur_params['vehicles'])
        cur_params['transformation_matrix'] = transformation_matrix
        cur_params['gt_transformation_matrix'] = \
            gt_transformation_matrix
        cur_params['spatial_correction_matrix'] = spatial_correction_matrix

        return cur_params

    def reform_camera_param(self, cav_content, ego_content):
        """
        Load camera extrinsic and intrinsic into a propoer format. todo:
        Enable delay and localization error.

        Returns
        -------
        The camera params dictionary.
        """
        camera_params = OrderedDict()

        cav_params = load_pkl(cav_content['pkl'])
        ego_params = load_pkl(ego_content['pkl'])
        ego_lidar_pose = ego_params['lidar_pose']
        try:
            ego_pose = ego_params['true_ego_pos'] # TODO: typo error in OPV2V repository
        except:
            ego_pose = ego_params['true_ego_pose']

        # load each camera's world coordinates, extrinsic (lidar to camera)
        # pose and intrinsics (the same for all cameras).


        for i in range(self.max_camera_num):
            try:
                camera_coords = cav_params['camera%d' % i]['cords']
                camera_extrinsic = np.array(
                    cav_params['camera%d' % i]['extrinsic'])
                camera_extrinsic_to_ego_lidar = x1_to_x2(camera_coords,
                                                         ego_lidar_pose)
                camera_extrinsic_to_ego = x1_to_x2(camera_coords,
                                                   ego_pose)

                camera_intrinsic = np.array(
                    cav_params['camera%d' % i]['intrinsic'])

                cur_camera_param = {'camera_coords': camera_coords,
                                    'camera_extrinsic': camera_extrinsic,
                                    'camera_intrinsic': camera_intrinsic,
                                    'camera_extrinsic_to_ego_lidar':
                                        camera_extrinsic_to_ego_lidar,
                                    'camera_extrinsic_to_ego':
                                        camera_extrinsic_to_ego}
                camera_params.update({'camera%d' % i: cur_camera_param})
            except:
                continue
        return camera_params

    def filter_boxes_by_class(self, object_dict):
        filtered_object_dict = OrderedDict()
        for obj_id, obj in object_dict.items():
            if obj['obj_type'].lower() in self.class_names:
                # Map class name (string) to int (np.array with shape (1,))
                obj['obj_type'] = np.array(
                    [self.class_name2int[obj['obj_type'].lower()]])
                filtered_object_dict[obj_id] = obj
        return filtered_object_dict

    def map_class_name_to_super_class_name(self, object_dict):
        new_object_dict = OrderedDict()
        for obj_id, obj in object_dict.items():
            if obj['obj_type'] not in self.INVERSE_SUPER_CLASS_MAP:
                continue
            obj['obj_type'] = self.INVERSE_SUPER_CLASS_MAP[obj['obj_type']]
            new_object_dict[obj_id] = obj
        return new_object_dict

    @staticmethod
    def load_camera_files(cav_path, timestamp, img_format, max_camera_num):
        """
        Retrieve the paths to all camera files.

        Parameters
        ----------
        cav_path : str
            The full file path of current cav.

        timestamp : str
            Current timestamp

        Returns
        -------
        camera_files : list
            The list containing all camera png file paths.
        """
        cameras = []
        for i in range(max_camera_num):
            cam_file = os.path.join(cav_path, timestamp + '_camera%d.%s' % (i, img_format))
            if os.path.isfile(cam_file):
                cameras.append(cam_file)
        return cameras

    def project_points_to_bev_map(self, points, ratio=0.1):
        """
        Project points to BEV occupancy map with default ratio=0.1.

        Parameters
        ----------
        points : np.ndarray
            (N, 3) / (N, 4)

        ratio : float
            Discretization parameters. Default is 0.1.

        Returns
        -------
        bev_map : np.ndarray
            BEV occupancy map including projected points
            with shape (img_row, img_col).

        """
        return self.pre_processor.project_points_to_bev_map(points, ratio)

    def augment(self, lidar_np, object_bbx_center, object_bbx_mask,
                flip=None, rotation=None, scale=None):
        """
        Given the raw point cloud, augment by flipping and rotation.

        Parameters
        ----------
        lidar_np : np.ndarray
            (n, 4) shape

        object_bbx_center : np.ndarray
            (n, 8) shape to represent bbx's x, y, z, h, w, l, yaw, class

        object_bbx_mask : np.ndarray
            Indicate which elements in object_bbx_center are padded.
        """
        tmp_dict = {'lidar_np': lidar_np,
                    'object_bbx_center': object_bbx_center[:, :7],
                    'object_bbx_mask': object_bbx_mask,
                    'flip': flip,
                    'noise_rotation': rotation,
                    'noise_scale': scale}
        tmp_dict = self.data_augmentor.forward(tmp_dict)

        lidar_np = tmp_dict['lidar_np']
        object_bbx_center[:, :7] = tmp_dict['object_bbx_center']
        object_bbx_mask = tmp_dict['object_bbx_mask']

        return lidar_np, object_bbx_center, object_bbx_mask

    def collate_batch_train(self, batch):
        """
        Customized collate function for pytorch dataloader during training
        for early and late fusion dataset.

        Parameters
        ----------
        batch : dict

        Returns
        -------
        batch : dict
            Reformatted batch.
        """
        # during training, we only care about ego.
        output_dict = {'ego': {}}

        object_bbx_center = []
        object_bbx_mask = []
        processed_lidar_list = []
        label_dict_list = []

        if self.visualize:
            origin_lidar = []

        for i in range(len(batch)):
            ego_dict = batch[i]['ego']
            object_bbx_center.append(ego_dict['object_bbx_center'])
            object_bbx_mask.append(ego_dict['object_bbx_mask'])
            processed_lidar_list.append(ego_dict['processed_lidar'])
            label_dict_list.append(ego_dict['label_dict'])

            if self.visualize:
                origin_lidar.append(ego_dict['origin_lidar'])

        # convert to numpy, (B, max_num, 7)
        object_bbx_center = torch.from_numpy(np.array(object_bbx_center))
        object_bbx_mask = torch.from_numpy(np.array(object_bbx_mask))

        processed_lidar_torch_dict = \
            self.pre_processor.collate_batch(processed_lidar_list)
        label_torch_dict = \
            self.post_processor.collate_batch(label_dict_list)
        output_dict['ego'].update({'object_bbx_center': object_bbx_center,
                                   'object_bbx_mask': object_bbx_mask,
                                   'processed_lidar': processed_lidar_torch_dict,
                                   'label_dict': label_torch_dict})
        if self.visualize:
            origin_lidar = \
                np.array(downsample_lidar_minimum(pcd_np_list=origin_lidar))
            origin_lidar = torch.from_numpy(origin_lidar)
            output_dict['ego'].update({'origin_lidar': origin_lidar})

        return output_dict

    def visualize_result(self, pred_box_tensor,
                         gt_box_tensor,
                         origin_lidar,
                         map_lidar,
                         show_vis,
                         save_path,
                         dataset=None):
        # visualize the model output
        self.post_processor.visualize(pred_box_tensor,
                                      gt_box_tensor,
                                      origin_lidar,
                                      show_vis,
                                      save_path,
                                      dataset=dataset)
