# -*- coding: utf-8 -*-
# Author: Yifan Lu <yifan_lu@sjtu.edu.cn>
# License: TDG-Attribution-NonCommercial-NoDistrib

import torch.nn as nn

from opencood.models.sub_modules.mean_vfe import MeanVFE
from opencood.models.sub_modules.sparse_backbone_3d import VoxelBackBone8x
from opencood.models.sub_modules.height_compression import HeightCompression
from opencood.models.sub_modules.base_bev_backbone import BaseBEVBackbone
from opencood.models.sub_modules.cia_ssd_utils import SSFA
from opencood.models.sub_modules.downsample_conv import DownsampleConv
import numpy as np
from opencood.utils.model_utils import weight_init

class SecondSSFAUncertainty(nn.Module):
    def __init__(self, args):
        super(SecondSSFAUncertainty, self).__init__()
        lidar_range = np.array(args['lidar_range'])
        grid_size = np.round((lidar_range[3:6] - lidar_range[:3]) /
                             np.array(args['voxel_size'])).astype(np.int64)
        self.vfe = MeanVFE(args['mean_vfe'],
                           args['mean_vfe']['num_point_features'])
        self.spconv_block = VoxelBackBone8x(args['spconv'],
                                            input_channels=args['spconv'][
                                                'num_features_in'],
                                            grid_size=grid_size)
        self.map_to_bev = HeightCompression(args['map2bev'])
        self.ssfa = SSFA(args['ssfa'])
        self.out_channel = args['ssfa']['feature_num']

        uncertainty_dim = args['uncertainty_dim']

        self.shrink_flag = False
        if 'shrink_header' in args:
            self.shrink_flag = True
            self.shrink_conv = DownsampleConv(args['shrink_header'])
            self.out_channel = args['shrink_header']['dim'][-1]

        self.cls_head = nn.Conv2d(self.out_channel, args['anchor_num'],
                                  kernel_size=1)
        self.reg_head = nn.Conv2d(self.out_channel, 7 * args['anchor_num'],
                                  kernel_size=1)
        self.unc_head = nn.Conv2d(self.out_channel, uncertainty_dim * args['anchor_num'],
                                  kernel_size=1)
        
        self.use_dir = False
        if 'dir_args' in args.keys():
            self.use_dir = True
            self.dir_head = nn.Conv2d(self.out_channel, args['dir_args']['num_bins'] * args['anchor_num'],
                                  kernel_size=1) # BIN_NUM = 2

        self.apply(weight_init)

    def forward(self, data_dict):

        voxel_features = data_dict['processed_lidar']['voxel_features']
        voxel_coords = data_dict['processed_lidar']['voxel_coords']
        voxel_num_points = data_dict['processed_lidar']['voxel_num_points']
        batch_size = voxel_coords[:, 0].max() + 1 # batch size is padded in the first idx

        batch_dict = {'voxel_features': voxel_features,
                      'voxel_coords': voxel_coords,
                      'voxel_num_points': voxel_num_points,
                      'batch_size': batch_size}

        batch_dict = self.vfe(batch_dict)
        batch_dict = self.spconv_block(batch_dict)
        batch_dict = self.map_to_bev(batch_dict)
        out = self.ssfa(batch_dict['spatial_features'])
        if self.shrink_flag:
            out = self.shrink_conv(out)

        cls_preds = self.cls_head(out)
        reg_preds = self.reg_head(out)
        unc_preds = self.unc_head(out) # s is log(b) or log(sigma^2)

        output_dict = {'cls_preds': cls_preds,
                       'reg_preds': reg_preds,
                       'unc_preds': unc_preds}
        
        if self.use_dir:
            dir_preds = self.dir_head(out)
            output_dict.update({'dir_preds': dir_preds})

        return output_dict
