import torch.nn as nn
import MinkowskiEngine as ME

from models.common import stride_centroids
from models.transformer_layers import *
from models.transformer_layers_v2 import *


# -------------------------------
#         Residual Blocks
# -------------------------------

class TransformerBasicBlockBase(nn.Module):
    expansion = 1
    LAYER = None

    def __init__(self, in_channels, out_channels=None, kernel_size=3, dimension=3):
        out_channels = in_channels if out_channels is None else out_channels
        assert self.LAYER is not None
        super(TransformerBasicBlockBase, self).__init__()

        self.layer1 = self.LAYER(in_channels, out_channels, kernel_size=kernel_size, dimension=dimension)
        self.norm1 = ME.MinkowskiBatchNorm(out_channels)
        self.layer2 = self.LAYER(out_channels, kernel_size=kernel_size, dimension=dimension)
        self.norm2 = ME.MinkowskiBatchNorm(out_channels)
        self.relu = ME.MinkowskiReLU(inplace=True)

    def forward(self, stensor, points):
        out = self.layer1(stensor, points)
        out = self.norm1(out)
        out = self.relu(out)

        out = self.layer2(out, points)
        out = self.norm2(out)

        out += stensor
        out = self.relu(out)

        return out


class PointTransformerBasicBlock(TransformerBasicBlockBase):
    LAYER = PointTransformerLayer
    
class EfficientPointTransformerBasicBlock(TransformerBasicBlockBase):
    LAYER = EfficientPointTransformerLayer
    
    def forward(self, stensor, points, meta=None):
        out, meta = self.layer1(stensor, points, meta)
        out = self.norm1(out)
        out = self.relu(out)

        out, meta = self.layer2(out, points, meta)
        out = self.norm2(out)

        out += stensor
        out = self.relu(out)

        return out, meta
 
class EfficientPointTransformerBasicBlockInterOnly(EfficientPointTransformerBasicBlock):
    LAYER = EfficientPointTransformerLayerInterOnly
 
class EfficientPointTransformerBasicBlockWithKey(EfficientPointTransformerBasicBlock):
    LAYER = EfficientPointTransformerLayerWithKey

class EfficientPointTransformerBasicBlockWithKeyApprox(EfficientPointTransformerBasicBlock):
    LAYER = EfficientPointTransformerLayerWithKeyApprox

class EfficientPointTransformerBasicBlockWithKeySoftmax(EfficientPointTransformerBasicBlock):
    LAYER = EfficientPointTransformerLayerWithKeySoftmax

class EfficientPointTransformerBasicBlockWithKeyScaled(EfficientPointTransformerBasicBlock):
    LAYER = EfficientPointTransformerLayerWithKeyScaled

class EfficientPointTransformerBasicBlockLinear(EfficientPointTransformerBasicBlock):
    LAYER = EfficientPointTransformerLayerLinear
    
class EfficientPointTransformerBasicBlockLinearSigmoid(EfficientPointTransformerBasicBlock):
    LAYER = EfficientPointTransformerLayerLinearSigmoid
    
class EfficientPointTransformerBasicBlockLipschitzShared(TransformerBasicBlockBase):
    LAYER = EfficientPointTransformerLayerLipschitzShared
    
    def forward(self, stensor, points, meta1=None, meta2=None):
        out, meta1, meta2 = self.layer1(stensor, points, meta1, meta2)
        out = self.norm1(out)
        out = self.relu(out)

        out, meta1, meta2 = self.layer2(out, points, meta1, meta2)
        out = self.norm2(out)

        out += stensor
        out = self.relu(out)

        return out, meta1, meta2
    
class EfficientPointTransformerBasicBlockLipschitz(EfficientPointTransformerBasicBlockLipschitzShared):
    LAYER = EfficientPointTransformerLayerLipschitz
    
class EfficientPointTransformerBasicBlockMappingShared(EfficientPointTransformerBasicBlockLipschitzShared):
    LAYER = EfficientPointTransformerLayerMappingShared
    
class EfficientPointTransformerBasicBlockMapping(EfficientPointTransformerBasicBlockLipschitzShared):
    LAYER = EfficientPointTransformerLayerMapping
    
class EfficientPointTransformerBasicBlockMappingLighter(EfficientPointTransformerBasicBlockLipschitzShared):
    LAYER = EfficientPointTransformerLayerMappingLighter
    
class EfficientPointTransformerBasicBlockMappingFiner(EfficientPointTransformerBasicBlockLipschitzShared):
    LAYER = EfficientPointTransformerLayerMappingFiner
    
class EfficientPointTransformerBasicBlockLipschitzFiner(EfficientPointTransformerBasicBlockLipschitzShared):
    LAYER = EfficientPointTransformerLayerLipschitzFiner
    
class EfficientPointTransformerBasicBlockMHSA(EfficientPointTransformerBasicBlockLipschitzShared):
    LAYER = EfficientPointTransformerLayerMHSA
    
class EfficientPointTransformerBasicBlockMappingFinerWithKey(EfficientPointTransformerBasicBlockLipschitzShared):
    LAYER = EfficientPointTransformerLayerMappingFinerWithKey
    
