import torch
import torch.nn as nn
from MinkowskiEngine import (
    MinkowskiSPMMFunction,
    MinkowskiSPMMAverageFunction,
    MinkowskiDirectMaxPoolingFunction,
    SparseTensor,
    TensorField,
) 

EPS = 1e-7


@torch.no_grad()
def downsample_points(points, tensor_map, field_map, size):
    down_points = MinkowskiSPMMAverageFunction().apply(
        tensor_map, field_map, size, points
    ) # rows == inverse_map
    _, count = torch.unique(tensor_map, return_counts=True)
    
    return down_points, count.unsqueeze_(1).type_as(down_points)

@torch.no_grad()
def stride_centroids(points, count, rows, cols, size):
    stride_centroids = MinkowskiSPMMFunction().apply(
        rows, cols, count, size, points
    )
    ones = torch.ones(size[1], dtype=points.dtype, device=points.device)
    stride_count = MinkowskiSPMMFunction().apply(
        rows, cols, ones, size, count
    )
    
    return torch.true_divide(stride_centroids, stride_count), stride_count

def downsample_embeddings(embeddings, inverse_map, size, mode="avg"):
    assert len(embeddings) == size[1]
    
    if mode == "max":
        in_map = torch.arange(size[1], dtype=inverse_map.dtype, device=inverse_map.device)
        down_embeddings = MinkowskiDirectMaxPoolingFunction().apply(
            in_map, inverse_map, embeddings, size[0]
        )
    elif mode == "avg":
        cols = torch.arange(size[1], dtype=inverse_map.dtype, device=inverse_map.device)
        down_embeddings = MinkowskiSPMMAverageFunction().apply(
            inverse_map, cols, size, embeddings
        )
    else:
        raise NotImplementedError
    
    return down_embeddings

def create_splat_coordinates(coordinates: torch.Tensor) -> torch.Tensor:
    r"""Create splat coordinates. splat coordinates could have duplicate coordinates."""
    dimension = coordinates.shape[1] - 1
    region_offset = [
        [
            0,
        ]
        * (dimension + 1)
    ]
    for d in reversed(range(1, dimension + 1)):
        new_offset = []
        for offset in region_offset:
            offset = offset.copy()  # Do not modify the original
            offset[d] = 1
            new_offset.append(offset)
        region_offset.extend(new_offset)
    region_offset = torch.IntTensor(region_offset).to(coordinates.device)
    coordinates = torch.floor(coordinates).int().unsqueeze(1) + region_offset.unsqueeze(
        0
    )
    return coordinates.reshape(-1, dimension + 1)


class MinkowskiLayerNorm(nn.Module):
    
    def __init__(
        self,
        normalized_shape,
        eps=1e-05,
        elementwise_affine=True,
    ):
        super(MinkowskiLayerNorm, self).__init__()
        self.ln = nn.LayerNorm(
            normalized_shape,
            eps,
            elementwise_affine,
        )
        
    def forward(self, input):
        if isinstance(input, TensorField):
            return TensorField(
                self.ln(input.F),
                coordinate_field_map_key=input.coordinate_field_map_key,
                coordinate_manager=input.coordinate_manager,
                quantization_mode=input.quantization_mode,
            )
        elif isinstance(input, SparseTensor):
            return SparseTensor(
                self.ln(input.F),
                coordinate_map_key=input.coordinate_map_key,
                coordinate_manager=input.coordinate_manager,
            )
        else:
            return self.ln(input)