# -*- coding: utf-8 -*-

"""
Convert lidar to bev
"""

import numpy as np
import torch
from opencood.data_utils.pre_processor.base_preprocessor import \
    BasePreprocessor
import cv2
from torchvision import transforms

class BevPreprocessor(BasePreprocessor):
    def __init__(self, preprocess_params, train):
        super(BevPreprocessor, self).__init__(preprocess_params, train)
        self.lidar_range = self.params['cav_lidar_range']
        self.geometry_param = preprocess_params["geometry_param"]
        self.transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=self.params['args']['mean'], std=self.params['args']['std'])
    ])

    def preprocess_lidar(self, pcd_raw):
        """
        Preprocess the lidar points to BEV representations.

        Parameters
        ----------
        pcd_raw : np.ndarray
            The raw lidar.

        Returns
        -------
        data_dict : the structured output dictionary.
        """
        bev = np.zeros(self.geometry_param['input_shape'], dtype=np.float32)
        intensity_map_count = np.zeros((bev.shape[0], bev.shape[1]),
                                       dtype=int)
        bev_origin = np.array(
            [self.geometry_param["L1"], self.geometry_param["W1"],
             self.geometry_param["H1"]]).reshape(1, -1)

        indices = ((pcd_raw[:, :3] - bev_origin) / self.geometry_param[
            "res"]).astype(int)

        for i in range(indices.shape[0]):
            bev[indices[i, 0], indices[i, 1], indices[i, 2]] = 1
            bev[indices[i, 0], indices[i, 1], -1] += pcd_raw[i, 3]
            intensity_map_count[indices[i, 0], indices[i, 1]] += 1
        divide_mask = intensity_map_count != 0
        bev[divide_mask, -1] = np.divide(bev[divide_mask, -1],
                                         intensity_map_count[divide_mask])

        data_dict = {
            "bev_input": np.transpose(bev, (2, 0, 1))
        }
        return data_dict

    def preprocess_rgb(self, rgb_image, img_format):
        # camera preprocess
        if img_format == 'npy': # Only npy file has been channel swapped and resized
            rgb_image = self.transform(rgb_image)
        else:
            rgb_image = self.channel_swap(rgb_image)
            rgb_image = self.resize_image(rgb_image)
            rgb_image = self.normalize(rgb_image)
            rgb_image = self.standalize(rgb_image)

        if rgb_image.shape[0] == 3:
            rgb_image = rgb_image.permute(1, 2, 0) # Make sure the dimension follows with [H, W, C] order
        return rgb_image

    def standalize(self, rgb_image):
        mean = np.array(self.params['args']['mean'])
        std = np.array(self.params['args']['std'])

        rgb_image = (rgb_image - mean) / std

        return rgb_image

    def normalize(self, rgb_image):
        return np.array(rgb_image, dtype=float) / 255.

    def channel_swap(self, rgb_image):
        """
        Convert BGR to RGB if needed
        """
        if self.params['args']['bgr2rgb']:
            rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2RGB)
        else:
            rgb_image = rgb_image

        return rgb_image

    def resize_image(self, rgb_image):
        """
        Resize image to the correct resolution.
        """
        resize_x = self.params['args']['resize_x']
        resize_y = self.params['args']['resize_y']

        rgb_image = cv2.resize(rgb_image, (resize_x, resize_y))

        return rgb_image

    @staticmethod
    def collate_batch_list(batch):
        """
        Customized pytorch data loader collate function.

        Parameters
        ----------
        batch : list
            List of dictionary. Each dictionary represent a single frame.

        Returns
        -------
        processed_batch : dict
            Updated lidar batch.
        """
        bev_input_list = [
            x["bev_input"][np.newaxis, ...] for x in batch
        ]
        processed_batch = {
            "bev_input": torch.from_numpy(
                np.concatenate(bev_input_list, axis=0))
        }
        return processed_batch

    @staticmethod
    def collate_batch_dict(batch):
        """
        Customized pytorch data loader collate function.

        Parameters
        ----------
        batch : dict
            Dict of list. Each element represents a CAV.

        Returns
        -------
        processed_batch : dict
            Updated lidar batch.
        """
        bev_input_list = [
            x[np.newaxis, ...] for x in batch["bev_input"]
        ]
        processed_batch = {
            "bev_input": torch.from_numpy(
                np.concatenate(bev_input_list, axis=0))
        }
        return processed_batch

    def collate_batch(self, batch):
        """
        Customized pytorch data loader collate function.

        Parameters
        ----------
        batch : list / dict
            Batched data.
        Returns
        -------
        processed_batch : dict
            Updated lidar batch.
        """
        if isinstance(batch, list):
            return self.collate_batch_list(batch)
        elif isinstance(batch, dict):
            return self.collate_batch_dict(batch)
        else:
            raise NotImplemented
