import random
import torch
import torch.nn as nn


def add_noise(features):
    # with torch.no_grad():
    pass


def random_sampling(features, zero_ratio=0.1):

    N, C, D, H, W = features.shape
    # features = features.view(N, C * D, -1)
    features = features.view(N, C, -1)
    total_sample = tuple(range(D * H * W))
    for i in range(N):
        zero_filter = random.sample(total_sample, int(zero_ratio * len(total_sample)))
        features[i, :, zero_filter] = 1e-7

    features = features.view(N, C, D, H, W)
    return features


class HeightCompression(nn.Module):
    def __init__(self, model_cfg, **kwargs):
        super().__init__()
        self.model_cfg = model_cfg
        self.num_bev_features = self.model_cfg.NUM_BEV_FEATURES

    def forward(self, batch_dict, **kwargs):
        """
        Args:
            batch_dict:
                encoded_spconv_tensor: sparse tensor
        Returns:
            batch_dict:
                spatial_features:

        """

        encoded_spconv_tensor = batch_dict["encoded_spconv_tensor"]
        spatial_features = encoded_spconv_tensor.dense()
        N, C, D, H, W = spatial_features.shape
        spatial_features = spatial_features.view(N, C * D, H, W)
        batch_dict["spatial_features"] = spatial_features
        batch_dict["spatial_features_stride"] = batch_dict[
            "encoded_spconv_tensor_stride"
        ]
        return batch_dict

