import torch
import torch.nn as nn

from opencood.models.sub_modules.mean_vfe import MeanVFE
from opencood.models.sub_modules.sparse_backbone_3d import VoxelBackBone8x, VoxelDecoder
from opencood.models.fuse_modules.fuse_utils import regroup
from opencood.models.sub_modules.downsample_conv import DownsampleConv
from opencood.models.sub_modules.naive_compress import NaiveCompressor
from opencood.models.sub_modules.sparse_layers import Voxel2PointScatterNeck, VoteSegHead
from opencood.utils.sparse_utils import voxel2point, ClusterAssigner, scatter_v2
from opencood.models.sub_modules.sst_ops import get_inner_win_inds
from mmdet3d.models.backbones.sir import SIR


class SecondSparse(nn.Module):
    def __init__(self, args):
        super(SecondSparse, self).__init__()

        self.max_cav = args['max_cav']
        self.batch_size = args['batch_size']
        self.pre_voxel_size = args['pre_voxel_size']
        self.lidar_range = args['lidar_range']
        self.num_classes = args['num_classes']
        
        # mean_vfe
        self.mean_vfe = MeanVFE(args['mean_vfe'], 4)
        # sparse conv3d encoder
        self.backbone_3d = VoxelBackBone8x(args['backbone_3d'],
                                           4, args['grid_size'])

        # sparse conv3d decoder
        self.decoder_3d = VoxelDecoder(args['decoder_3d']['decoder_channels'])
        
        # point-wise predictor
        self.voxel2point = Voxel2PointScatterNeck(args['lidar_range'], args['seg_voxel_size'])
        self.vote_head = VoteSegHead(args['vote_head']['in_channel'], args['vote_head']['num_classes'],
                                     args['vote_head']['hidden_dim'], args['vote_head']['dropout_ratio'])
        
        # foreground point cluster
        self.cluster_assigner = ClusterAssigner(args['cluster']['cluster_voxel_size'],
            args['cluster']['min_points'], args['cluster']['point_cloud_range'],
            args['cluster']['connected_dist'])
        self.score_thresh = args['cluster']['score_thresh']
        self.fg_topk = args['cluster']['fg_topk']
        
        # SIR(Sparse Instance Recognition) feature exactor
        self.sir_module = SIR(args['sir'])
        
        self.shrink_flag = False
        if 'shrink_header' in args:
            self.shrink_flag = True
            self.shrink_conv = DownsampleConv(args['shrink_header'])
        self.compression = False

        if args['compression'] > 0:
            self.compression = True
            self.naive_compressor = NaiveCompressor(256, args['compression'])

        self.cls_head = nn.Conv2d(128 * 2, args['anchor_number'],
                                  kernel_size=1)
        self.reg_head = nn.Conv2d(128 * 2, 7 * args['anchor_number'],
                                  kernel_size=1)

        if args['backbone_fix']:
            self.backbone_fix()

    def backbone_fix(self):
        """
        Fix the parameters of backbone during finetune on timedelay。
        """
        for p in self.pillar_vfe.parameters():
            p.requires_grad = False

        for p in self.scatter.parameters():
            p.requires_grad = False

        for p in self.backbone.parameters():
            p.requires_grad = False

        if self.compression:
            for p in self.naive_compressor.parameters():
                p.requires_grad = False
        if self.shrink_flag:
            for p in self.shrink_conv.parameters():
                p.requires_grad = False

        for p in self.cls_head.parameters():
            p.requires_grad = False
        for p in self.reg_head.parameters():
            p.requires_grad = False

    def forward(self, data_dict):
        voxel_features = data_dict['processed_lidar']['voxel_features']
        voxel_coords = data_dict['processed_lidar']['voxel_coords']
        voxel_num_points = data_dict['processed_lidar']['voxel_num_points']
        record_len = data_dict['record_len']
        spatial_correction_matrix = data_dict['spatial_correction_matrix']

        # B, max_cav, 3(dt dv infra), 1, 1
        prior_encoding =\
            data_dict['prior_encoding'].unsqueeze(-1).unsqueeze(-1)
        
        batch_dict = {'voxel_features': voxel_features,     # [N, max_points_per_voxel, 4]
                      'voxel_coords': voxel_coords,         # [N, 4] (B, Z, Y, X)
                      'voxel_num_points': voxel_num_points, # [N](1~max_points_per_voxel)
                      'batch_size': torch.sum(record_len).cpu().numpy(),
                      'record_len': record_len}             # [B] point out the scene of batch
        
        # n, p, 4 -> n, 4 by mean reduce
        batch_dict = self.mean_vfe(batch_dict)      # voxel_features [N, 4]
        # n, 4 -> [(n_1, 16), (n_2, 32), (n_3, 64), (n_4, 64)]
        batch_dict = self.backbone_3d(batch_dict)   # multi-scale voxel feautres

        # multi-scale -> (n, 16)
        batch_dict = self.decoder_3d(batch_dict)    # upsampled feature
        
        # assign voxel to point
        points, voxel2point_inds = voxel2point(voxel_features, voxel_num_points)
        batch_dict.update({
            'points': points,                                           # [N_p, 4]
            'point_coords': voxel_coords[voxel2point_inds, :],          # [N_p, 4]
            'voxel2point_inds': voxel2point_inds                        # [N_p]
        })
        
        # assign voxel feature to point inside voxel
        # modify point_feats, add point_mask. 
        # Noted mask is default all True since no drop voxel
        batch_dict = self.voxel2point(batch_dict)
        
        # predict point-wise class and vote for cluster center
        # add class_logits, add vote_preds, add vote_offsets
        batch_dict = self.vote_head(batch_dict)
        
        # re-project points into voxels
        # modify value shape of {'points': [N_v, 4], 'point_feats': [N_v, 19], 
        # 'class_logits': [N_v, 1], 'vote_preds': [N_v, 3], 'batch_idx': [N_v]}
        batch_dict = self.pre_voxelize(batch_dict)
        
        # get foreground mask and center
        # add fg_mask, add fg_center_preds
        batch_dict = self.group_sample(batch_dict)
        # cluster based on center distance
        pts_cluster_inds, valid_mask = self.cluster_assigner(batch_dict['fg_center_preds'],
                                        batch_dict['batch_idx'][batch_dict['fg_mask']])
        
        

        return

    def pre_voxelize(self, data_dict):
        ''' reproject points into voxels (feature reduce by avg_pool)
        '''
        batch_idx = data_dict['point_coords'][:, 0]
        points = data_dict['points']

        voxel_size = torch.tensor(self.pre_voxel_size, device=batch_idx.device)
        pc_range = torch.tensor(self.lidar_range, device=points.device)
        coors = torch.div(points[:, :3] - pc_range[None, :3], voxel_size[None, :], rounding_mode='floor').long()
        coors = coors[:, [2, 1, 0]] # to zyx order
        coors = torch.cat([batch_idx[:, None], coors], dim=1)

        new_coors, unq_inv  = torch.unique(coors, return_inverse=True, return_counts=False, dim=0)

        voxelized_data_dict = {}
        for data_name in data_dict:
            data = data_dict[data_name]
            if hasattr(data, 'dtype') and data.dtype in (torch.float, torch.float16):
                voxelized_data, voxel_coors = scatter_v2(data, coors, mode='avg', return_inv=False, new_coors=new_coors, unq_inv=unq_inv)
                voxelized_data_dict[data_name] = voxelized_data
            else:
                voxelized_data_dict[data_name] = data
        
        voxelized_data_dict['batch_idx'] = voxel_coors[:, 0]
        return voxelized_data_dict
    
    def group_sample(self, batch_data):
        """ get foreground mask based on class score 
            get foreground center based on vote offset
        """
        batch_idx = batch_data['batch_idx']
        bsz = batch_idx.max().item() + 1
        # combine all classes as fg class.
        seg_logits = batch_data['class_logits'].detach()
        assert (seg_logits < 0).any() # make sure no sigmoid applied
        assert seg_logits.size(1) == self.num_classes + 1 # we have background class
        seg_scores = seg_logits.softmax(1)[:, :-1]  # # without background score

        offset = batch_data['vote_offsets'].detach()     # [N, 3*2]
        offset = offset.reshape(-1, self.num_classes + 1, 3)
        seg_points = batch_data['points'][:, :3]

        # get foreground mask of point
        fg_mask = self.get_fg_mask(seg_scores, 0)
        if len(torch.unique(batch_idx[fg_mask])) < bsz:
            one_random_pos_per_sample = self.get_sample_beg_position(batch_idx, fg_mask)
            fg_mask[one_random_pos_per_sample] = True # at least one point per sample
        
        # get foreground point predictions
        fg_offset = offset[:, 0, :][fg_mask, ...]
        fg_logits = seg_logits[:, 0:1][fg_mask, :]
        offset_weight = self.get_offset_weight(fg_logits)
        assert torch.isclose(offset_weight.sum(1), offset_weight.new_ones(len(offset_weight))).all()
        fg_offset = (fg_offset * offset_weight[:, :, None]).sum(dim=1)
        fg_points = seg_points[fg_mask, :]
        fg_centers = fg_points + fg_offset
        
        batch_data['fg_mask'] = fg_mask
        batch_data['fg_center_preds'] = fg_centers

        return batch_data
    
    def get_fg_mask(self, seg_scores, cls_id):
        if self.training:
            # training: select top-k as foreground cluster
            seg_scores = seg_scores[:, cls_id]
            k = min(self.fg_topk, len(seg_scores))
            top_inds = torch.topk(seg_scores, k)[1]
            fg_mask = torch.zeros_like(seg_scores, dtype=torch.bool)
            fg_mask[top_inds] = True
        else:
            # inference: select all cluster with score > th
            seg_scores = seg_scores[:, cls_id]
            fg_mask = seg_scores > self.score_thresh
        return fg_mask
    
    def get_offset_weight(self, seg_logit):
        weight = ((seg_logit - seg_logit.max(1)[0][:, None]).abs() < 1e-6).float()
        assert ((weight == 1).any(1)).all()
        weight = weight / weight.sum(1)[:, None] # in case of two max values
        return weight

    def get_sample_beg_position(self, batch_idx, fg_mask):
        assert batch_idx.shape == fg_mask.shape
        inner_inds = get_inner_win_inds(batch_idx.contiguous())
        pos = torch.where(inner_inds == 0)[0]
        return pos
    
    def update_sample_results_by_mask(self, sampled_out, valid_mask_list):
        for k in sampled_out:
            old_data = sampled_out[k]
            if len(old_data[0]) == len(valid_mask_list[0]) or 'fg_mask' in k:
                if 'fg_mask' in k:
                    new_data_list = []
                    for data, mask in zip(old_data, valid_mask_list):
                        new_data = data.clone()
                        new_data[data] = mask
                        assert new_data.sum() == mask.sum()
                        new_data_list.append(new_data)
                    sampled_out[k] = new_data_list
                else:
                    new_data_list = [data[mask] for data, mask in zip(old_data, valid_mask_list)]
                    sampled_out[k] = new_data_list
        return sampled_out