from time import sleep

import torch
import torch.nn as nn
import MinkowskiEngine as ME

from models.common import downsample_points
from models.resunet import Res16UNetBase
from models.transformer_layers import *
from models.transformer_modules import *


class Transformer16UNetBase(Res16UNetBase):
    LAYER = None
    BLOCK = None
    KERNEL_SIZE = None
    POOL_LAYER = None

    def network_initialization(self, in_channels, out_channels, D):
        assert self.LAYER is not None
        assert self.BLOCK is not None
        assert self.KERNEL_SIZE is not None
        assert self.POOL_LAYER is not None
        
        self.attn0p1 = self.LAYER(in_channels, self.INIT_DIM, kernel_size=5, dimension=D)
        self.bn0 = ME.MinkowskiBatchNorm(self.INIT_DIM)

        self.attn1p1 = self.LAYER(self.INIT_DIM, self.PLANES[0], kernel_size=self.KERNEL_SIZE, dimension=D)
        self.bn1 = ME.MinkowskiBatchNorm(self.PLANES[0])
        self.pool1p1s2 = self.POOL_LAYER(dimension=D)
        self.block1 = nn.ModuleList([self.BLOCK(self.PLANES[0], kernel_size=self.KERNEL_SIZE, dimension=D) for _ in range(self.LAYERS[0])])

        self.attn2p2 = self.LAYER(self.PLANES[0], self.PLANES[1], kernel_size=self.KERNEL_SIZE, dimension=D)
        self.bn2 = ME.MinkowskiBatchNorm(self.PLANES[1])
        self.pool2p2s2 = self.POOL_LAYER(dimension=D)
        self.block2 = nn.ModuleList([self.BLOCK(self.PLANES[1], kernel_size=self.KERNEL_SIZE, dimension=D) for _ in range(self.LAYERS[1])])

        self.attn3p4 = self.LAYER(self.PLANES[1], self.PLANES[2], kernel_size=self.KERNEL_SIZE, dimension=D)
        self.bn3 = ME.MinkowskiBatchNorm(self.PLANES[2])
        self.pool3p4s2 = self.POOL_LAYER(dimension=D)
        self.block3 = nn.ModuleList([self.BLOCK(self.PLANES[2], kernel_size=self.KERNEL_SIZE, dimension=D) for _ in range(self.LAYERS[2])])

        self.attn4p8 = self.LAYER(self.PLANES[2], self.PLANES[3], kernel_size=self.KERNEL_SIZE, dimension=D)
        self.bn4 = ME.MinkowskiBatchNorm(self.PLANES[3])
        self.pool4p8s2 = self.POOL_LAYER(dimension=D)
        self.block4 = nn.ModuleList([self.BLOCK(self.PLANES[3], kernel_size=self.KERNEL_SIZE, dimension=D) for _ in range(self.LAYERS[3])])

        self.attn5p8 = self.LAYER(self.PLANES[3] + self.PLANES[3], self.PLANES[4], kernel_size=self.KERNEL_SIZE, dimension=D)
        self.bn5 = ME.MinkowskiBatchNorm(self.PLANES[4])
        self.block5 = nn.ModuleList([self.BLOCK(self.PLANES[4], kernel_size=self.KERNEL_SIZE, dimension=D) for _ in range(self.LAYERS[4])])

        self.attn6p4 = self.LAYER(self.PLANES[4] + self.PLANES[2], self.PLANES[5], kernel_size=self.KERNEL_SIZE, dimension=D)
        self.bn6 = ME.MinkowskiBatchNorm(self.PLANES[5])
        self.block6 = nn.ModuleList([self.BLOCK(self.PLANES[5], kernel_size=self.KERNEL_SIZE, dimension=D) for _ in range(self.LAYERS[5])])

        self.attn7p2 = self.LAYER(self.PLANES[5] + self.PLANES[1], self.PLANES[6], kernel_size=self.KERNEL_SIZE, dimension=D)
        self.bn7 = ME.MinkowskiBatchNorm(self.PLANES[6])
        self.block7 = nn.ModuleList([self.BLOCK(self.PLANES[6], kernel_size=self.KERNEL_SIZE, dimension=D) for _ in range(self.LAYERS[6])])

        self.attn8p1 = self.LAYER(self.PLANES[6] + self.PLANES[0], self.PLANES[7], kernel_size=self.KERNEL_SIZE, dimension=D)
        self.bn8 = ME.MinkowskiBatchNorm(self.PLANES[7])
        self.block8 = nn.ModuleList([self.BLOCK(self.PLANES[7], kernel_size=self.KERNEL_SIZE, dimension=D) for _ in range(self.LAYERS[7])])

        self.final = ME.MinkowskiConvolution(self.PLANES[7], out_channels, kernel_size=1, stride=1, bias=True, dimension=D)
        self.relu = ME.MinkowskiReLU(inplace=True)
        self.pooltr = ME.MinkowskiPoolingTranspose(kernel_size=2, stride=2, dimension=D)
        
    def voxelize_with_centroids(self, x: ME.TensorField):
        cm = x.coordinate_manager
        points = x.C[:, 1:]
        
        out = x.sparse()
        tensor_map, field_map = cm.field_to_sparse_map(x.coordinate_key, out.coordinate_key)
        size = torch.Size([len(out), len(x)])
        points_p1, count_p1 = downsample_points(points, tensor_map, field_map, size)
        
        return out, points_p1, count_p1, None

    def devoxelize_with_centroids(self, out: ME.SparseTensor, x: ME.TensorField, h_embs):
        out = self.final(out)
        
        return out.slice(x).F


