# -*- coding: utf-8 -*-
# Author: 
# License: TDG-Attribution-NonCommercial-NoDistrib


import os
from collections import OrderedDict

import numpy as np
import torch

from opencood.utils.common_utils import torch_tensor_to_numpy
from opencood.utils.box_utils import project_box3d, corner_to_center


def inference_late_fusion(batch_data, model, dataset):
    """
    Model inference for late fusion.

    Parameters
    ----------
    batch_data : dict
    model : opencood.object
    dataset : opencood.LateFusionDataset

    Returns
    -------
    pred_box_tensor : torch.Tensor
        The tensor of prediction bounding box after NMS.
    gt_box_tensor : torch.Tensor
        The tensor of gt bounding box.
    """
    output_dict = OrderedDict()

    for cav_id, cav_content in batch_data.items():
        output_dict[cav_id] = model(cav_content)

    pred_box_tensor, pred_score, gt_box_tensor = \
        dataset.post_process(batch_data,
                             output_dict)

    return pred_box_tensor, pred_score, gt_box_tensor


def inference_early_fusion(batch_data, model, dataset):
    """
    Model inference for early fusion.

    Parameters
    ----------
    batch_data : dict
    model : opencood.object
    dataset : opencood.EarlyFusionDataset

    Returns
    -------
    pred_box_tensor : torch.Tensor
        The tensor of prediction bounding box after NMS.
    gt_box_tensor : torch.Tensor
        The tensor of gt bounding box.
    """
    output_dict = OrderedDict()
    cav_content = batch_data['ego']

    output_dict['ego'] = model(cav_content)

    pred_box_tensor, pred_score, gt_box_tensor = \
        dataset.post_process(batch_data,
                             output_dict)

    return pred_box_tensor, pred_score, gt_box_tensor


def inference_intermediate_fusion(batch_data, model, dataset):
    """
    Model inference for early fusion.

    Parameters
    ----------
    batch_data : dict
    model : opencood.object
    dataset : opencood.EarlyFusionDataset

    Returns
    -------
    pred_box_tensor : torch.Tensor
        The tensor of prediction bounding box after NMS.
    gt_box_tensor : torch.Tensor
        The tensor of gt bounding box.
    """
    return inference_early_fusion(batch_data, model, dataset)


def save_prediction_gt(pred_tensor, gt_tensor, pcd, timestamp, save_path, pred_score=None):
    """
    Save prediction and gt tensor to txt file.
    
    Parameters
    ----------
    pred_tensor : torch.Tensor or None
        Prediction bounding boxes
    gt_tensor : torch.Tensor
        Ground truth bounding boxes
    pcd : torch.Tensor
        Point cloud data
    timestamp : int
        Frame timestamp/index
    save_path : str
        Directory to save files
    pred_score : torch.Tensor or None, optional
        Prediction scores
    """
    if pred_tensor is not None:
        pred_np = torch_tensor_to_numpy(pred_tensor)
        np.save(os.path.join(save_path, '%04d_pred.npy' % timestamp), pred_np)
    
    if pred_score is not None:
        pred_score_np = torch_tensor_to_numpy(pred_score)
        np.save(os.path.join(save_path, '%04d_pred_score.npy' % timestamp), pred_score_np)
    
    gt_np = torch_tensor_to_numpy(gt_tensor)
    pcd_np = torch_tensor_to_numpy(pcd)

    np.save(os.path.join(save_path, '%04d_pcd.npy' % timestamp), pcd_np)
    np.save(os.path.join(save_path, '%04d_gt.npy' % timestamp), gt_np)

def save_prediction_gt_opencda(pred_tensor, gt_tensor, pcd, timestamp, save_path, transformation_matrix):
    """
    Save prediction and gt tensor to txt file. Coordinate transformed back to Carla world coordinate.
    """
    if pred_tensor is not None:
        pred_np = torch_tensor_to_numpy(pred_tensor)
        pred_np = project_box3d(pred_np, transformation_matrix)
        pred_np = corner_to_center(pred_np, order='hwl')
        #print('Pred: ')
        #print(pred_np)
        np.save(os.path.join(save_path, '%04d_pred.npy' % timestamp), pred_np)

    gt_np = torch_tensor_to_numpy(gt_tensor)
    gt_np = project_box3d(gt_np, transformation_matrix)
    gt_np = corner_to_center(gt_np, order='hwl')
    #print('GT: ')
    #print(gt_np)
    np.save(os.path.join(save_path, '%04d_gt.npy' % timestamp), gt_np)