# -*- coding: utf-8 -*-

# License: TDG-Attribution-NonCommercial-NoDistrib

import numpy as np
import torch
from opencood.utils import pcd_utils


class BasePreprocessor(object):
    """
    Basic Lidar pre-processor.

    Parameters
    ----------
    preprocess_params : dict
        The dictionary containing all parameters of the preprocessing.

    train : bool
        Train or test mode.
    """

    def __init__(self, preprocess_params, train):
        self.params = preprocess_params
        self.train = train

    def preprocess(self, pcd_np):
        """
        Preprocess the lidar points by simple sampling.

        Parameters
        ----------
        pcd_np : np.ndarray
            The raw lidar.

        Returns
        -------
        data_dict : the output dictionary.
        """
        data_dict = {}
        sample_num = self.params['args']['sample_num']

        pcd_np = pcd_utils.downsample_lidar(pcd_np, sample_num)
        data_dict['downsample_lidar'] = pcd_np

        return data_dict

    def project_points_to_bev_map(self, points, ratio=0.1):
        """
        Project points to BEV occupancy map with default ratio=0.1.

        Parameters
        ----------
        points : np.ndarray
            (N, 3) / (N, 4)

        ratio : float
            Discretization parameters. Default is 0.1.

        Returns
        -------
        bev_map : np.ndarray
            BEV occupancy map including projected points with shape
            (img_row, img_col).

        """
        L1, W1, H1, L2, W2, H2 = self.params["cav_lidar_range"]
        img_row = int((L2 - L1) / ratio)
        img_col = int((W2 - W1) / ratio)
        bev_map = np.zeros((img_row, img_col))
        bev_origin = np.array([L1, W1, H1]).reshape(1, -1)
        # (N, 3)
        indices = ((points[:, :3] - bev_origin) / ratio).astype(int)
        mask = np.logical_and(indices[:, 0] > 0, indices[:, 0] < img_row)
        mask = np.logical_and(mask, np.logical_and(indices[:, 1] > 0,
                                                   indices[:, 1] < img_col))
        indices = indices[mask, :]
        bev_map[indices[:, 0], indices[:, 1]] = 1
        return bev_map

class IdentityPreprocessor(BasePreprocessor):
    def __init__(self, preprocess_params, train):
        self.params = preprocess_params
        self.train = train
    
    def preprocess(self, pcd_np):
        data_dict = {}
        data_dict['points'] =  torch.from_numpy(pcd_np)
        return data_dict

    def collate_batch(self, batch):
        if isinstance(batch, list):
            return self.collate_batch_list(batch)
        elif isinstance(batch, dict):
            return self.collate_batch_dict(batch)
        else:
            raise NotImplemented
    
    def collate_batch_list(self, batch):
        points = []
        for i in range(len(batch)):
            points.append(batch[i]['points'])
        return {'points': points}

    def collate_batch_dict(self, batch):
        return batch