# ------------------------------------------------------------------------
# Copyright (c) 2023 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from DETR3D (https://github.com/WangYueFt/detr3d)
# Copyright (c) 2021 Wang, Yue
# ------------------------------------------------------------------------
# Modified from mmdetection3d (https://github.com/open-mmlab/mmdetection3d)
# Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------

import torch
import numpy as np
from mmdet.datasets import DATASETS
from mmdet.datasets.custom import CustomDataset
from mmdet3d.datasets import NuScenesDataset
from sklearn.metrics import hamming_loss, accuracy_score, f1_score, jaccard_score, roc_auc_score

import json
import itertools
from copy import deepcopy
from typing import List, Dict


def get_combinations(input_list):
    all_combinations = []
    for r in range(0, len(input_list)):
        combinations = itertools.combinations(input_list, r)
        all_combinations.extend([sorted(list(combo)) for combo in combinations])
    return all_combinations

@DATASETS.register_module()
class SceneClassificationDataset(NuScenesDataset):
    r"""NuScenes Dataset.

    This datset only add camera intrinsics and extrinsics to the results.
    """

    CAM_NAMES = ['CAM_FRONT', 'CAM_FRONT_LEFT', 'CAM_FRONT_RIGHT', 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT']

    def __init__(self, scene_ann_file, *args, return_gt_info=False, used_cams=None, limit=None, **kwargs):
        super(SceneClassificationDataset, self).__init__(*args, **kwargs)
        self.scene_ann_file = scene_ann_file

        with open(self.scene_ann_file, 'r') as f:
            scene_anns = json.load(f)
        scene_anns_map = dict([(scene_ann['token'], scene_ann['drop_cams']) for scene_ann in scene_anns])
        for k, v in scene_anns_map.items():
            scene_anns_map[k] = ','.join(sorted(v))
        self.scene_anns_map = scene_anns_map

        name_combinations = get_combinations(SceneClassificationDataset.CAM_NAMES)
        label_names = [','.join(sorted(name_combination)) for name_combination in name_combinations]
        label_name_to_id = dict([(label_name, idx) for idx, label_name in enumerate(label_names)])
        self.label_name_to_id = label_name_to_id
        self.label_id_to_name = label_names
        label_id_to_binary = dict()
        for label_name in label_names:
            cam_names = label_name.split(',')
            binary = []
            for CAM_NAME in SceneClassificationDataset.CAM_NAMES:
                if CAM_NAME in cam_names:
                    binary.append(1)
                else:
                    binary.append(0)
            label_id = self.label_name_to_id[label_name]
            assert label_id not in label_id_to_binary
            label_id_to_binary[label_id] = binary
        self.label_id_to_binary = label_id_to_binary

        self.return_gt_info = return_gt_info
        self.used_cams = used_cams
        if self.used_cams is not None:
            self.used_cams = set(self.used_cams)

        self.limit = limit
        if self.limit is not None:
            self.data_infos = self.data_infos[:self.limit]

        print(f'>>>>>> {len(self.data_infos)=}')

    def get_data_info(self, index):
        """Get data info according to the given index.

        Args:
            index (int): Index of the sample data to get.

        Returns:
            dict: Data information that will be passed to the data \
                preprocessing pipelines. It includes the following keys:

                - sample_idx (str): Sample index.
                - pts_filename (str): Filename of point clouds.
                - sweeps (list[dict]): Infos of sweeps.
                - timestamp (float): Sample timestamp.
                - img_filename (str, optional): Image filename.
                - lidar2img (list[np.ndarray], optional): Transformations \
                    from lidar to different cameras.
                - ann_info (dict): Annotation info.
        """
        info = self.data_infos[index]

        if self.used_cams is not None:
            for cam_name in list(info['cams'].keys()):
                if cam_name not in self.used_cams:
                    del info['cams'][cam_name]
        assert set(info['cams'].keys()) == self.used_cams, f'available cameras: {set(info["cams"].keys())}, self.used_cams: {self.used_cams}'

        # standard protocal modified from SECOND.Pytorch
        input_dict = dict(
            sample_idx=info['token'],
            pts_filename=info['lidar_path'],
            sweeps=info['sweeps'],
            timestamp=info['timestamp'] / 1e6,
            img_sweeps=None if 'img_sweeps' not in info else info['img_sweeps'],
            radar_info=None if 'radars' not in info else info['radars'],
            cls_labels=self.label_name_to_id[self.scene_anns_map[info['token']]],
        )

        if self.return_gt_info:
            input_dict['info'] = info

        if self.modality['use_camera']:
            image_paths = []
            lidar2img_rts = []
            lidar2cam_rts = []
            cam_intrinsics = []
            img_timestamp = []
            for cam_type, cam_info in info['cams'].items():
                img_timestamp.append(cam_info['timestamp'] / 1e6)
                image_paths.append(cam_info['data_path'])
                # obtain lidar to image transformation matrix
                lidar2cam_r = np.linalg.inv(cam_info['sensor2lidar_rotation'])
                lidar2cam_t = cam_info[
                    'sensor2lidar_translation'] @ lidar2cam_r.T
                lidar2cam_rt = np.eye(4)
                lidar2cam_rt[:3, :3] = lidar2cam_r.T
                lidar2cam_rt[3, :3] = -lidar2cam_t
                intrinsic = cam_info['cam_intrinsic']
                viewpad = np.eye(4)
                viewpad[:intrinsic.shape[0], :intrinsic.shape[1]] = intrinsic
                lidar2img_rt = (viewpad @ lidar2cam_rt.T)
                lidar2img_rts.append(lidar2img_rt)

                cam_intrinsics.append(viewpad)
                lidar2cam_rts.append(lidar2cam_rt.T)

            input_dict.update(
                dict(
                    img_timestamp=img_timestamp,
                    img_filename=image_paths,
                    lidar2img=lidar2img_rts,
                    cam_intrinsic=cam_intrinsics,
                    lidar2cam=lidar2cam_rts,
                ))

        if not self.test_mode:
            annos = self.get_ann_info(index)
            input_dict['ann_info'] = annos

        return input_dict

    def evaluate(self, results, metric='bbox', logger=None, jsonfile_prefix=None, result_names=..., show=False, out_dir=None, pipeline=None):
        total = len(results)
        top1_correct = 0
        top3_correct = 0
        num_correct = 0
        hamming_loss_total = 0
        for result in results:
            pred_top3 = result['pred_top3']
            label = result['label']
            if pred_top3[0] == label:
                top1_correct += 1
            if label in pred_top3:
                top3_correct += 1
            pred_name = self.label_id_to_name[pred_top3[0]]
            label_name = self.label_id_to_name[label]
            if len(pred_name.split(',')) == len(label_name.split(',')):
                num_correct += 1
            pred_binary = [self.label_id_to_binary[pred_top3[0]]]
            label_binary = [self.label_id_to_binary[label]]
            hamming_loss_total += hamming_loss(label_binary, pred_binary)
        acc_top1 = top1_correct / total * 100
        acc_top3 = top3_correct / total * 100
        acc_num = num_correct / total * 100
        hamming_loss_avg = hamming_loss_total / total
        return {'total': total, 'acc_top1': acc_top1, 'acc_top3': acc_top3, 'acc_num': acc_num, 'hamming_loss': hamming_loss_avg}


@DATASETS.register_module()
class NuScenesCameraLiDARWithUnimportantDeviceDataset(NuScenesDataset):

    CAM_NAMES = ['CAM_FRONT', 'CAM_FRONT_LEFT', 'CAM_FRONT_RIGHT', 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT']
    LIDAR_NAMES = ['LIDAR_TOP']

    # def __init__(
    #     self,
    #     data_root: str,
    #     info_file: str,
    #     label_file: str,
    #     image_size: Tuple[int],
    #     is_training: bool = True,
    #     used_cams: Optional[List[str]] = None,
    #     auto_augment: Optional[str] = None,
    #     return_token: bool = False,
    # ) -> None:
    def __init__(self, label_file, *args, return_gt_info=False, used_cams=None, limit=None, **kwargs):
        super(NuScenesCameraLiDARWithUnimportantDeviceDataset, self).__init__(*args, **kwargs)
        self.used_cams = used_cams
        if self.used_cams is None:
            self.used_cams = set(deepcopy(NuScenesCameraLiDARWithUnimportantDeviceDataset.CAM_NAMES))

        with open(label_file, 'r') as f:
            label_infos = json.load(f)
        self.token_to_used_devices = dict([(label_info['token'], label_info['used_devices']) for label_info in label_infos])
        self.token_to_unimportant_devices = dict([(label_info['token'], label_info['unimportant_devices']) for label_info in label_infos])

        self.label_id_to_name = deepcopy(NuScenesCameraLiDARWithUnimportantDeviceDataset.CAM_NAMES + NuScenesCameraLiDARWithUnimportantDeviceDataset.LIDAR_NAMES)
        self.label_name_to_id = dict((label_name, idx) for idx, label_name in enumerate(self.label_id_to_name))
        self.num_classes = len(self.label_id_to_name)

        self.return_gt_info = return_gt_info

        self.limit = limit
        if self.limit is not None:
            self.data_infos = self.data_infos[:self.limit]

    def get_data_info(self, index):
        """Get data info according to the given index.

        Args:
            index (int): Index of the sample data to get.

        Returns:
            dict: Data information that will be passed to the data \
                preprocessing pipelines. It includes the following keys:

                - sample_idx (str): Sample index.
                - pts_filename (str): Filename of point clouds.
                - sweeps (list[dict]): Infos of sweeps.
                - timestamp (float): Sample timestamp.
                - img_filename (str, optional): Image filename.
                - lidar2img (list[np.ndarray], optional): Transformations \
                    from lidar to different cameras.
                - ann_info (dict): Annotation info.
        """
        info = self.data_infos[index]

        if self.used_cams is not None:
            for cam_name in list(info['cams'].keys()):
                if cam_name not in self.used_cams:
                    del info['cams'][cam_name]
        assert set(info['cams'].keys()) == self.used_cams, f'available cameras: {set(info["cams"].keys())}, self.used_cams: {self.used_cams}'

        cls_labels = dict()
        token = info['token']
        used_devices = self.token_to_used_devices[token]
        label = self._used_devices_to_label(used_devices)
        cls_labels['label'] = label
        unimportant_devices = self.token_to_unimportant_devices[token]
        is_unimportant = self._used_devices_to_label(unimportant_devices)
        cls_labels['is_unimportant'] = is_unimportant

        # standard protocal modified from SECOND.Pytorch
        input_dict = dict(
            sample_idx=info['token'],
            pts_filename=info['lidar_path'],
            sweeps=info['sweeps'],
            timestamp=info['timestamp'] / 1e6,
            img_sweeps=None if 'img_sweeps' not in info else info['img_sweeps'],
            radar_info=None if 'radars' not in info else info['radars'],
            cls_labels=cls_labels,
        )

        if self.return_gt_info:
            input_dict['info'] = info

        if self.modality['use_camera']:
            image_paths = []
            lidar2img_rts = []
            lidar2cam_rts = []
            cam_intrinsics = []
            img_timestamp = []
            for cam_type, cam_info in info['cams'].items():
                img_timestamp.append(cam_info['timestamp'] / 1e6)
                image_paths.append(cam_info['data_path'])
                # obtain lidar to image transformation matrix
                lidar2cam_r = np.linalg.inv(cam_info['sensor2lidar_rotation'])
                lidar2cam_t = cam_info[
                    'sensor2lidar_translation'] @ lidar2cam_r.T
                lidar2cam_rt = np.eye(4)
                lidar2cam_rt[:3, :3] = lidar2cam_r.T
                lidar2cam_rt[3, :3] = -lidar2cam_t
                intrinsic = cam_info['cam_intrinsic']
                viewpad = np.eye(4)
                viewpad[:intrinsic.shape[0], :intrinsic.shape[1]] = intrinsic
                lidar2img_rt = (viewpad @ lidar2cam_rt.T)
                lidar2img_rts.append(lidar2img_rt)

                cam_intrinsics.append(viewpad)
                lidar2cam_rts.append(lidar2cam_rt.T)

            input_dict.update(
                dict(
                    img_timestamp=img_timestamp,
                    img_filename=image_paths,
                    lidar2img=lidar2img_rts,
                    cam_intrinsic=cam_intrinsics,
                    lidar2cam=lidar2cam_rts,
                ))

        if not self.test_mode:
            annos = self.get_ann_info(index)
            input_dict['ann_info'] = annos

        return input_dict

    def evaluate(self, results, metric='bbox', logger=None, jsonfile_prefix=None, result_names=..., show=False, out_dir=None, pipeline=None):
        output = dict()
        total = len(results)

        device = torch.device('cuda:0')
        predictions = torch.concat([x['predictions'].to(device) for x in results], dim=0)
        labels = torch.concat([x['labels']['label'].to(device) for x in results], dim=0)
        is_unimportant = torch.concat([x['labels']['is_unimportant'].to(device) for x in results], dim=0)

        correct = (predictions == labels)
        correct_cls = correct.sum(dim=0)
        acc_cls = correct_cls / total * 100
        for idx, acc in enumerate(acc_cls):
            output[f'acc_{idx}'] = acc.item()
        acc_mean = acc_cls.mean()
        output['acc_mean'] = acc_mean.item()

        correct_masked = correct.clone()
        correct_masked[is_unimportant.bool()] = True
        correct_masked_cls = correct_masked.sum(dim=0)
        acc_masked_cls = correct_masked_cls / total * 100
        for idx, acc in enumerate(acc_masked_cls):
            output[f'acc_masked_{idx}'] = acc.item()
        acc_masked_mean = acc_masked_cls.mean()
        output['acc_masked_mean'] = acc_masked_mean.item()

        return output

    def _used_cams_to_label(self, used_cams: List[str]):
        label = []
        for CAM_NAME in NuScenesCameraLiDARWithUnimportantDeviceDataset.CAM_NAMES:
            if CAM_NAME in used_cams:
                label.append(1)
            else:
                label.append(0)
        return label

    def _used_devices_to_label(self, used_devices: Dict[str, List[str]]):
        label = []

        label_ids = []
        used_devices = used_devices['used_cams'] + used_devices['used_lidars']
        for label_name in used_devices:
            label_ids.append(self.label_name_to_id[label_name])

        label = torch.zeros(self.num_classes, dtype=torch.int64)
        label[label_ids] = 1

        return label
