import torch
import torch.nn as nn
import spconv
import spconv.pytorch
import numpy as np
from spconv.pytorch import SparseSequential, SparseConv2d
from spconv.core import ConvAlgo

def replace_feature(out, new_features):
    if "replace_feature" in out.__dir__():
        return out.replace_feature(new_features)
    else:
        out.features = new_features
        return out

class SparseConvBlock(spconv.pytorch.SparseModule):
    def __init__(self, in_channels, out_channels, kernel_size, stride, use_subm=True, bias=False):
        super(SparseConvBlock, self).__init__()
        if stride == 1 and use_subm:
            self.conv = spconv.pytorch.SubMConv2d(in_channels, out_channels, kernel_size,
                                                  padding=kernel_size//2, stride=1, bias=bias)
        else:
            self.conv = spconv.pytorch.SparseConv2d(in_channels, out_channels, kernel_size,
                                                    padding=kernel_size//2, stride=stride, bias=bias)

        self.norm = nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01)
        self.act = nn.ReLU()

    def forward(self, x):
        out = self.conv(x)
        out = replace_feature(out, self.norm(out.features))
        out = replace_feature(out, self.act(out.features))

        return out
    

class SparseBasicBlock(spconv.pytorch.SparseModule):
    def __init__(self, channels, kernel_size):
        super(SparseBasicBlock, self).__init__()
        self.block1 = SparseConvBlock(channels, channels, kernel_size, 1)
        self.conv2 = spconv.pytorch.SubMConv2d(channels, channels, kernel_size, padding=kernel_size//2,
                                               stride=1, bias=False, algo=ConvAlgo.Native)
        self.norm2 = nn.BatchNorm1d(channels, eps=1e-3, momentum=0.01)
        self.act2 = nn.ReLU()

    def forward(self, x):
        identity = x
        out = self.block1(x)
        out = self.conv2(out)
        out = replace_feature(out, self.norm2(out.features))
        out = replace_feature(out, out.features + identity.features)
        out = replace_feature(out, self.act2(out.features))

        return out
    
class SparseResNet(spconv.pytorch.SparseModule):
    def __init__(self, model_cfg, input_channels, output_channels):
        super(SparseResNet, self).__init__()
        self.model_cfg = model_cfg
        self._layer_strides = model_cfg['layer_strides']
        self._num_filters = model_cfg['num_filters']
        self._layer_nums = model_cfg['layer_nums']
        self._grid_size = model_cfg['grid_size']

        # self._input_channels = input_channels
        self._input_channels = input_channels

        kernel_size = [3, 3, 3, 3]
        out_channels = output_channels

        assert len(self._layer_strides) == len(self._layer_nums)
        assert len(self._num_filters) == len(self._layer_nums)

        in_filters = [self._input_channels, *self._num_filters[:-1]]
        blocks = []

        for i, layer_num in enumerate(self._layer_nums):
            block = self._make_layer(
                in_filters[i],
                self._num_filters[i],
                kernel_size[i],
                self._layer_strides[i],
                layer_num)
            blocks.append(block)

        self.blocks = nn.ModuleList(blocks)

        self.mapping = SparseSequential(
            SparseConv2d(self._num_filters[-1], 
                         out_channels, 1, 1, bias=False),
            nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01),
            nn.ReLU()
        )

    def _make_layer(self, inplanes, planes, kernel_size, stride, num_blocks):
        layers = []
        layers.append(SparseConvBlock(inplanes, planes, 
                                      kernel_size=kernel_size, stride=stride, use_subm=False))
        
        for j in range(num_blocks):
            layers.append(SparseBasicBlock(planes, kernel_size=kernel_size))

        return spconv.pytorch.SparseSequential(*layers)
    
    def forward(self, batch_dict):
        spatial_features, coords = \
            batch_dict['spatial_features'], batch_dict['voxel_coords']
        
        spatial_features = spatial_features.permute(0, 2, 3, 1)
        SP_spatial_features = spatial_features.to_sparse(3)

        indices_th = SP_spatial_features.indices().permute(1, 0).contiguous()
        feature_th = SP_spatial_features.values().view(-1, 64)
        # batch_size = len(torch.unique(coords[:, 0]))
        batch_size = spatial_features.shape[0]
        
        x = spconv.pytorch.SparseConvTensor(
            feature_th, indices_th.int(), spatial_features.shape[1:3], batch_size)
        for i in range(len(self.blocks)):
            x = self.blocks[i](x)
        x = self.mapping(x)

        batch_dict['spatial_features_2d'] = x.dense()
                    
        return batch_dict
        