class EfficientPointTransformerBasicBlockToken(EfficientPointTransformerBasicBlockLipschitzShared):
    LAYER = EfficientPointTransformerLayerToken
    
class PointTransformerBasicBlockShared(EfficientPointTransformerBasicBlockLipschitzShared):
    LAYER = PointTransformerLayerShared
    
class EfficientPointTransformerBasicBlockV2Shared(EfficientPointTransformerBasicBlockLipschitzShared):
    LAYER = EfficientPointTransformerLayerV2Shared


class TransformerBottleneckBase(nn.Module):
    expansion = 1
    LAYER = None

    def __init__(self, in_channels, out_channels=None, kernel_size=3, dimension=3):
        out_channels = in_channels if out_channels is None else out_channels
        assert self.LAYER is not None
        super(TransformerBottleneckBase, self).__init__()

        self.linear1 = ME.MinkowskiConvolution(in_channels, in_channels, kernel_size=1, bias=False, dimension=dimension)
        self.norm1 = ME.MinkowskiBatchNorm(in_channels)
        self.layer2 = self.LAYER(in_channels, kernel_size=kernel_size, dimension=dimension)
        self.norm2 = ME.MinkowskiBatchNorm(in_channels)
        self.linear3 = ME.MinkowskiConvolution(in_channels, in_channels, kernel_size=1, bias=False, dimension=dimension)
        self.norm3 = ME.MinkowskiBatchNorm(in_channels)
        self.relu = ME.MinkowskiReLU(inplace=True)

    def forward(self, stensor, points):
        out = self.linear1(stensor)
        out = self.norm1(out)
        out = self.relu(out)

        out = self.layer2(out, points)
        out = self.norm2(out)
        out = self.relu(out)
        
        out = self.linear3(out)
        out = self.norm3(out)

        out += stensor
        out = self.relu(out)

        return out

class EfficientPointTransformerBottleneckV2Shared(TransformerBottleneckBase):
    LAYER = EfficientPointTransformerLayerV2Shared
    
    def forward(self, stensor, points, meta1=None, meta2=None):
        out = self.linear1(stensor)
        out = self.norm1(out)
        out = self.relu(out)

        out = self.layer2(out, points, meta1, meta2)[0]
        out = self.norm2(out)
        out = self.relu(out)
        
        out = self.linear3(out)
        out = self.norm3(out)

        out += stensor
        out = self.relu(out)

        return out, meta1, meta2
    

# -------------------------------
#         Pooling Layers
# -------------------------------

class StridedPoolLayerBase(nn.Module):
    POOL_FUNC = None
    
    def __init__(self, kernel_size=2, stride=2, dimension=3):
        assert self.POOL_FUNC is not None
        assert kernel_size == 2
        assert stride == 2
        super(StridedPoolLayerBase, self).__init__()

        self.pool = self.POOL_FUNC(kernel_size=kernel_size, stride=stride, dimension=dimension)
        
    def forward(self, stensor, points, count):
        assert len(stensor) == len(points)
        cm = stensor.coordinate_manager
        
        down_stensor = self.pool(stensor)
        cols, rows = cm.stride_map(stensor.coordinate_key, down_stensor.coordinate_key)
        size = torch.Size([len(down_stensor), len(stensor)])
        down_points, down_count = stride_centroids(points, count, rows, cols, size)
        
        return down_stensor, down_points, down_count
    
class StridedMaxPoolLayer(StridedPoolLayerBase):
    POOL_FUNC = ME.MinkowskiMaxPooling
    
class StridedAvgPoolLayer(StridedPoolLayerBase):
    POOL_FUNC = ME.MinkowskiAvgPooling
    
    
# class AttentiveStridedPoolLayerBase(nn.Module):
#     POOL_FUNC = None
    
#     def __init__(self, channels, kernel_size=2, stride=2, dimension=3):
#         assert self.POOL_FUNC is not None
#         assert kernel_size == 2
#         assert stride == 2
#         super(AttentiveStridedPoolLayerBase, self).__init__()

#         self.pool = self.POOL_FUNC(kernel_size=kernel_size, stride=stride, dimension=dimension)
#         self.sum_pool = ME.MinkowskiSumPooling(kernel_size=kernel_size, stride=stride, dimension=dimension)
        
#     def forward(self, stensor, points, count):
#         assert len(stensor) == len(points)
        
#         down_stensor = self.pool(stensor)
        
#         with torch.no_grad():
#             points_count_stensor = ME.SparseTensor(
#                 torch.cat([points * count, count], dim=1),
#                 coordinate_map_key=stensor.coordinate_key,
#                 coordinate_manager=stensor.coordinate_manager
#             )
#             down_points_, down_count = self.sum_pool(points_count_stensor).F.split([points.shape[1], 1], dim=1)
#             down_points = torch.true_divide(down_points_, down_count)
        
#         return down_stensor, down_points, down_count
    
# class AttentiveStridedMaxPoolLayer(AttentiveStridedPoolLayerBase):
#     POOL_FUNC = ME.MinkowskiMaxPooling
    
# class AttentiveStridedAvgPoolLayer(AttentiveStridedPoolLayerBase):
#     POOL_FUNC = ME.MinkowskiAvgPooling