class PointTransformer16UNetBase(Transformer16UNetBase):
    LAYER = PointTransformerLayer
    POOL_LAYER = StridedMaxPoolLayer
    
    def forward(self, x: ME.TensorField):
        out, points_p1, count_p1, h_embs = self.voxelize_with_centroids(x)
        out = self.relu(self.bn0(self.attn0p1(out, points_p1)))
        out_p1 = self.relu(self.bn1(self.attn1p1(out, points_p1)))

        out, points_p2, count_p2 = self.pool1p1s2(out_p1, points_p1, count_p1)
        for module in self.block1:
            out = module(out, points_p2)
        out_p2 = self.relu(self.bn2(self.attn2p2(out, points_p2)))

        out, points_p4, count_p4 = self.pool2p2s2(out_p2, points_p2, count_p2)
        for module in self.block2:
            out = module(out, points_p4)
        out_p4 = self.relu(self.bn3(self.attn3p4(out, points_p4)))

        out, points_p8, count_p8 = self.pool3p4s2(out_p4, points_p4, count_p4)
        for module in self.block3:
            out = module(out, points_p8)
        out_p8 = self.relu(self.bn4(self.attn4p8(out, points_p8)))

        out, points_p16, _ = self.pool4p8s2(out_p8, points_p8, count_p8)
        for module in self.block4:
            out = module(out, points_p16)

        out = self.pooltr(out)
        out = ME.cat(out, out_p8)
        out = self.relu(self.bn5(self.attn5p8(out, points_p8)))
        for module in self.block5:
            out = module(out, points_p8)

        out = self.pooltr(out)
        out = ME.cat(out, out_p4)
        out = self.relu(self.bn6(self.attn6p4(out, points_p4)))
        for module in self.block6:
            out = module(out, points_p4)

        out = self.pooltr(out)
        out = ME.cat(out, out_p2)
        out = self.relu(self.bn7(self.attn7p2(out, points_p2)))
        for module in self.block7:
            out = module(out, points_p2)

        out = self.pooltr(out)
        out = ME.cat(out, out_p1)
        out = self.relu(self.bn8(self.attn8p1(out, points_p1)))
        for module in self.block8:
            out = module(out, points_p1)

        out = self.devoxelize_with_centroids(out, x, h_embs)

        return out
    
    
