# -*- coding: utf-8 -*-

"""
Dataset class for intermediate fusion
"""
from collections import OrderedDict

import numpy as np
import torch

import opencood.data_utils.post_processor as post_processor
from opencood.utils import box_utils
from opencood.data_utils.datasets import basedataset
from opencood.data_utils.pre_processor import build_preprocessor
from opencood.utils.pcd_utils import \
    mask_points_by_range, mask_ego_points, shuffle_points
from opencood.data_utils.augmentor.data_augmentor import DataAugmentor


class IntermediateFusionDatasetV2XReal(basedataset.BaseDataset):
    """
    This class is for intermediate fusion where each vehicle transmit the
    deep features to ego.
    """

    def __init__(self, params, visualize, train=True):
        super(IntermediateFusionDatasetV2XReal, self). \
            __init__(params, visualize, train)

        # if project first, cav's lidar will first be projected to
        # the ego's coordinate frame. otherwise, the feature will be
        # projected instead.
        self.proj_first = True
        if 'proj_first' in params['fusion']['args'] and \
                not params['fusion']['args']['proj_first']:
            self.proj_first = False

        # whether there is a time delay between the time that cav project
        # lidar to ego and the ego receive the delivered feature
        self.cur_ego_pose_flag = True if 'cur_ego_pose_flag' not in \
                                         params['fusion']['args'] else \
            params['fusion']['args']['cur_ego_pose_flag']

        self.pre_processor = build_preprocessor(params['preprocess'],
                                                train)
        self.post_processor = post_processor.build_postprocessor(
            params['postprocess'],
            self.class_names,
            train)

        # augmenter related
        self.augment_config = params['data_augment']
        self.data_augmentor = DataAugmentor(params['data_augment'],
                                            train,
                                            intermediate=True)

    def generate_augment(self):
        flip = [None, None, None]
        noise_rotation = None
        noise_scale = None

        for aug_ele in self.augment_config:
            # for intermediate fusion only
            if 'random_world_rotation' in aug_ele['NAME']:
                rot_range = \
                    aug_ele['WORLD_ROT_ANGLE']
                if not isinstance(rot_range, list):
                    rot_range = [-rot_range, rot_range]
                noise_rotation = np.random.uniform(rot_range[0],
                                                   rot_range[1])

            if 'random_world_flip' in aug_ele['NAME']:
                for i, cur_axis in enumerate(aug_ele['ALONG_AXIS_LIST']):
                    enable = np.random.choice([False, True], replace=False,
                                              p=[0.5, 0.5])
                    flip[i] = enable

            if 'random_world_scaling' in aug_ele['NAME']:
                scale_range = \
                    aug_ele['WORLD_SCALE_RANGE']
                noise_scale = \
                    np.random.uniform(scale_range[0], scale_range[1])

        return flip, noise_rotation, noise_scale

    def __getitem__(self, idx):
        base_data_dict = self.retrieve_base_data(idx)

        processed_data_dict = OrderedDict()
        processed_data_dict['ego'] = {}

        # augmentation related
        flip, noise_rotation, noise_scale = self.generate_augment()
        # print(noise_rotation)

        ego_lidar_pose = []
        for cav_id, cav_content in base_data_dict.items():
            if cav_content['ego']:
                ego_id = cav_id
                ego_lidar_pose = cav_content['params']['lidar_pose']
                break

        processed_features = [] # processed lidar data with dictornary format
        object_stack = []
        object_id_stack = []

        # The number of sensors are differ for CAV and RSU
        # Final shape: (L, M, H, W, 3)
        cav_camera_data = []
        rsu_camera_data = []
        # (L, M, 3, 3)
        cav_camera_intrinsic = []
        rsu_camera_intrinsic = []
        # (L, M, 4, 4)
        cav_camera2ego = []
        rsu_camera2ego = []

        # (max_cav, 4, 4)
        transformation_matrix = []

        if self.visualize:
            projected_lidar_stack = []

        # loop over all CAVs to process information
        cav_num = 0
        data_loc = ''
        for cav_id, selected_cav_base in base_data_dict.items():

            # augment related
            selected_cav_base['flip'] = flip
            selected_cav_base['noise_rotation'] = noise_rotation
            selected_cav_base['noise_scale'] = noise_scale
            data_loc = selected_cav_base['folder_name'] + '/' + ego_id + '/' + str(selected_cav_base['timestamp_key'])

            selected_cav_processed, void_lidar = self.get_item_single_car(selected_cav_base, ego_lidar_pose)

            if void_lidar:
                continue

            # Get vehicles' sensor data
            if cav_id == '1' or cav_id == '2':
                if self.modality['cav_camera']:
                    cav_camera_data.append(selected_cav_processed['camera']['data'])
                    cav_camera_intrinsic.append(selected_cav_processed['camera']['intrinsic'])
                    cav_camera2ego.append(selected_cav_processed['camera']['extrinsic'])
                if self.modality['cav_lidar']:
                    processed_features.append(
                        selected_cav_processed['processed_features'])

            # Get infrastructure's sensor data
            if cav_id == '-1' or cav_id == '-2':
                if self.modality['rsu_camera']:
                    rsu_camera_data.append(selected_cav_processed['camera']['data'])
                    rsu_camera_intrinsic.append(selected_cav_processed['camera']['intrinsic'])
                    rsu_camera2ego.append(selected_cav_processed['camera']['extrinsic'])
                if self.modality['rsu_lidar']:
                    processed_features.append(
                        selected_cav_processed['processed_features'])
                    if (self.ego_avaliable_mask[idx] == False) and  (cav_id==ego_id): #TODO: We set lidar feature to 0 to represent the ego disconnected.
                        selected_cav_processed['processed_features']['bev_input'] = (
                            np.ones_like(selected_cav_processed['processed_features']['bev_input'])*1e-8)

            # Add transformation matrices
            transformation_matrix.append(selected_cav_base['params']['transformation_matrix'])

            # Add object truth labels
            object_stack.append(selected_cav_processed['object_bbx_center'])
            object_id_stack += selected_cav_processed['object_ids']

            cav_num = cav_num+1

            if self.visualize:
                # Crop the projected lidar to the ego agent's GT_RANGE
                from opencood.data_utils.datasets import GT_RANGE
                projected_lidar_vis = mask_points_by_range(
                    selected_cav_processed['projected_lidar_original'],
                    GT_RANGE)
                projected_lidar_stack.append(
                    projected_lidar_vis)

        # stack all agents together
        if cav_camera_data:
            cav_camera_data = np.stack(cav_camera_data) # CAV 4 camera
            cav_camera_intrinsic = np.stack(cav_camera_intrinsic)
            cav_camera2ego = np.stack(cav_camera2ego)

        if rsu_camera_data:
            rsu_camera_data = np.stack(rsu_camera_data) # RSU 2 camera
            rsu_camera_intrinsic = np.stack(rsu_camera_intrinsic)
            rsu_camera2ego = np.stack(rsu_camera2ego)

        # exclude all repetitive objects
        unique_indices = \
            [object_id_stack.index(x) for x in set(object_id_stack)]
        object_stack = np.vstack(object_stack)
        object_stack = object_stack[unique_indices]

        # make sure bounding boxes across all frames have the same number
        object_bbx_center = \
            np.zeros((self.params['postprocess']['max_num'], 8))
        mask = np.zeros(self.params['postprocess']['max_num'])
        object_bbx_center[:object_stack.shape[0], :] = object_stack
        mask[:object_stack.shape[0]] = 1

        # merge preprocessed features from different cavs into the same dict
        merged_feature_dict = self.merge_features_to_dict(processed_features)

        # generate targets label
        label_dict = \
            self.post_processor.generate_label(
                gt_box_center=object_bbx_center,
                mask=mask)

        transformation_matrix = np.stack(transformation_matrix)
        padding_eye = np.tile(np.eye(4)[None], (self.max_cav - len(
                                               transformation_matrix), 1, 1))
        transformation_matrix = np.concatenate(
            [transformation_matrix, padding_eye], axis=0)

        processed_data_dict['ego'].update(
            {'object_bbx_center': object_bbx_center,
             'object_bbx_mask': mask,
             'object_ids': [object_id_stack[i] for i in unique_indices],
             'all_anchors': None,
             'processed_lidar': merged_feature_dict,
             'label_dict': label_dict,
             'cav_num': cav_num,
             'cav_camera_data': cav_camera_data,
             'cav_camera_intrinsic': cav_camera_intrinsic,
             'cav_camera_extrinsic': cav_camera2ego,
             'rsu_camera_data': rsu_camera_data,
             'rsu_camera_intrinsic': rsu_camera_intrinsic,
             'rsu_camera_extrinsic': rsu_camera2ego,
             'transformation_matrix': transformation_matrix,
             'data_loc': data_loc,
             'ego_avaliable_mask': self.ego_avaliable_mask[idx],
             'agent_avaliable_mask': self.agent_avaliable_mask[idx]})

        if self.visualize:
            processed_data_dict['ego'].update(
                {'origin_lidar': projected_lidar_stack})
        return processed_data_dict

    def padding_rsu_cam(self, cam):
        padding_cam = np.zeros(cam.shape)
        return np.concatenate([cam, padding_cam], axis=0)

    def get_item_single_car(self, selected_cav_base, ego_pose):
        """
        Project the lidar and bbx to ego space first, and then do clipping.

        Parameters
        ----------
        selected_cav_base : dict
            The dictionary contains a single CAV's raw information.
        ego_pose : list
            The ego vehicle lidar pose under world coordinate.

        Returns
        -------
        selected_cav_processed : dict
            The dictionary contains the cav's processed information.
        """
        selected_cav_processed = {}

        # calculate the transformation matrix
        transformation_matrix = \
            selected_cav_base['params']['transformation_matrix']

        # retrieve objects under ego coordinates
        object_bbx_center, object_bbx_mask, object_ids = \
            self.post_processor.generate_object_center([selected_cav_base],
                                                       ego_pose)

        # filter lidar
        lidar_np = selected_cav_base['lidar_np']
        lidar_np = shuffle_points(lidar_np)
        # remove points that hit itself
        lidar_np = mask_ego_points(lidar_np)

        all_camera_data = []
        all_camera_origin = []
        all_camera_intrinsic = []
        all_camera_extrinsic = []

        # project the lidar to ego space
        if self.proj_first:
            lidar_np[:, :3] = \
                box_utils.project_points_by_matrix_torch(lidar_np[:, :3],
                                                         transformation_matrix)
            projected_lidar_original = lidar_np
        # data augmentation
        lidar_np, object_bbx_center, object_bbx_mask = \
            self.augment(lidar_np, object_bbx_center, object_bbx_mask,
                         selected_cav_base['flip'],
                         selected_cav_base['noise_rotation'],
                         selected_cav_base['noise_scale'])
        lidar_np = mask_points_by_range(lidar_np,
                                        self.params['preprocess'][
                                            'cav_lidar_range'])
        # Check if filtered LiDAR points are not void
        void_lidar = True if lidar_np.shape[0] < 1 else False

        # filter out the augmented bbx that is out of range
        object_bbx_center_valid = object_bbx_center[object_bbx_mask == 1]
        object_bbx_center_valid, range_mask = \
            box_utils.mask_boxes_outside_range_numpy(object_bbx_center_valid,
                                                     self.params['preprocess'][
                                                         'cav_lidar_range'],
                                                     self.params['postprocess'][
                                                         'order'], return_mask=True)
        object_ids = [int(x) for x in list(np.array(object_ids)[range_mask])]
        processed_lidar = self.pre_processor.preprocess_lidar(lidar_np)

        for camera_id, camera_data in selected_cav_base['camera_np'].items():
            if camera_data is not None:
                all_camera_origin.append(camera_data)
                camera_data = self.pre_processor.preprocess_rgb(camera_data, self.img_format)
                camera_intrinsic = \
                    selected_cav_base['params'][camera_id][
                        'intrinsic']
                cam2ego = \
                    selected_cav_base['params'][camera_id][
                        'extrinsic']

                all_camera_data.append(camera_data)
                all_camera_intrinsic.append(camera_intrinsic)
                all_camera_extrinsic.append(cam2ego)

        camera_dict = {
            'origin_data': np.stack(all_camera_origin),
            'data': np.stack(all_camera_data),
            'intrinsic': np.stack(all_camera_intrinsic),
            'extrinsic': np.stack(all_camera_extrinsic)
        }

        selected_cav_processed.update(
            {'object_bbx_center': object_bbx_center_valid,
             'object_ids': object_ids,
             'projected_lidar': lidar_np,
             'processed_features': processed_lidar,
             'projected_lidar_original': projected_lidar_original,
             'camera': camera_dict})  # for broadcasting approach, the visualization lidar is not projected and need changes

        return selected_cav_processed, void_lidar

    @staticmethod
    def merge_features_to_dict(processed_feature_list):
        """
        Merge the preprocessed features from different cavs to the same
        dictionary.

        Parameters
        ----------
        processed_feature_list : list
            A list of dictionary containing all processed features from
            different cavs.

        Returns
        -------
        merged_feature_dict: dict
            key: feature names, value: list of features.
        """

        merged_feature_dict = OrderedDict()

        for i in range(len(processed_feature_list)):
            for feature_name, feature in processed_feature_list[i].items():
                if feature_name not in merged_feature_dict:
                    merged_feature_dict[feature_name] = []
                if isinstance(feature, list):
                    merged_feature_dict[feature_name] += feature
                else:
                    merged_feature_dict[feature_name].append(feature)

        return merged_feature_dict

    def collate_batch_train(self, batch):
        # Intermediate fusion is different the other two
        output_dict = {'ego': {}}

        object_bbx_center = []
        object_bbx_mask = []
        object_ids = []
        processed_lidar_list = []

        cav_cam_rgb_all_batch = []
        cav_cam_to_ego_all_batch = []
        cav_cam_intrinsic_all_batch = []
        rsu_cam_rgb_all_batch = []
        rsu_cam_to_ego_all_batch = []
        rsu_cam_intrinsic_all_batch = []

        # used to record different scenario
        record_len = []
        label_dict_list = []

        # get transformation matrices
        transformation_matrix_all_batch = []

        agent_avaliable_mask_list = []

        data_loc_batch = []

        if self.visualize:
            origin_lidar = []

        for i in range(len(batch)):
            ego_dict = batch[i]['ego']

            # Get camera data
            cav_camera_data = ego_dict['cav_camera_data']
            cav_camera_intrinsic = ego_dict['cav_camera_intrinsic']
            cav_camera_extrinsic = ego_dict['cav_camera_extrinsic']
            rsu_camera_data = ego_dict['rsu_camera_data']
            rsu_camera_intrinsic = ego_dict['rsu_camera_intrinsic']
            rsu_camera_extrinsic = ego_dict['rsu_camera_extrinsic']

            # Combine camera data
            cav_cam_rgb_all_batch.append(cav_camera_data)
            cav_cam_intrinsic_all_batch.append(cav_camera_intrinsic)
            cav_cam_to_ego_all_batch.append(cav_camera_extrinsic)
            rsu_cam_rgb_all_batch.append(rsu_camera_data)
            rsu_cam_intrinsic_all_batch.append(rsu_camera_intrinsic)
            rsu_cam_to_ego_all_batch.append(rsu_camera_extrinsic)

            # Get object bbx
            object_bbx_center.append(ego_dict['object_bbx_center'])
            object_bbx_mask.append(ego_dict['object_bbx_mask'])
            object_ids.append(ego_dict['object_ids'])

            processed_lidar_list.append(ego_dict['processed_lidar'])
            record_len.append(ego_dict['cav_num'])
            label_dict_list.append(ego_dict['label_dict'])

            transformation_matrix_all_batch.append(
                ego_dict['transformation_matrix'])

            data_loc_batch.append(ego_dict['data_loc'])
            agent_avaliable_mask_list.append(ego_dict['agent_avaliable_mask'])

            if self.visualize:
                origin_lidar.append(ego_dict['origin_lidar'])
        # convert to numpy, (B, max_num, 7)
        object_bbx_center = torch.from_numpy(np.array(object_bbx_center))
        object_bbx_mask = torch.from_numpy(np.array(object_bbx_mask))

        # example: {'voxel_features':[np.array([1,2,3]]),
        # np.array([3,5,6]), ...]}
        merged_feature_dict = self.merge_features_to_dict(processed_lidar_list)
        processed_lidar_torch_dict = \
            self.pre_processor.collate_batch(merged_feature_dict)
        # [2, 3, 4, ..., M], M <= max_cav
        record_len = torch.from_numpy(np.array(record_len, dtype=int))
        label_torch_dict = \
            self.post_processor.collate_batch(label_dict_list)

        # (B*L, M, H, W, C)
        # L: number of agents; M: number of camera per vehicle
        cav_cam_rgb_all_batch = torch.from_numpy(
            np.concatenate(cav_cam_rgb_all_batch, axis=0)).unsqueeze(1).float()
        cav_cam_intrinsic_all_batch = torch.from_numpy(
            np.concatenate(cav_cam_intrinsic_all_batch, axis=0)).unsqueeze(1).float()
        cav_cam_to_ego_all_batch = torch.from_numpy(
            np.concatenate(cav_cam_to_ego_all_batch, axis=0)).unsqueeze(1).float()
        rsu_cam_rgb_all_batch = torch.from_numpy(
            np.concatenate(rsu_cam_rgb_all_batch, axis=0)).unsqueeze(1).float()
        rsu_cam_intrinsic_all_batch = torch.from_numpy(
            np.concatenate(rsu_cam_intrinsic_all_batch, axis=0)).unsqueeze(1).float()
        rsu_cam_to_ego_all_batch = torch.from_numpy(
            np.concatenate(rsu_cam_to_ego_all_batch, axis=0)).unsqueeze(1).float()

        # (B, max_cav)
        transformation_matrix_all_batch = \
            torch.from_numpy(np.stack(transformation_matrix_all_batch)).float()

        agent_avaliable_masks = torch.from_numpy(np.array(agent_avaliable_mask_list))
        # object id is only used during inference, where batch size is 1.
        # so here we only get the first element.
        output_dict['ego'].update({'object_bbx_center': object_bbx_center,
                                   'object_bbx_mask': object_bbx_mask,
                                   'processed_lidar': processed_lidar_torch_dict,
                                   'record_len': record_len,
                                   'label_dict': label_torch_dict,
                                   'object_ids': object_ids[0],
                                   'transformation_matrix': transformation_matrix_all_batch,
                                   'cav_camera': cav_cam_rgb_all_batch,
                                   'cav_intrinsic': cav_cam_intrinsic_all_batch,
                                   'cav_extrinsic': cav_cam_to_ego_all_batch,
                                   'rsu_camera': rsu_cam_rgb_all_batch,
                                   'rsu_intrinsic': rsu_cam_intrinsic_all_batch,
                                   'rsu_extrinsic': rsu_cam_to_ego_all_batch,
                                   'data_locs': data_loc_batch,
                                   'agent_avaliable_masks': agent_avaliable_masks})

        if self.visualize:
            origin_lidar = [torch.from_numpy(lidar) for lidar in origin_lidar[0]]
            output_dict['ego'].update({'origin_lidar': origin_lidar})

        return output_dict

    def collate_batch_test(self, batch):
        assert len(batch) <= 1, "Batch size 1 is required during testing!"
        output_dict = self.collate_batch_train(batch)

        # save the transformation matrix (4, 4) to ego vehicle
        # transformation_matrix_batch = \
        #     torch.from_numpy(np.stack(batch[0]['ego']['transformation_matrix'])).float()
        transformation_matrix_torch = \
            torch.from_numpy(np.identity(4)).float()
        output_dict['ego'].update({'transformation_matrix_test':
                                       transformation_matrix_torch})
                                   # 'transformation_matrix':
                                   #     transformation_matrix_batch})


        return output_dict

    def post_process(self, data_dict, output_dict):
        """
        Process the outputs of the model to 2D/3D bounding box.

        Parameters
        ----------
        data_dict : dict
            The dictionary containing the origin input data of model.

        output_dict :dict
            The dictionary containing the output of the model.

        Returns
        -------
        pred_box_tensor : torch.Tensor
            The tensor of prediction bounding box after NMS.
        gt_box_tensor : torch.Tensor
            The tensor of gt bounding box.
        """
        pred_box_tensor, pred_score = \
            self.post_processor.post_process(data_dict, output_dict) #post_process_real
        gt_box_tensor, gt_label_tensor = self.post_processor.generate_gt_bbx(data_dict)

        return pred_box_tensor, pred_score, gt_box_tensor, gt_label_tensor
