import torchsparse
import torch
import torchsparse.nn as spnn
from torch import nn
from torchsparse import PointTensor, SparseTensor
import numpy as np 
from torch_scatter import scatter

from pcdet.models.model_utils.spvnas_utils import initial_voxelize, point_to_voxel, voxel_to_point
from pcdet.models.model_utils import graph_utils


__all__ = ['SPVCNN']


class BasicConvolutionBlock(nn.Module):

    def __init__(self, inc, outc, ks=3, stride=1, dilation=1):
        super().__init__()
        self.net = nn.Sequential(
            spnn.Conv3d(inc,
                        outc,
                        kernel_size=ks,
                        dilation=dilation,
                        stride=stride),
            spnn.BatchNorm(outc),
            spnn.ReLU(True),
        )

    def forward(self, x):
        out = self.net(x)
        return out


class BasicDeconvolutionBlock(nn.Module):

    def __init__(self, inc, outc, ks=3, stride=1):
        super().__init__()
        self.net = nn.Sequential(
            spnn.Conv3d(inc,
                        outc,
                        kernel_size=ks,
                        stride=stride,
                        transposed=True),
            spnn.BatchNorm(outc),
            spnn.ReLU(True),
        )

    def forward(self, x):
        return self.net(x)


class ResidualBlock(nn.Module):

    def __init__(self, inc, outc, ks=3, stride=1, dilation=1):
        super().__init__()
        self.net = nn.Sequential(
            spnn.Conv3d(inc,
                        outc,
                        kernel_size=ks,
                        dilation=dilation,
                        stride=stride),
            spnn.BatchNorm(outc),
            spnn.ReLU(True),
            spnn.Conv3d(outc, outc, kernel_size=ks, dilation=dilation,
                        stride=1),
            spnn.BatchNorm(outc),
        )

        if inc == outc and stride == 1:
            self.downsample = nn.Identity()
        else:
            self.downsample = nn.Sequential(
                spnn.Conv3d(inc, outc, kernel_size=1, dilation=1,
                            stride=stride),
                spnn.BatchNorm(outc),
            )

        self.relu = spnn.ReLU(True)

    def forward(self, x):
        out = self.relu(self.net(x) + self.downsample(x))
        return out


