# -*- coding: utf-8 -*-

import os
from collections import OrderedDict

import numpy as np
import torch
import cv2

from opencood.utils.common_utils import torch_tensor_to_numpy

CLASS_TO_RGB = {
            0: [0, 0, 0],  # Background
            1: [255, 255, 255],  # Vehicle
            2: [238, 123, 94],  # Road
            3: [41, 132, 199]  # Lane
        }

def inference_late_fusion(batch_data, model, dataset):
    """
    Model inference for late fusion.

    Parameters
    ----------
    batch_data : dict
    model : opencood.object
    dataset : opencood.LateFusionDataset

    Returns
    -------
    pred_box_tensor : torch.Tensor
        The tensor of prediction bounding box after NMS.
    gt_box_tensor : torch.Tensor
        The tensor of gt bounding box.
    """
    output_dict = OrderedDict()

    for cav_id, cav_content in batch_data.items():
        output_dict[cav_id] = model(cav_content)

    pred_box_tensor, pred_score, gt_box_tensor, gt_label_tensor = \
        dataset.post_process(batch_data,
                             output_dict)

    return pred_box_tensor, pred_score, gt_box_tensor, gt_label_tensor


def inference_nofusion(batch_data, model, dataset):
    """
    Model inference for late fusion.

    Parameters
    ----------
    batch_data : dict
    model : opencood.object
    dataset : opencood.LateFusionDataset

    Returns
    -------
    pred_box_tensor : torch.Tensor
        The tensor of prediction bounding box after NMS.
    gt_box_tensor : torch.Tensor
        The tensor of gt bounding box.
    """
    output_dict = OrderedDict()

    for cav_id, cav_content in batch_data.items():
        if cav_id == "ego":
            output_dict[cav_id] = model(cav_content)

    pred_box_tensor, pred_score, gt_box_tensor, gt_label_tensor = \
        dataset.post_process(batch_data,
                             output_dict)

    return pred_box_tensor, pred_score, gt_box_tensor, gt_label_tensor

def inference_early_fusion(batch_data, model, dataset):
    """
    Model inference for early fusion.

    Parameters
    ----------
    batch_data : dict
    model : opencood.object
    dataset : opencood.EarlyFusionDataset

    Returns
    -------
    pred_box_tensor : torch.Tensor
        The tensor of prediction bounding box after NMS.
    gt_box_tensor : torch.Tensor
        The tensor of gt bounding box.
    """
    output_dict = OrderedDict()
    cav_content = batch_data['ego']


    output_dict['ego'] = model(cav_content)

    pred_box_tensor, pred_score, gt_box_tensor, gt_label_tensor = \
        dataset.post_process(batch_data,
                             output_dict)

    return pred_box_tensor, pred_score, gt_box_tensor, gt_label_tensor


def inference_intermediate_fusion(batch_data, model, dataset):
    """
    Model inference for early fusion.

    Parameters
    ----------
    batch_data : dict
    model : opencood.object
    dataset : opencood.EarlyFusionDataset

    Returns
    -------
    pred_box_tensor : torch.Tensor
        The tensor of prediction bounding box after NMS.
    gt_box_tensor : torch.Tensor
        The tensor of gt bounding box.
    """
    return inference_early_fusion(batch_data, model, dataset)


def save_prediction_gt(pred_tensor, gt_tensor, pcd, timestamp, save_path):
    """
    Save prediction and gt tensor to txt file.
    """
    pred_np = torch_tensor_to_numpy(pred_tensor)
    gt_np = torch_tensor_to_numpy(gt_tensor)
    pcd_np = torch_tensor_to_numpy(pcd)

    np.save(os.path.join(save_path, '%04d_pcd.npy' % timestamp), pcd_np)
    np.save(os.path.join(save_path, '%04d_pred.npy' % timestamp), pred_np)
    np.save(os.path.join(save_path, '%04d_gt.npy_test' % timestamp), gt_np)

def camera_inference_visualization(pred_map,
                                   gt_map,
                                   output_dir,
                                   epoch,
                                   folder_name,
                                   with_true_map=True):
    # Only compare the output result with
    image_width = 512
    image_height = 512
    offset = 50 # Spaces between maps

    output_folder = os.path.join(output_dir, folder_name)
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    if with_true_map:
        visualize_summary = np.zeros((image_height,
                                      image_width * 2+offset,
                                      3),
                                     dtype=np.uint8)

        pred_map_rgb = np.zeros((pred_map.shape[0], pred_map.shape[1], 3))
        gt_map_rgb = np.zeros((gt_map.shape[0], gt_map.shape[1], 3))

        for k, v in CLASS_TO_RGB.items():
            pred_map_rgb[np.where(pred_map == k)] = v
            gt_map_rgb[np.where(gt_map == k)] = v

        pred_map_rgb = cv2.resize(pred_map_rgb, (image_width,
                                             image_height))

        gt_map_rgb = cv2.resize(gt_map_rgb, (image_width,
                                                 image_height))

        visualize_summary[:, image_width * 0: image_width * 1, :] = gt_map_rgb
        visualize_summary[:, image_width * 1 + offset:, :] = pred_map_rgb

        cv2.imwrite(os.path.join(output_folder, '%04d.png')
                    % epoch, visualize_summary)
    else:
        visualize_summary = np.zeros((image_height, image_width, 3), dtype=np.uint8)

        pred_map_rgb = np.zeros((pred_map.shape[0], pred_map.shape[1], 3))

        for k, v in CLASS_TO_RGB.items():
            pred_map_rgb[np.where(pred_map == k)] = v

        pred_map_rgb = cv2.resize(pred_map_rgb, (image_width, image_height))

        visualize_summary[:, 0: image_width, :] = pred_map_rgb

        cv2.imwrite(os.path.join(output_folder, '%04d.png')
                    % epoch, visualize_summary)