import torch
import torch.nn as nn


class PointPillarScatterV2(nn.Module):
    def __init__(self, model_cfg, grid_H, grid_W, infra=None):
        super().__init__()

        self.model_cfg = model_cfg
        self.num_bev_features = self.model_cfg['num_features']
        
        self.grid_H = grid_H
        self.grid_W = grid_W
        
        if infra is not None:
            if infra:
                self.nx, self.ny, self.nz = model_cfg['i_grid_size']
            elif ~infra:
                self.nx, self.ny, self.nz = model_cfg['v_grid_size']
        
        else:
            self.nx, self.ny, self.nz = model_cfg['grid_size']
        
        assert self.nz == 1

    def forward(self, batch_dict):
        pillar_features, coords = batch_dict['pillar_features'], batch_dict[
            'voxel_coords']
        
        if pillar_features.ndim == 1:
            pillar_features = pillar_features.unsqueeze(0)
                
        batch_spatial_features = []
        batch_s_mask = []
        batch_size = coords[:, 0].max().int().item() + 1

        for batch_idx in range(batch_size):
            spatial_feature = torch.zeros(
                self.num_bev_features,
                self.nz * self.nx * self.ny,
                dtype=pillar_features.dtype,
                device=pillar_features.device)
            
            s_mask = torch.zeros((self.nz * self.nx * self.ny), 
                                 dtype=torch.bool, 
                                 device=pillar_features.device)

            batch_mask = coords[:, 0] == batch_idx
            this_coords = coords[batch_mask, :]
            
            # Added s_mask for heterogeneity condiseration
            # s_mask = torch.zeros((self.grid_H, self.grid_W), dtype=torch.bool).to(pillar_features.device)
            # unempty_pillar = this_coords[:, 2:].type(torch.long)          
            # s_mask[unempty_pillar[:, 0], unempty_pillar[:, 1]] = True           

            indices = this_coords[:, 1] + \
                      this_coords[:, 2] * self.nx + \
                      this_coords[:, 3]
            indices = indices.type(torch.long)

            pillars = pillar_features[batch_mask, :]
            pillars = pillars.t()
            spatial_feature[:, indices] = pillars
            s_mask[indices] = True
            batch_spatial_features.append(spatial_feature)
            batch_s_mask.append(s_mask)

        batch_spatial_features = \
            torch.stack(batch_spatial_features, 0)
        batch_spatial_features = \
            batch_spatial_features.view(batch_size, self.num_bev_features *
                                        self.nz, self.ny, self.nx)
            
        batch_s_mask = torch.stack(batch_s_mask, 0)
        batch_s_mask = batch_s_mask.view(batch_size, self.nz, self.ny, self.nx)
        if self.nz == 1:
            batch_s_mask = batch_s_mask.squeeze(1)
            
        
        batch_dict['spatial_features'] = batch_spatial_features
        batch_dict['grid_size'] = [self.ny, self.nx]
        batch_dict['s_mask'] = batch_s_mask

        return batch_dict

