import torch.nn as nn
import torchsparse
import torchsparse.nn as spnn
import torch 
import numpy as np 

from torch import nn
from torchsparse import PointTensor, SparseTensor
import numpy as np
from torch_scatter import scatter

from pcdet.models.model_utils.spvnas_utils import initial_voxelize, point_to_voxel, voxel_to_point
from pcdet.models.model_utils import graph_utils



__all__ = ['MinkUNet']


class BasicConvolutionBlock(nn.Module):

    def __init__(self, inc, outc, ks=3, stride=1, dilation=1):
        super().__init__()
        self.net = nn.Sequential(
            spnn.Conv3d(inc,
                        outc,
                        kernel_size=ks,
                        dilation=dilation,
                        stride=stride),
            spnn.BatchNorm(outc),
            spnn.ReLU(True),
        )

    def forward(self, x):
        out = self.net(x)
        return out


class BasicDeconvolutionBlock(nn.Module):

    def __init__(self, inc, outc, ks=3, stride=1):
        super().__init__()
        self.net = nn.Sequential(
            spnn.Conv3d(inc,
                        outc,
                        kernel_size=ks,
                        stride=stride,
                        transposed=True),
            spnn.BatchNorm(outc),
            spnn.ReLU(True),
        )

    def forward(self, x):
        return self.net(x)


class ResidualBlock(nn.Module):

    def __init__(self, inc, outc, ks=3, stride=1, dilation=1):
        super().__init__()
        self.net = nn.Sequential(
            spnn.Conv3d(inc,
                        outc,
                        kernel_size=ks,
                        dilation=dilation,
                        stride=stride),
            spnn.BatchNorm(outc),
            spnn.ReLU(True),
            spnn.Conv3d(outc, outc, kernel_size=ks, dilation=dilation,
                        stride=1),
            spnn.BatchNorm(outc),
        )

        if inc == outc and stride == 1:
            self.downsample = nn.Sequential()
        else:
            self.downsample = nn.Sequential(
                spnn.Conv3d(inc, outc, kernel_size=1, dilation=1,
                            stride=stride),
                spnn.BatchNorm(outc),
            )

        self.relu = spnn.ReLU(True)

    def forward(self, x):
        out = self.relu(self.net(x) + self.downsample(x))
        return out


