import torch
import torch.nn as nn
import MinkowskiEngine as ME

from models.common import *
from models.transformer_layers import *
from models.transformer_layers_v2 import *
from models.transformer_modules import *
from models.transformer_base import *
    

class EPTUNetV2(PointTransformer16UNetBase):
    H_DIM = 32
    KERNEL_SIZE = 3
    POOL_LAYER = StridedMaxPoolLayer
    LAYER = EfficientPointTransformerLayerV2Shared
    BLOCK = EfficientPointTransformerBasicBlockV2Shared
    LAYERS = (2, 3, 4, 14, 2, 2, 2, 2)
    PLANES = (64, 128, 256, 512, 256, 128, 128, 128)
    
    def network_initialization(self, in_channels, out_channels, D):
        super(EPTUNetV2, self).network_initialization(in_channels, out_channels, D)
        
        self.h_mlp = nn.Sequential(
            nn.Linear(D, self.H_DIM, bias=False),
            ME.MinkowskiBatchNorm(self.H_DIM),
            nn.Tanh(),
            nn.Linear(self.H_DIM, self.H_DIM, bias=False),
            ME.MinkowskiBatchNorm(self.H_DIM),
            nn.Tanh()
        )
        self.attn0p1 = self.LAYER(in_channels + self.H_DIM, self.INIT_DIM, kernel_size=5, dimension=D) # overriding
        self.final = nn.Sequential(
            nn.Linear(self.PLANES[7] + self.H_DIM, self.PLANES[7], bias=False),
            ME.MinkowskiBatchNorm(self.PLANES[7]),
            nn.ReLU(inplace=True),
            nn.Linear(self.PLANES[7], out_channels)
        ) # overriding
        
    @torch.no_grad()
    def normalize_points(self, points, centroids, tensor_map):
        tensor_map = tensor_map if tensor_map.dtype == torch.int64 else tensor_map.long()
        norm_points = points - centroids[tensor_map]
        
        return norm_points
    
    def voxelize_with_centroids(self, x: ME.TensorField):
        cm = x.coordinate_manager
        points = x.C[:, 1:]
        
        out = x.sparse()
        size = torch.Size([len(out), len(x)])
        tensor_map, field_map = cm.field_to_sparse_map(x.coordinate_key, out.coordinate_key)
        points_p1, count_p1 = downsample_points(points, tensor_map, field_map, size)
        norm_points = self.normalize_points(points, points_p1, tensor_map)
        
        h_embs = self.h_mlp(norm_points)
        down_h_embs = downsample_embeddings(h_embs, tensor_map, size, mode="avg")
        out = ME.SparseTensor(
            features=torch.cat([out.F, down_h_embs], dim=1),
            coordinate_map_key=out.coordinate_key,
            coordinate_manager=cm
        )
        
        return out, points_p1, count_p1, h_embs
    
    def devoxelize_with_centroids(self, out: ME.SparseTensor, x: ME.TensorField, h_embs):
        out = self.final(torch.cat([out.slice(x).F, h_embs], dim=1))
        
        return out
    
    def forward(self, x: ME.TensorField):
        out, points_p1, count_p1, h_embs = self.voxelize_with_centroids(x)
        out, kq_idx_p1k5, rel_pos_p1k5 = self.attn0p1(out, points_p1)
        out = self.relu(self.bn0(out))
        out, kq_idx_p1k3, rel_pos_p1k3 = self.attn1p1(out, points_p1)
        out_p1 = self.relu(self.bn1(out))

        out, points_p2, count_p2 = self.pool1p1s2(out_p1, points_p1, count_p1)
        kq_idx_p2k3, rel_pos_p2k3 = None, None
        for module in self.block1:
            out, kq_idx_p2k3, rel_pos_p2k3 = module(out, points_p2, kq_idx_p2k3, rel_pos_p2k3)
        out_p2 = self.relu(self.bn2(self.attn2p2(out, points_p2, kq_idx_p2k3, rel_pos_p2k3)[0]))

        out, points_p4, count_p4 = self.pool2p2s2(out_p2, points_p2, count_p2)
        kq_idx_p4k3, rel_pos_p4k3 = None, None
        for module in self.block2:
            out, kq_idx_p4k3, rel_pos_p4k3 = module(out, points_p4, kq_idx_p4k3, rel_pos_p4k3)
        out_p4 = self.relu(self.bn3(self.attn3p4(out, points_p4, kq_idx_p4k3, rel_pos_p4k3)[0]))

        out, points_p8, count_p8 = self.pool3p4s2(out_p4, points_p4, count_p4)
        kq_idx_p8k3, rel_pos_p8k3 = None, None
        for module in self.block3:
            out, kq_idx_p8k3, rel_pos_p8k3 = module(out, points_p8, kq_idx_p8k3, rel_pos_p8k3)
        out_p8 = self.relu(self.bn4(self.attn4p8(out, points_p8, kq_idx_p8k3, rel_pos_p8k3)[0]))

        out, points_p16, _ = self.pool4p8s2(out_p8, points_p8, count_p8)
        kq_idx_p16k3, rel_pos_p16k3 = None, None
        for module in self.block4:
            out, kq_idx_p16k3, rel_pos_p16k3 = module(out, points_p16, kq_idx_p16k3, rel_pos_p16k3)

        out = self.pooltr(out)
        out = ME.cat(out, out_p8)
        out = self.relu(self.bn5(self.attn5p8(out, points_p8, kq_idx_p8k3, rel_pos_p8k3)[0]))
        for module in self.block5:
            out = module(out, points_p8, kq_idx_p8k3, rel_pos_p8k3)[0]

        out = self.pooltr(out)
        out = ME.cat(out, out_p4)
        out = self.relu(self.bn6(self.attn6p4(out, points_p4, kq_idx_p4k3, rel_pos_p4k3)[0]))
        for module in self.block6:
            out = module(out, points_p4, kq_idx_p4k3, rel_pos_p4k3)[0]

        out = self.pooltr(out)
        out = ME.cat(out, out_p2)
        out = self.relu(self.bn7(self.attn7p2(out, points_p2, kq_idx_p2k3, rel_pos_p2k3)[0]))
        for module in self.block7:
            out = module(out, points_p2, kq_idx_p2k3, rel_pos_p2k3)[0]

        out = self.pooltr(out)
        out = ME.cat(out, out_p1)
        out = self.relu(self.bn8(self.attn8p1(out, points_p1, kq_idx_p1k3, rel_pos_p1k3)[0]))
        for module in self.block8:
            out = module(out, points_p1, kq_idx_p1k3, rel_pos_p1k3)[0]

        out = self.devoxelize_with_centroids(out, x, h_embs)

        return out