# -*- coding: utf-8 -*-

# License: TDG-Attribution-NonCommercial-NoDistrib


import os
from collections import OrderedDict

import numpy as np
import torch

from opencood.utils.common_utils import torch_tensor_to_numpy
from opencood.utils.box_utils import boxes_to_corners_3d
from opencood.utils import box_utils

from mmdet3d.core.bbox import LiDARInstance3DBoxes


def inference_late_fusion(batch_data, model, dataset):
    """
    Model inference for late fusion.

    Parameters
    ----------
    batch_data : dict
    model : opencood.object
    dataset : opencood.LateFusionDataset

    Returns
    -------
    pred_box_tensor : torch.Tensor
        The tensor of prediction bounding box after NMS.
    gt_box_tensor : torch.Tensor
        The tensor of gt bounding box.
    """
    output_dict = OrderedDict()

    for cav_id, cav_content in batch_data.items():
        output_dict[cav_id] = model(cav_content)

    pred_box_tensor, pred_score, gt_box_tensor = \
        dataset.post_process(batch_data,
                             output_dict)

    return pred_box_tensor, pred_score, gt_box_tensor

def get_center_box(box: LiDARInstance3DBoxes):
    """
    Get BBoxes with gravity center from mmdet3d's LiDARInstance3DBoxes
    """
    box_tensor = torch.zeros_like(box.tensor)
    box_tensor[:, 3:] = box.tensor[:, 3:]
    box_tensor[:, :3] = box.gravity_center
    return box_tensor

def inference_early_fusion(batch_data, model, dataset):
    """
    Model inference for early fusion.

    Parameters
    ----------
    batch_data : dict
    model : opencood.object
    dataset : opencood.EarlyFusionDataset

    Returns
    -------
    pred_box_tensor : torch.Tensor
        The tensor of prediction bounding box after NMS.
    gt_box_tensor : torch.Tensor
        The tensor of gt bounding box.
    """
    output_dict = OrderedDict()
    cav_content = batch_data['ego']

    is_fsd = 'fsd' in str(model.__class__)

    if is_fsd:
        # dingzihan, model parameters
        #from fvcore.nn import parameter_count_table
        #print(parameter_count_table(model))
        output_dict['ego'] = model(cav_content, return_loss=False, rescale=True)
        # below code can't be used, because mmdet3d and opencood have different box2corner rule
        # which results in mod zero error in iou calculation
        # pred_box_tensor = pred_box_tensor_by_box.corners
        
        gt_box_tensor_by_box = LiDARInstance3DBoxes(cav_content['object_bbx_center']
            [cav_content['object_bbx_mask'] > 0, :], origin=(0.5, 0.5, 0.5))
        gt_box_tensor = get_center_box(gt_box_tensor_by_box)
        gt_box_tensor = boxes_to_corners_3d(gt_box_tensor, order='lwh', coord_type='right')
        # gt_box_tensor = gt_box_tensor_by_box.corners
        if isinstance(output_dict['ego'], list):
            pred_box_tensor, pred_score = get_pred_boxes(output_dict['ego'])
            ret_dicts = {
                    'pred_box': pred_box_tensor,
                    'pred_score': pred_score, 
                    'gt_box': gt_box_tensor
                }
            if 'seg_logits' in output_dict['ego'].keys():
                ret_dicts['seg_logits'] = output_dict['ego']['seg_logits']
            return ret_dicts
        
        elif isinstance(output_dict['ego'], dict):
            gt_box_tensor_by_box = LiDARInstance3DBoxes(cav_content['object_bbx_center']
                    [cav_content['object_bbx_mask'] > 0, :], origin=(0.5, 0.5, 0.5))
            gt_box_tensor = get_center_box(gt_box_tensor_by_box)
            gt_box_tensor = boxes_to_corners_3d(gt_box_tensor, order='lwh', coord_type='right')
            
            pred_score_dict = dict()
            pred_box_tensor_dict = dict()
            for stage, data in output_dict['ego'].items():
                pred_box_tensor, pred_score = get_pred_boxes(data)

                pred_score_dict[stage] = pred_score
                pred_box_tensor_dict[stage] = pred_box_tensor
            ret_dicts = {
                    'pred_box': pred_box_tensor_dict,
                    'pred_score': pred_score_dict, 
                    'gt_box': gt_box_tensor
                }
            return ret_dicts
    else:
        output_dict['ego'] = model(cav_content)
        pred_box_tensor, pred_score, gt_box_tensor = \
            dataset.post_process(batch_data,
                                 output_dict)
    ret_dicts = {
                    'pred_box': pred_box_tensor,
                    'pred_score': pred_score, 
                    'gt_box': gt_box_tensor
                }
    
    return ret_dicts

def get_pred_boxes(data):
    data = data[0]
    if 'boxes_3d' not in data.keys():
        return None
    # TODO: maybe unkown format error
    pred_box_tensor_by_box, pred_score = data['boxes_3d'], data['scores_3d']
    pred_box_tensor = get_center_box(pred_box_tensor_by_box)
    pred_box_tensor = boxes_to_corners_3d(pred_box_tensor, order='lwh', coord_type='right')

    # post process
    # keep_index_1 = box_utils.remove_large_pred_bbx(pred_box_tensor)
    # keep_index_2 = box_utils.remove_bbx_abnormal_z(pred_box_tensor)
    # keep_index = torch.logical_and(keep_index_1, keep_index_2)
    # pred_box_tensor = pred_box_tensor[keep_index]         
    # pred_score = pred_score[keep_index]
    
    return pred_box_tensor, pred_score

def inference_intermediate_fusion(batch_data, model, dataset):
    """
    Model inference for early fusion.

    Parameters
    ----------
    batch_data : dict
    model : opencood.object
    dataset : opencood.EarlyFusionDataset

    Returns
    -------
    pred_box_tensor : torch.Tensor
        The tensor of prediction bounding box after NMS.
    gt_box_tensor : torch.Tensor
        The tensor of gt bounding box.
    """
    return inference_early_fusion(batch_data, model, dataset)


def save_prediction_gt(pred_tensor, gt_tensor, pcd, timestamp, save_path):
    """
    Save prediction and gt tensor to txt file.
    """
    pred_np = torch_tensor_to_numpy(pred_tensor)
    gt_np = torch_tensor_to_numpy(gt_tensor)
    pcd_np = torch_tensor_to_numpy(pcd)

    np.save(os.path.join(save_path, '%04d_pcd.npy' % timestamp), pcd_np)
    np.save(os.path.join(save_path, '%04d_pred.npy' % timestamp), pred_np)
    np.save(os.path.join(save_path, '%04d_gt.npy_test' % timestamp), gt_np)