class MinkUNet(nn.Module):

    def __init__(self,model_cfg, runtime_cfg, **kwargs):
        super().__init__()
        self.model_cfg = model_cfg


        cr = kwargs.get('cr', 1.0)
        cs = [32, 32, 64, 128, 256, 256, 128, 96, 96]
        cs = [int(cr * x) for x in cs]
        input_channels = model_cfg.get('INPUT_CHANNELS', 6)

        self.run_up = model_cfg.get('run_up', True)
        self.grid_size = self.model_cfg.get('GRID_SIZE', [0.1, 0.1, 0.15])
        self.point_cloud_range = self.model_cfg.get('POINT_CLOUD_RANGE', [-75.2, -75.2, -2, 75.2, 75.2 ,4])
        self.grid_size = torch.from_numpy(np.array(self.grid_size)).cuda()
        self.point_cloud_range = torch.from_numpy(np.array(self.point_cloud_range)).cuda()

        self.graph = graph_utils.KNNGraph({}, dict(NUM_NEIGHBORS=1))

        self.stem = nn.Sequential(
            spnn.Conv3d(input_channels, cs[0], kernel_size=3, stride=1),
            spnn.BatchNorm(cs[0]), spnn.ReLU(True),
            spnn.Conv3d(cs[0], cs[0], kernel_size=3, stride=1),
            spnn.BatchNorm(cs[0]), spnn.ReLU(True))

        self.stage1 = nn.Sequential(
            BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1),
        )

        self.stage2 = nn.Sequential(
            BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1))

        self.stage3 = nn.Sequential(
            BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1),
        )

        self.stage4 = nn.Sequential(
            BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1),
        )

        self.up1 = nn.ModuleList([
            BasicDeconvolutionBlock(cs[4], cs[5], ks=2, stride=2),
            nn.Sequential(
                ResidualBlock(cs[5] + cs[3], cs[5], ks=3, stride=1, dilation=1),
                ResidualBlock(cs[5], cs[5], ks=3, stride=1, dilation=1),
            )
        ])

        self.up2 = nn.ModuleList([
            BasicDeconvolutionBlock(cs[5], cs[6], ks=2, stride=2),
            nn.Sequential(
                ResidualBlock(cs[6] + cs[2], cs[6], ks=3, stride=1, dilation=1),
                ResidualBlock(cs[6], cs[6], ks=3, stride=1, dilation=1),
            )
        ])

        self.up3 = nn.ModuleList([
            BasicDeconvolutionBlock(cs[6], cs[7], ks=2, stride=2),
            nn.Sequential(
                ResidualBlock(cs[7] + cs[1], cs[7], ks=3, stride=1, dilation=1),
                ResidualBlock(cs[7], cs[7], ks=3, stride=1, dilation=1),
            )
        ])

        self.up4 = nn.ModuleList([
            BasicDeconvolutionBlock(cs[7], cs[8], ks=2, stride=2),
            nn.Sequential(
                ResidualBlock(cs[8] + cs[0], cs[8], ks=3, stride=1, dilation=1),
                ResidualBlock(cs[8], cs[8], ks=3, stride=1, dilation=1),
            )
        ])

        # self.classifier = nn.Sequential(nn.Linear(cs[8], kwargs['num_classes']))


        self.weight_initialization()
        self.dropout = nn.Dropout(0.3, True)
        self.num_point_features = cs[8]

    def weight_initialization(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    #def forward(self, x):
    def forward(self, batch_dict):
        point_xyz = batch_dict['spv_point_cloud'][:, 1:]
        point_feat = batch_dict['spv_point_feat']


        spv_coord = torch.cat((batch_dict['spv_coord'], batch_dict['spv_point_cloud'][:, 0].unsqueeze(1)), 1)
        x = SparseTensor(torch.cat((point_xyz, point_feat), 1), spv_coord)
        x.C = x.C.int()
        
        x0 = self.stem(x)
        x1 = self.stage1(x0)
        x2 = self.stage2(x1)
        x3 = self.stage3(x2)
        x4 = self.stage4(x3)

        y1 = self.up1[0](x4)
        y1 = torchsparse.cat([y1, x3])
        y1 = self.up1[1](y1)

        y2 = self.up2[0](y1)
        y2 = torchsparse.cat([y2, x2])
        y2 = self.up2[1](y2)

        y3 = self.up3[0](y2)
        y3 = torchsparse.cat([y3, x1])
        y3 = self.up3[1](y3)

        y4 = self.up4[0](y3)
        y4 = torchsparse.cat([y4, x0])
        y4 = self.up4[1](y4)
        
        x_convs = [y4, y3, y2, y1]
        for i, x_conv in enumerate(x_convs):
            downsample_times = [1, 2, 4, 8, [8, 8, 16]][i]
            downsample_times = torch.tensor(downsample_times).to(x_conv.F)

            coord_i = x_conv.C[:, 0:3].clone()
            voxel_size = torch.tensor(self.grid_size, device=coord_i.device).float() * downsample_times
            pc_range = torch.tensor(self.point_cloud_range[0:3], device=coord_i.device).float()
            voxel_corners = (coord_i) * voxel_size + pc_range
            #point_corners += voxel_size * 0.5
            batch_dict[f'spv_unet_up_bcenter{4-i}'] = torch.cat([x_conv.C[:, 3:4], voxel_corners], dim=-1)
            batch_dict[f'spv_unet_up_feat{4-i}'] = x_conv.F

        y4_coord = y4.C[:, 0:3].clone()
        y4_voxel_centers = y4_coord * self.grid_size.unsqueeze(0) + self.point_cloud_range[:3].unsqueeze(0)
        y4_voxel_centers = torch.cat((y4.C[:, 3].clone().unsqueeze(1), y4_voxel_centers,), 1)

        
        """
        ref_bxyz = batch_dict['point_bxyz']
        #query_bxyz = y4_voxel_centers.type(torch.FloatTensor).cuda()
        query_bxyz = y4_voxel_centers
        e_ref, e_query = self.graph(ref_bxyz, query_bxyz)
        y4_voxel_segmentation_label = scatter(batch_dict['segmentation_label'][e_ref], e_query, dim=0,
                                            dim_size=query_bxyz.shape[0], reduce='max')
        """
        batch_dict['spv_unet_fea'] =  y4.F
        batch_dict['spv_unet_indices'] =  y4.C
        #batch_dict['spv_segmentation_label'] = y4_voxel_segmentation_label.long()
        batch_dict['spv_point_bxyz'] = y4_voxel_centers
        return batch_dict

        #out = self.classifier(y4.F)
        #return out