class EfficientPointTransformer16UNetBase(Transformer16UNetBase):
    LAYER = EfficientPointTransformerLayer
    POOL_LAYER = StridedMaxPoolLayer
    
    @torch.no_grad()
    def normalize_centroids(self, down_points, coordinates, tensor_stride):
        norm_points = (down_points - coordinates[:, 1:]) / tensor_stride - 0.5

        return norm_points
    
    def voxelize_with_centroids(self, x: ME.TensorField):
        cm = x.coordinate_manager
        points = x.C[:, 1:]
        
        out = x.sparse()
        tensor_map, field_map = cm.field_to_sparse_map(x.coordinate_key, out.coordinate_key)
        size = torch.Size([len(out), len(x)])
        points_p1, count_p1 = downsample_points(points, tensor_map, field_map, size)
        norm_points_p1 = self.normalize_centroids(points_p1, out.C, out.tensor_stride[0])
        
        return out, norm_points_p1, points_p1, count_p1, None
    
    def forward(self, x: ME.TensorField):
        out, norm_points_p1, points_p1, count_p1, h_embs = self.voxelize_with_centroids(x)
        out, kq_map_p1k5 = self.attn0p1(out, norm_points_p1)
        out = self.relu(self.bn0(out))
        out, kq_map_p1k3 = self.attn1p1(out, norm_points_p1)
        out_p1 = self.relu(self.bn1(out))

        out, points_p2, count_p2 = self.pool1p1s2(out_p1, points_p1, count_p1)
        norm_points_p2 = self.normalize_centroids(points_p2, out.C, out.tensor_stride[0])
        kq_map_p2k3 = None
        for module in self.block1:
            out, kq_map_p2k3 = module(out, norm_points_p2, kq_map_p2k3)
        out_p2 = self.relu(self.bn2(self.attn2p2(out, norm_points_p2, kq_map_p2k3)[0]))

        out, points_p4, count_p4 = self.pool2p2s2(out_p2, points_p2, count_p2)
        norm_points_p4 = self.normalize_centroids(points_p4, out.C, out.tensor_stride[0])
        kq_map_p4k3 = None
        for module in self.block2:
            out, kq_map_p4k3 = module(out, norm_points_p4, kq_map_p4k3)
        out_p4 = self.relu(self.bn3(self.attn3p4(out, norm_points_p4, kq_map_p4k3)[0]))

        out, points_p8, count_p8 = self.pool3p4s2(out_p4, points_p4, count_p4)
        norm_points_p8 = self.normalize_centroids(points_p8, out.C, out.tensor_stride[0])
        kq_map_p8k3 = None
        for module in self.block3:
            out, kq_map_p8k3 = module(out, norm_points_p8, kq_map_p8k3)
        out_p8 = self.relu(self.bn4(self.attn4p8(out, norm_points_p8, kq_map_p8k3)[0]))

        out, points_p16, _ = self.pool4p8s2(out_p8, points_p8, count_p8)
        norm_points_p16 = self.normalize_centroids(points_p16, out.C, out.tensor_stride[0])
        kq_map_p16k3 = None
        for module in self.block4:
            out, kq_map_p16k3 = module(out, norm_points_p16, kq_map_p16k3)

        out = self.pooltr(out)
        out = ME.cat(out, out_p8)
        out = self.relu(self.bn5(self.attn5p8(out, norm_points_p8, kq_map_p8k3)[0]))
        for module in self.block5:
            out = module(out, norm_points_p8, kq_map_p8k3)[0]

        out = self.pooltr(out)
        out = ME.cat(out, out_p4)
        out = self.relu(self.bn6(self.attn6p4(out, norm_points_p4, kq_map_p4k3)[0]))
        for module in self.block6:
            out = module(out, norm_points_p4, kq_map_p4k3)[0]

        out = self.pooltr(out)
        out = ME.cat(out, out_p2)
        out = self.relu(self.bn7(self.attn7p2(out, norm_points_p2, kq_map_p2k3)[0]))
        for module in self.block7:
            out = module(out, norm_points_p2, kq_map_p2k3)[0]

        out = self.pooltr(out)
        out = ME.cat(out, out_p1)
        out = self.relu(self.bn8(self.attn8p1(out, norm_points_p1, kq_map_p1k3)[0]))
        for module in self.block8:
            out = module(out, norm_points_p1, kq_map_p1k3)[0]

        out = self.devoxelize_with_centroids(out, x, h_embs)

        return out