class SPVCNN(nn.Module):

    def __init__(self,model_cfg, runtime_cfg, **kwargs):
        super().__init__()
        self.model_cfg = model_cfg 

        cr = self.model_cfg.get('cr', 1.0)
        cs = [32, 32, 64, 128, 256, 256, 128, 96, 96]
        cs = [int(cr * x) for x in cs]
        input_channels = model_cfg.get('INPUT_CHANNELS', 6) 
        
        self.pres = self.model_cfg.get('pres', 0.05)
        self.vres = self.model_cfg.get('vres', 0.05)
        
        self.grid_size = self.model_cfg.get('GRID_SIZE', [0.1, 0.1, 0.15])
        self.point_cloud_range = self.model_cfg.get('POINT_CLOUD_RANGE', [-75.2, -75.2, -2, 75.2, 75.2 ,4])
        self.grid_size = torch.from_numpy(np.array(self.grid_size)).cuda()
        self.point_cloud_range = torch.from_numpy(np.array(self.point_cloud_range)).cuda()
        self.graph = graph_utils.KNNGraph({}, dict(NUM_NEIGHBORS=1))

        self.stem = nn.Sequential(
            spnn.Conv3d(input_channels, cs[0], kernel_size=3, stride=1),
            spnn.BatchNorm(cs[0]), spnn.ReLU(True),
            spnn.Conv3d(cs[0], cs[0], kernel_size=3, stride=1),
            spnn.BatchNorm(cs[0]), spnn.ReLU(True)
            )

        self.stage1 = nn.Sequential(
            BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1),
        )

        self.stage2 = nn.Sequential(
            BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1),
        )

        self.stage3 = nn.Sequential(
            BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1),
        )

        self.stage4 = nn.Sequential(
            BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1),
            ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1),
            ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1),
        )

        self.up1 = nn.ModuleList([
            BasicDeconvolutionBlock(cs[4], cs[5], ks=2, stride=2),
            nn.Sequential(
                ResidualBlock(cs[5] + cs[3], cs[5], ks=3, stride=1, dilation=1),
                ResidualBlock(cs[5], cs[5], ks=3, stride=1, dilation=1),
            )
        ])

        self.up2 = nn.ModuleList([
            BasicDeconvolutionBlock(cs[5], cs[6], ks=2, stride=2),
            nn.Sequential(
                ResidualBlock(cs[6] + cs[2], cs[6], ks=3, stride=1, dilation=1),
                ResidualBlock(cs[6], cs[6], ks=3, stride=1, dilation=1),
            )
        ])

        self.up3 = nn.ModuleList([
            BasicDeconvolutionBlock(cs[6], cs[7], ks=2, stride=2),
            nn.Sequential(
                ResidualBlock(cs[7] + cs[1], cs[7], ks=3, stride=1, dilation=1),
                ResidualBlock(cs[7], cs[7], ks=3, stride=1, dilation=1),
            )
        ])

        self.up4 = nn.ModuleList([
            BasicDeconvolutionBlock(cs[7], cs[8], ks=2, stride=2),
            nn.Sequential(
                ResidualBlock(cs[8] + cs[0], cs[8], ks=3, stride=1, dilation=1),
                ResidualBlock(cs[8], cs[8], ks=3, stride=1, dilation=1),
            )
        ])

        # self.classifier = nn.Sequential(nn.Linear(cs[8], kwargs['num_classes']))

        self.point_transforms = nn.ModuleList([
            nn.Sequential(
                nn.Linear(cs[0], cs[4]),
                nn.BatchNorm1d(cs[4]),
                nn.ReLU(True),
            ),
            nn.Sequential(
                nn.Linear(cs[4], cs[6]),
                nn.BatchNorm1d(cs[6]),
                nn.ReLU(True),
            ),
            nn.Sequential(
                nn.Linear(cs[6], cs[8]),
                nn.BatchNorm1d(cs[8]),
                nn.ReLU(True),
            )
        ])

        self.weight_initialization()
        self.dropout = nn.Dropout(0.3, True)
        self.num_point_features = cs[8]

    def weight_initialization(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, batch_dict):
        point_xyz = batch_dict['point_bxyz'][:, 1:]
        point_feat = batch_dict['point_feat']
        segmentation_label = batch_dict['segmentation_label'].float().unsqueeze(1)
        if 'instance_label_back' in batch_dict.keys():
            instance_label =  batch_dict['instance_label_back'].float().unsqueeze(1)
        else:
            instance_label =  batch_dict['instance_label'].float().unsqueeze(1)

        spv_coord = torch.cat((batch_dict['spv_coord'], batch_dict['point_bxyz'][:, 0].unsqueeze(1)), 1)
        x = SparseTensor(torch.cat((point_xyz, point_feat), 1), spv_coord)
        xseg = SparseTensor(segmentation_label, spv_coord)
        xins = SparseTensor(instance_label, spv_coord)
        # x: SparseTensor z: PointTensor
        z = PointTensor(x.F, x.C.float())
        zseg = PointTensor(xseg.F, xseg.C.float())
        zins = PointTensor(xins.F, xins.C.float())
        
        x0 = initial_voxelize(z, self.pres, self.vres)
        
        x0 = self.stem(x0)
        z0 = voxel_to_point(x0, z, nearest=False)
        z0.F = z0.F

        x1 = point_to_voxel(x0, z0)
        x1 = self.stage1(x1)
        x2 = self.stage2(x1)
        x3 = self.stage3(x2)
        x4 = self.stage4(x3)
        z1 = voxel_to_point(x4, z0)
        z1.F = z1.F + self.point_transforms[0](z0.F)
        

        y1 = point_to_voxel(x4, z1)
        y1.F = self.dropout(y1.F)
        y1 = self.up1[0](y1)
        y1 = torchsparse.cat([y1, x3])
        y1 = self.up1[1](y1)
        
        y2 = self.up2[0](y1)
        y2 = torchsparse.cat([y2, x2])
        y2 = self.up2[1](y2)
        z2 = voxel_to_point(y2, z1)
        z2.F = z2.F + self.point_transforms[1](z1.F)

        y3 = point_to_voxel(y2, z2)
        y3.F = self.dropout(y3.F)
        y3 = self.up3[0](y3)
        y3 = torchsparse.cat([y3, x1])
        y3 = self.up3[1](y3)

        y4 = self.up4[0](y3)
        y4 = torchsparse.cat([y4, x0])
        y4 = self.up4[1](y4)
        z3 = voxel_to_point(y4, z2)
        z3.F = z3.F + self.point_transforms[2](z2.F)
        
        x_convs = [y4, y3, y2, y1]
        for i, x_conv in enumerate(x_convs):
            downsample_times = [1, 2, 4, 8, [8, 8, 16]][i]
            downsample_times = torch.tensor(downsample_times).to(x_conv.F)
            
            coord_i = x_conv.C[:, 0:3].clone()
            voxel_size = torch.tensor(self.grid_size, device=coord_i.device).float() * downsample_times
            pc_range = torch.tensor(self.point_cloud_range[0:3], device=coord_i.device).float()
            voxel_corners = (coord_i) * voxel_size + pc_range
            #point_corners += voxel_size * 0.5
            batch_dict[f'spv_unet_up_bcenter{4-i}'] = torch.cat([x_conv.C[:, 3:4], voxel_corners], dim=-1)
            batch_dict[f'spv_unet_up_feat{4-i}'] = x_conv.F

        
        y4_coord = y4.C[:, 0:3].clone()
        y4_voxel_centers = y4_coord * self.grid_size.unsqueeze(0) + self.point_cloud_range[:3].unsqueeze(0)
        y4_voxel_centers = torch.cat((y4.C[:, 3].clone().unsqueeze(1), y4_voxel_centers,), 1)
        ref_bxyz = batch_dict['point_bxyz']
        #query_bxyz = y4_voxel_centers.type(torch.FloatTensor).cuda()
        query_bxyz = y4_voxel_centers
        e_ref, e_query = self.graph(ref_bxyz, query_bxyz)
        y4_voxel_segmentation_label = scatter(batch_dict['segmentation_label'][e_ref], e_query, dim=0,
                                            dim_size=query_bxyz.shape[0], reduce='max')
        y4_voxel_instance_label = scatter(batch_dict['instance_label'][e_ref], e_query, dim=0,
                                            dim_size=query_bxyz.shape[0], reduce='max')


        batch_dict['spv_point_fea'] =  z3.F 
        batch_dict['spv_unet_fea'] =  y4.F
        batch_dict['spv_unet_indices'] =  y4.C

        batch_dict['spv_point_bxyz'] = y4_voxel_centers
        batch_dict['spv_segmentation_label'] = y4_voxel_segmentation_label.long() 
        batch_dict['spv_instance_label'] = y4_voxel_instance_label.long() 
        
        if 0:
            from ...utils.vis_utils import write_ply_color
            import os
            out_path = 'tmp_refine_vis/spvnas_voxel5'
            if not os.path.exists(out_path):
                os.makedirs(out_path)

            for i in range(batch_dict['batch_size']):
                # name = batch_dict['frame_id'][i][0]
                name = batch_dict['frame_id'][i]
                bs_mask = batch_dict['spv_point_bxyz'][:,0]==i

                points_xyz = batch_dict['spv_point_bxyz'][:,1:][bs_mask].detach().cpu().numpy()
                gt_seg_label = batch_dict['spv_segmentation_label'].int()[bs_mask].detach().cpu().numpy()
                write_ply_color(points_xyz, gt_seg_label, os.path.join(out_path, '%s_gt_seg.ply'%(name)))
                
                gt_seg_label = batch_dict['spv_instance_label'].int()[bs_mask].detach().cpu().numpy()%22
                write_ply_color(points_xyz, gt_seg_label, os.path.join(out_path, '%s_gt_ins.ply'%(name)))
                
                bs_mask = batch_dict['point_bxyz'][:,0]==i
                points_xyz = batch_dict['point_bxyz'][:,1:][bs_mask].detach().cpu().numpy()
                gt_seg_label = batch_dict['segmentation_label'].int()[bs_mask].detach().cpu().numpy()
                write_ply_color(points_xyz, gt_seg_label, os.path.join(out_path, '%s_gt_seg1.ply'%(name)))
                
                gt_seg_label = batch_dict['instance_label'].int()[bs_mask].detach().cpu().numpy() % 22
                write_ply_color(points_xyz, gt_seg_label, os.path.join(out_path, '%s_gt_ins1.ply'%(name)))
            print('saved')

        return batch_dict 
        # out = self.classifier(z3.F)
        # return out

