# -*- coding: utf-8 -*-
# Author: Runsheng Xu <rxx3386@ucla.edu>, OpenPCDet
# License: TDG-Attribution-NonCommercial-NoDistrib


import torch
import torch.nn as nn
import numpy as np

from opencood.models.sub_modules.pillar_vfe import PillarVFE
from opencood.models.sub_modules.point_pillar_scatter import PointPillarScatter
from opencood.models.sub_modules.base_bev_backbone import BaseBEVBackbone


class PointPillarSeperate(nn.Module):
    def __init__(self, args):
        super(PointPillarSeperate, self).__init__()

        # PIllar VFE
        self.v_pillar_vfe = PillarVFE(args['pillar_vfe'],
                                    num_point_features=4,
                                    voxel_size=args['voxel_size'],
                                    point_cloud_range=args['v_lidar_range'])
        self.i_pillar_vfe = PillarVFE(args['pillar_vfe'],
                                    num_point_features=4,
                                    voxel_size=args['voxel_size'],
                                    point_cloud_range=args['i_lidar_range'])
        
        self.v_scatter = PointPillarScatter(args['point_pillar_scatter'], False)
        self.i_scatter = PointPillarScatter(args['point_pillar_scatter'], True)
        
        self.v_backbone = BaseBEVBackbone(args['base_bev_backbone'], 64)
        self.i_backbone = BaseBEVBackbone(args['base_bev_backbone'], 64)
        
        self.v_cls_head = nn.Conv2d(128 * 3, args['anchor_number'],
                                  kernel_size=1)
        self.v_reg_head = nn.Conv2d(128 * 3, 7 * args['anchor_num'],
                                  kernel_size=1)

        self.i_cls_head = nn.Conv2d(128 * 3, args['anchor_number'],
                                  kernel_size=1)
        self.i_reg_head = nn.Conv2d(128 * 3, 7 * args['anchor_num'],
                                  kernel_size=1)

    def forward(self, data_dict):
        
        vehicle_dict = data_dict['vehicle'] if 'vehicle' in data_dict else None
        infra_dict = data_dict['infra'] if 'infra' in data_dict else None
        
        output_dict = {}
        
        if vehicle_dict:
            voxel_features = vehicle_dict['processed_lidar']['voxel_features']
            voxel_coords = vehicle_dict['processed_lidar']['voxel_coords']
            voxel_num_points = vehicle_dict['processed_lidar']['voxel_num_points']

            batch_dict = {'voxel_features': voxel_features,
                        'voxel_coords': voxel_coords,
                        'voxel_num_points': voxel_num_points}

            batch_dict = self.v_pillar_vfe(batch_dict)
            batch_dict = self.v_scatter(batch_dict)
            batch_dict= self.v_backbone(batch_dict)            

            spatial_features_2d = batch_dict['spatial_features_2d']

            psm = self.v_cls_head(spatial_features_2d)
            rm = self.v_reg_head(spatial_features_2d)
            
            output_dict['vehicle']= {'psm': psm,
                                     'rm': rm}
            
        if infra_dict:
            voxel_features = infra_dict['processed_lidar']['voxel_features']
            voxel_coords = infra_dict['processed_lidar']['voxel_coords']
            voxel_num_points = infra_dict['processed_lidar']['voxel_num_points']

            batch_dict = {'voxel_features': voxel_features,
                        'voxel_coords': voxel_coords,
                        'voxel_num_points': voxel_num_points}

            batch_dict = self.i_pillar_vfe(batch_dict)
            batch_dict = self.i_scatter(batch_dict)
            batch_dict= self.i_backbone(batch_dict)            

            spatial_features_2d = batch_dict['spatial_features_2d']

            psm = self.i_cls_head(spatial_features_2d)
            rm = self.i_reg_head(spatial_features_2d)
            
            output_dict['infra']= {'psm': psm,
                                   'rm': rm}

        # output_dict = {'psm': psm,
        #                'rm': rm}

        return output_dict