# Copyright (c) Meta Platforms, Inc. and affiliates.

# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


import numpy.random as random

import torch
import torch.nn as nn
import torch.nn.functional as F
# from MinkowskiEngine import SparseTensor

# class MinkowskiGRN(nn.Module):
#     """ GRN layer for sparse tensors.
#     """
#     def __init__(self, dim):
#         super().__init__()
#         self.gamma = nn.Parameter(torch.zeros(1, dim))
#         self.beta = nn.Parameter(torch.zeros(1, dim))

#     def forward(self, x):
#         cm = x.coordinate_manager
#         in_key = x.coordinate_map_key

#         Gx = torch.norm(x.F, p=2, dim=0, keepdim=True)
#         Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
#         return SparseTensor(
#                 self.gamma * (x.F * Nx) + self.beta + x.F,
#                 coordinate_map_key=in_key,
#                 coordinate_manager=cm)

# class MinkowskiDropPath(nn.Module):
#     """ Drop Path for sparse tensors.
#     """

#     def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
#         super(MinkowskiDropPath, self).__init__()
#         self.drop_prob = drop_prob
#         self.scale_by_keep = scale_by_keep
    
#     def forward(self, x):
#         if self.drop_prob == 0. or not self.training:
#             return x
#         cm = x.coordinate_manager
#         in_key = x.coordinate_map_key
#         keep_prob = 1 - self.drop_prob
#         mask = torch.cat([
#             torch.ones(len(_)) if random.uniform(0, 1) > self.drop_prob
#             else torch.zeros(len(_)) for _ in x.decomposed_coordinates
#         ]).view(-1, 1).to(x.device)
#         if keep_prob > 0.0 and self.scale_by_keep:
#             mask.div_(keep_prob)
#         return SparseTensor(
#                 x.F * mask,
#                 coordinate_map_key=in_key,
#                 coordinate_manager=cm)

# class MinkowskiLayerNorm(nn.Module):
#     """ Channel-wise layer normalization for sparse tensors.
#     """

#     def __init__(
#         self,
#         normalized_shape,
#         eps=1e-6,
#     ):
#         super(MinkowskiLayerNorm, self).__init__()
#         self.ln = nn.LayerNorm(normalized_shape, eps=eps)
#     def forward(self, input):
#         output = self.ln(input.F)
#         return SparseTensor(
#             output,
#             coordinate_map_key=input.coordinate_map_key,
#             coordinate_manager=input.coordinate_manager)
            
class LayerNorm(nn.Module):
    """ LayerNorm that supports two data formats: channels_last (default) or channels_first. 
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs 
    with shape (batch_size, channels, height, width).
    """
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError 
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x

class GRN(nn.Module):
    """ GRN (Global Response Normalization) layer
    """
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
        self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))

    def forward(self, x):
        Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True)
        Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
        return self.gamma * (x * Nx) + self.beta + x