# -*- coding: utf-8 -*-
# Author: Runsheng Xu <rxx3386@ucla.edu>
# License: TDG-Attribution-NonCommercial-NoDistrib

"""
Dataset class for late fusion
"""
import random
import math
from collections import OrderedDict

import numpy as np
import torch
from torch.utils.data import DataLoader

import opencood.data_utils.datasets
from opencood.data_utils.post_processor import build_postprocessor
from opencood.data_utils.datasets import basedataset
from opencood.data_utils.pre_processor import build_preprocessor
from opencood.hypes_yaml.yaml_utils import load_yaml
from opencood.utils import box_utils
from opencood.utils.pcd_utils import \
    mask_points_by_range, mask_ego_points, shuffle_points, \
    downsample_lidar_minimum
from opencood.utils.transformation_utils import x1_to_x2


class LateFusionDataset(basedataset.BaseDataset):
    """
    This class is for intermediate fusion where each vehicle transmit the
    detection outputs to ego.
    """
    def __init__(self, params, visualize, work='train'):
        super(LateFusionDataset, self).__init__(params, visualize, work)
        self.system = params['system']
        if self.system:
            print(f"Type of CP is {self.system}")
        self.pre_processor = build_preprocessor(params['preprocess'],
                                                work)
            
        self.post_processor = build_postprocessor(params['postprocess'], work)

    def __getitem__(self, idx):
        base_data_dict = self.retrieve_base_data(idx)
        if self.work == 'train':
            reformat_data_dict = self.get_item_train(base_data_dict)
        else:
            reformat_data_dict = self.get_item_test(base_data_dict)

        return reformat_data_dict

    def get_item_single_car(self, selected_cav_base, cav_id, ego_id=None, ego_lidar_pose=None):
        """
        Process a single CAV's information for the train/test pipeline.

        Parameters
        ----------
        selected_cav_base : dict
            The dictionary contains a single CAV's raw information.

        Returns
        -------
        selected_cav_processed : dict
            The dictionary contains the cav's processed information.
        """
        selected_cav_processed = {}

        # filter lidar
        lidar_np = selected_cav_base['lidar_np']
        lidar_np = shuffle_points(lidar_np)
        lidar_np = mask_points_by_range(lidar_np,
                                        self.params['preprocess'][
                                            'v_lidar_range'])
        # remove points that hit ego vehicle
        lidar_np = mask_ego_points(lidar_np)
        
        if ego_id is None:
            is_infra = int(cav_id) == -1

            # generate the bounding box(n, 7) under the cav's space
            object_bbx_center, object_bbx_mask, object_ids = \
                self.post_processor.generate_object_center([selected_cav_base],
                                                        selected_cav_base[
                                                            'params'][
                                                            'lidar_pose'],
                                                            is_infra)
                
        elif ego_id is not None:
            is_infra = int(ego_id) == -1
            
            # generate the bounding box(n, 7) under the cav's space
            object_bbx_center, object_bbx_mask, object_ids = \
                self.post_processor.generate_object_center([selected_cav_base],
                                                        ego_lidar_pose,
                                                            is_infra)
        
        # data augmentation
        lidar_np, object_bbx_center, object_bbx_mask = \
            self.augment(lidar_np, object_bbx_center, object_bbx_mask)

        # if self.visualize and int(cav_id) == int(ego_id):
        if self.visualize:
            selected_cav_processed.update({'origin_lidar': lidar_np})

        # pre-process the lidar to voxel/bev/downsampled lidar
        lidar_dict = self.pre_processor.preprocess(lidar_np, cav_id)
        selected_cav_processed.update({'processed_lidar': lidar_dict})

        # generate the anchor boxes
        anchor_box = self.post_processor.generate_anchor_box(cav_id)
        selected_cav_processed.update({'is_infra': is_infra})
        selected_cav_processed.update({'anchor_box': anchor_box})

        selected_cav_processed.update({'object_bbx_center': object_bbx_center,
                                       'object_bbx_mask': object_bbx_mask,
                                       'object_ids': object_ids})

        # generate targets label
        label_dict = \
            self.post_processor.generate_label(
                gt_box_center=object_bbx_center,
                anchors=anchor_box,
                mask=object_bbx_mask)
        selected_cav_processed.update({'label_dict': label_dict})

        return selected_cav_processed
    
    def get_item_single_infra(self, selected_cav_base, cav_id, ego_id=None, ego_lidar_pose=None):
        """
        Process a single CAV's information for the train/test pipeline.

        Parameters
        ----------
        selected_cav_base : dict
            The dictionary contains a single CAV's raw information.

        Returns
        -------
        selected_cav_processed : dict
            The dictionary contains the cav's processed information.
        """
        selected_cav_processed = {}

        # filter lidar
        lidar_np = selected_cav_base['lidar_np']
        lidar_np = shuffle_points(lidar_np)
        lidar_np = mask_points_by_range(lidar_np,
                                        self.params['preprocess'][
                                            'i_lidar_range'])
        # remove points that hit ego vehicle
        lidar_np = mask_ego_points(lidar_np)
        
        if ego_id is None:
            is_infra = int(cav_id) == -1

            # generate the bounding box(n, 7) under the cav's space
            object_bbx_center, object_bbx_mask, object_ids = \
                self.post_processor.generate_object_center([selected_cav_base],
                                                        selected_cav_base[
                                                            'params'][
                                                            'lidar_pose'],
                                                            is_infra)
                
        elif ego_id is not None:
            is_infra = int(ego_id) == -1
            
            # generate the bounding box(n, 7) under the cav's space
            object_bbx_center, object_bbx_mask, object_ids = \
                self.post_processor.generate_object_center([selected_cav_base],
                                                        ego_lidar_pose,
                                                            is_infra)
            
        # data augmentation
        lidar_np, object_bbx_center, object_bbx_mask = \
            self.augment(lidar_np, object_bbx_center, object_bbx_mask)

        if self.visualize:
        # if self.visualize and int(cav_id) == int(ego_id):
            selected_cav_processed.update({'origin_lidar': lidar_np})

        # pre-process the lidar to voxel/bev/downsampled lidar
        lidar_dict = self.pre_processor.preprocess(lidar_np, cav_id)
        
        selected_cav_processed.update({'is_infra': is_infra})
        selected_cav_processed.update({'processed_lidar': lidar_dict})

        # generate the anchor boxes
        anchor_box = self.post_processor.generate_anchor_box(cav_id)
        selected_cav_processed.update({'anchor_box': anchor_box})

        selected_cav_processed.update({'object_bbx_center': object_bbx_center,
                                       'object_bbx_mask': object_bbx_mask,
                                       'object_ids': object_ids})

        # generate targets label
        label_dict = \
            self.post_processor.generate_label(
                gt_box_center=object_bbx_center,
                anchors=anchor_box,
                mask=object_bbx_mask)
        selected_cav_processed.update({'label_dict': label_dict})

        return selected_cav_processed

    def get_item_train(self, base_data_dict):
        processed_data_dict = OrderedDict()
        
        processed_data_dict['ego'] = OrderedDict()

        if self.system:
            if self.system == 'V2V':
                tmp_dict = base_data_dict
                
                if '-1' in tmp_dict.keys():
                    del tmp_dict['-1']
                
                selected_cav_id, selected_cav_base = \
                    random.choice(list(tmp_dict.items()))
                    
                assert int(selected_cav_id) != -1
                    
            elif self.system == "I2X":
                selected_cav_id, selected_cav_base = \
                    list(base_data_dict.items())[0]
                    
                assert int(selected_cav_id) == -1
                
        else:
            # during training, we return a random cav's data
            if not self.visualize:                
                selected_cav_id, selected_cav_base = \
                    random.choice(list(base_data_dict.items()))
                    
            else:
                selected_cav_id, selected_cav_base = \
                    list(base_data_dict.items())[0]

        if int(selected_cav_id) < 0:
            selected_cav_processed = self.get_item_single_infra(selected_cav_base, selected_cav_id)
            processed_data_dict['ego']['infra'] = selected_cav_processed
        elif int(selected_cav_id) >= 0:
            selected_cav_processed = self.get_item_single_car(selected_cav_base, selected_cav_id)
            processed_data_dict['ego']['vehicle'] = selected_cav_processed
        else:
            print("Wrong type of agent")
            
        # processed_data_dict.update({'ego': selected_cav_processed})

        return processed_data_dict

    def get_item_test(self, base_data_dict):
        processed_data_dict = OrderedDict()
        ego_id = -1
        ego_lidar_pose = []

        # first find the ego vehicle's lidar pose
        for cav_id, cav_content in base_data_dict.items():
            if cav_content['ego']:
                ego_id = cav_id
                ego_lidar_pose = cav_content['params']['lidar_pose']
                break

        assert len(ego_lidar_pose) > 0

        # loop over all CAVs to process information
        for cav_id, selected_cav_base in base_data_dict.items():
            distance = \
                math.sqrt((selected_cav_base['params']['lidar_pose'][0] -
                           ego_lidar_pose[0])**2 + (
                                      selected_cav_base['params'][
                                          'lidar_pose'][1] - ego_lidar_pose[
                                          1])**2)
            if distance > opencood.data_utils.datasets.COM_RANGE:
                continue

            # find the transformation matrix from current cav to ego.
            cav_lidar_pose = selected_cav_base['params']['lidar_pose']
            transformation_matrix = x1_to_x2(cav_lidar_pose, ego_lidar_pose)
            
            if int(cav_id) < 0:
                selected_cav_processed = self.get_item_single_infra(selected_cav_base, cav_id, ego_id, ego_lidar_pose)
                # processed_data_dict['ego']['infra'] = selected_cav_processed
            elif int(cav_id) >= 0:
                selected_cav_processed = self.get_item_single_car(selected_cav_base, cav_id, ego_id, ego_lidar_pose)
                # processed_data_dict['ego']['vehicle'] = selected_cav_processed
            else:
                print("Wrong type of agent")

            # selected_cav_processed = \
            #     self.get_item_single_car(selected_cav_base, cav_id)
                
            selected_cav_processed.update({'transformation_matrix':
                                               transformation_matrix})
            update_cav = "ego" if cav_id == ego_id else cav_id
            
            if int(cav_id) < 0:
                processed_data_dict[update_cav] = OrderedDict()
                processed_data_dict[update_cav]['infra'] = selected_cav_processed
            
            elif int(cav_id) >= 0:
                processed_data_dict[update_cav] = OrderedDict()
                processed_data_dict[update_cav]['vehicle'] = selected_cav_processed

        return processed_data_dict

    def collate_batch_train(self, batch):
        """
        Customized collate function for pytorch dataloader during training
        for early and late fusion dataset.

        Parameters
        ----------
        batch : dict

        Returns
        -------
        batch : dict
            Reformatted batch.
        """
        # during training, we only care about ego.
        output_dict = {'ego': {}}
        
        output_dict['ego']['vehicle'] = {}
        output_dict['ego']['infra'] = {}
        output_dict['ego']['label_dict'] = {}

        v_object_bbx_center = []
        v_object_bbx_mask = []
        v_processed_lidar_list = []
        v_label_dict_list = []
        v_is_infra_list = []

        i_object_bbx_center = []
        i_object_bbx_mask = []
        i_processed_lidar_list = []
        i_label_dict_list = []
        i_is_infra_list = []

        if self.visualize:
            origin_lidar = []

        for i in range(len(batch)):
            if 'vehicle' in batch[i]['ego']:
                ego_dict = batch[i]['ego']['vehicle']
            
                v_object_bbx_center.append(ego_dict['object_bbx_center'])
                v_object_bbx_mask.append(ego_dict['object_bbx_mask'])
                v_processed_lidar_list.append(ego_dict['processed_lidar'])
                v_label_dict_list.append(ego_dict['label_dict'])
                v_is_infra_list.append(ego_dict['is_infra'])
                
            elif 'infra' in batch[i]['ego']:
                ego_dict = batch[i]['ego']['infra']
                
                i_object_bbx_center.append(ego_dict['object_bbx_center'])
                i_object_bbx_mask.append(ego_dict['object_bbx_mask'])
                i_processed_lidar_list.append(ego_dict['processed_lidar'])
                i_label_dict_list.append(ego_dict['label_dict'])
                i_is_infra_list.append(ego_dict['is_infra'])
                
                    
            if self.visualize:
                origin_lidar.append(ego_dict['origin_lidar'])

        if len(v_object_bbx_center) != 0:
            v_object_bbx_center = torch.from_numpy(np.array(v_object_bbx_center))
            v_object_bbx_mask = torch.from_numpy(np.array(v_object_bbx_mask))
            
            v_processed_lidar_torch_dict = \
                self.pre_processor.collate_batch(v_processed_lidar_list)
            v_label_torch_dict = \
                self.post_processor.collate_batch(v_label_dict_list)
                
            output_dict['ego']['vehicle'].update({'object_bbx_center': v_object_bbx_center,
                                            'object_bbx_mask': v_object_bbx_mask,
                                            'processed_lidar': v_processed_lidar_torch_dict,
                                            # 'label_dict': v_label_torch_dict,
                                            'is_infra': v_is_infra_list})
            
            output_dict['ego']['label_dict'].update({'vehicle': v_label_torch_dict})
        
        
        if len(i_object_bbx_center) != 0:
            i_object_bbx_center = torch.from_numpy(np.array(i_object_bbx_center))
            i_object_bbx_mask = torch.from_numpy(np.array(i_object_bbx_mask))

            i_processed_lidar_torch_dict = \
                self.pre_processor.collate_batch(i_processed_lidar_list)
            i_label_torch_dict = \
                self.post_processor.collate_batch(i_label_dict_list)
            
            output_dict['ego']['infra'].update({'object_bbx_center': i_object_bbx_center,
                                            'object_bbx_mask': i_object_bbx_mask,
                                            'processed_lidar': i_processed_lidar_torch_dict,
                                            # 'label_dict': i_label_torch_dict,
                                            'is_infra': i_is_infra_list})
            
            output_dict['ego']['label_dict'].update({'infra': i_label_torch_dict})
        
        if self.visualize:
            origin_lidar = \
                np.array(downsample_lidar_minimum(pcd_np_list=origin_lidar))
            origin_lidar = torch.from_numpy(origin_lidar)
            output_dict['ego'].update({'origin_lidar': origin_lidar})

        return output_dict

    def collate_batch_test(self, batch):
        """
        Customized collate function for pytorch dataloader during testing
        for late fusion dataset.

        Parameters
        ----------
        batch : dict

        Returns
        -------
        batch : dict
            Reformatted batch.
        """
        # currently, we only support batch size of 1 during testing
        assert len(batch) <= 1, "Batch size 1 is required during testing!"
        batch = batch[0]

        output_dict = {}

        # for late fusion, we also need to stack the lidar for better
        # visualization
        if self.visualize:
            projected_lidar_list = []
            origin_lidar = []

        for cav_id, cav_content_ in batch.items():
            keys = list(cav_content_.keys())
            cav_content = cav_content_[keys[0]]
            
            output_dict.update({cav_id: {}})
            # shape: (1, max_num, 7)
            object_bbx_center = \
                torch.from_numpy(np.array([cav_content['object_bbx_center']]))
            object_bbx_mask = \
                torch.from_numpy(np.array([cav_content['object_bbx_mask']]))
            object_ids = cav_content['object_ids']

            # the anchor box is the same for all bounding boxes usually, thus
            # we don't need the batch dimension.
            if self.visualize:
                transformation_matrix = cav_content['transformation_matrix']
                origin_lidar = [cav_content['origin_lidar']]
                
                if cav_id == "ego":

                    projected_lidar = cav_content['origin_lidar']
                    projected_lidar[:, :3] = \
                        box_utils.project_points_by_matrix_torch(
                            projected_lidar[:, :3],
                            transformation_matrix)
                            
                    # projected_lidar = mask_points_by_range(projected_lidar,
                    #                                 self.params['preprocess'][
                    #                                     'v_lidar_range'])
            
                    projected_lidar_list.append(projected_lidar)

            # processed lidar dictionary
            processed_lidar_torch_dict = \
                self.pre_processor.collate_batch(
                    [cav_content['processed_lidar']])
            # label dictionary
            label_torch_dict = \
                self.post_processor.collate_batch([cav_content['label_dict']])

            # save the transformation matrix (4, 4) to ego vehicle
            transformation_matrix_torch = \
                torch.from_numpy(
                    np.array(cav_content['transformation_matrix'])).float()
                
            output_dict[cav_id][keys[0]] = OrderedDict()
            output_dict[cav_id][keys[0]].update({'object_bbx_center': object_bbx_center,
                                                'object_bbx_mask': object_bbx_mask,
                                                'processed_lidar': processed_lidar_torch_dict,
                                                'label_dict': label_torch_dict,
                                                'object_ids': object_ids,
                                                'transformation_matrix': transformation_matrix_torch})
            if cav_content['anchor_box'] is not None:
                output_dict[cav_id][keys[0]].update({'anchor_box':
                    torch.from_numpy(np.array(
                        cav_content[
                            'anchor_box']))})

            if self.visualize:
                origin_lidar = \
                    np.array(
                        downsample_lidar_minimum(pcd_np_list=origin_lidar))
                origin_lidar = torch.from_numpy(origin_lidar)
                output_dict[cav_id].update({'origin_lidar': origin_lidar})

        if self.visualize:
            projected_lidar_stack = torch.from_numpy(
                np.vstack(projected_lidar_list))
            output_dict['ego'].update({'origin_lidar': projected_lidar_stack})

        return output_dict

    def post_process(self, data_dict, output_dict):
        """
        Process the outputs of the model to 2D/3D bounding box.

        Parameters
        ----------
        data_dict : dict
            The dictionary containing the origin input data of model.

        output_dict :dict
            The dictionary containing the output of the model.

        Returns
        -------
        pred_box_tensor : torch.Tensor
            The tensor of prediction bounding box after NMS.
        gt_box_tensor : torch.Tensor
            The tensor of gt bounding box.
        """
        pred_box_tensor, pred_score = \
            self.post_processor.post_process(data_dict, output_dict)
        gt_box_tensor = self.post_processor.generate_gt_bbx_late(data_dict)

        return pred_box_tensor, pred_score, gt_box_tensor
