import torch
import torch.nn as nn
import torch.nn.functional as F

import pandas as pd

from opencood.models.sub_modules.pillar_vfe import PillarVFE
from opencood.models.sub_modules.point_pillar_scatter import PointPillarScatter
# from opencood.models.sub_modules.point_pillar_scatter_v2 import PointPillarScatterV2
from opencood.models.sub_modules.sparse_resnet_backbone import SparseResNet
from opencood.models.fuse_modules.fuse_utils import regroup
from opencood.models.sub_modules.naive_compress import NaiveCompressor
# from opencood.models.fuse_modules.v2xm2c_basic import V2XTransformer
from opencood.models.fuse_modules.parcon_basic import V2XTransformer

class SparseResnetConcatTransformer(nn.Module):
    def __init__(self, args):
        super(SparseResnetConcatTransformer, self).__init__()
        
        select_dim = 256
        # select_dim = 512
        
        # self.grid_H = round((args['lidar_range'][4] - args['lidar_range'][1]) / args['voxel_size'][1])
        # self.grid_W = round((args['lidar_range'][3] - args['lidar_range'][0]) / args['voxel_size'][0])
        # self.s_mask_down_rate = args['s_mask_down_rate'] if 's_mask_down_rate' in args else None
        
        self.max_cav = args['max_cav']
        # PIllar VFE
        self.pillar_vfe = PillarVFE(args['pillar_vfe'],
                                    num_point_features=4,
                                    voxel_size=args['voxel_size'],
                                    point_cloud_range=args['lidar_range'])
        
        self.scatter = PointPillarScatter(args['point_pillar_scatter'])
        self.backbone = SparseResNet(args['sparse_resnet_backbone'], 64, select_dim)

        self.compression = False
        if args['compression'] > 0:
            self.compression = True
            self.naive_compressor = NaiveCompressor(select_dim, args['compression'])

        self.fusion_net = V2XTransformer(args['transformer'])

        self.cls_head = nn.Conv2d(select_dim, args['anchor_number'],
                                  kernel_size=1)
        self.reg_head = nn.Conv2d(select_dim, 7 * args['anchor_number'],
                                  kernel_size=1)

        if args['backbone_fix']:
            self.backbone_fix()
        

    def backbone_fix(self):
        """
        Fix the parameters of backbone during finetune on timedelay。
        """
        for p in self.pillar_vfe.parameters():
            p.requires_grad = False

        for p in self.scatter.parameters():
            p.requires_grad = False

        for p in self.backbone.parameters():
            p.requires_grad = False

        if self.compression:
            for p in self.naive_compressor.parameters():
                p.requires_grad = False

        for p in self.cls_head.parameters():
            p.requires_grad = False
        for p in self.reg_head.parameters():
            p.requires_grad = False

    def forward(self, data_dict):
        voxel_features = data_dict['processed_lidar']['voxel_features']
        voxel_coords = data_dict['processed_lidar']['voxel_coords']
        voxel_num_points = data_dict['processed_lidar']['voxel_num_points']
        
        # relative_pose = data_dict['relative_pose']
        record_len = data_dict['record_len']
        spatial_correction_matrix = data_dict['spatial_correction_matrix']
        
        cav_list = data_dict['cav_list'] if 'cav_list' in data_dict else None
        scene_info = data_dict['scene_info'] if 'scene_info' in data_dict else None

        # B, max_cav, 3(dt dv infra), 1, 1
        prior_encoding =\
            data_dict['prior_encoding'].unsqueeze(-1).unsqueeze(-1)

        batch_dict = {'voxel_features': voxel_features,
                    'voxel_coords': voxel_coords,
                    'voxel_num_points': voxel_num_points,
                    'record_len': record_len}
        
        # n, 4 -> n, c
        batch_dict = self.pillar_vfe(batch_dict)
        # n, c -> N, C, H, W
        batch_dict = self.scatter(batch_dict)
        batch_dict = self.backbone(batch_dict)
        
        # s_mask = batch_dict['s_mask']
        # if self.s_mask_down_rate:
        #     s_mask = batch_dict['s_mask'][:, ::self.s_mask_down_rate, ::self.s_mask_down_rate]

        spatial_features_2d = batch_dict['spatial_features_2d']
        
        # compressor
        if self.compression:
            spatial_features_2d = self.naive_compressor(spatial_features_2d)
            
        # N, C, H, W -> B, L, C, H, W
        regroup_feature, mask = regroup(spatial_features_2d,
                                                        record_len,
                                                        self.max_cav,
                                                        cav_list,
                                                        scene_info)
        # prior encoding added
        prior_encoding = prior_encoding.repeat(1, 1, 1,
                                            regroup_feature.shape[3],
                                            regroup_feature.shape[4])
        regroup_feature = torch.cat([regroup_feature, prior_encoding], dim=2)

        # b l c h w -> b l h w c
        regroup_feature = regroup_feature.permute(0, 1, 3, 4, 2)
        # transformer fusion
        fused_feature = self.fusion_net(regroup_feature, mask, spatial_correction_matrix)
        # b h w c -> b c h w
        fused_feature = fused_feature.permute(0, 3, 1, 2)

        psm = self.cls_head(fused_feature)
        rm = self.reg_head(fused_feature)

        output_dict = {'psm': psm,
                    'rm': rm}

        return output_dict
