# Copyright 2016-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
import MinkowskiEngine as ME
import MinkowskiEngine.MinkowskiFunctional as MF


class MinkUNet(ME.MinkowskiNetwork):
    def __init__(self, in_nchannel, out_nchannel, D=3):
        super(MinkUNet, self).__init__(D)
        self.block1 = torch.nn.Sequential(
            ME.MinkowskiConvolution(
                in_channels=in_nchannel,
                out_channels=8,
                kernel_size=3,
                stride=1,
                dimension=D),
            ME.MinkowskiBatchNorm(8))

        self.block2 = torch.nn.Sequential(
            ME.MinkowskiConvolution(
                in_channels=8,
                out_channels=16,
                kernel_size=3,
                stride=2,
                dimension=D),
            ME.MinkowskiBatchNorm(16),
        )

        self.block3 = torch.nn.Sequential(
            ME.MinkowskiConvolution(
                in_channels=16,
                out_channels=32,
                kernel_size=3,
                stride=2,
                dimension=D),
            ME.MinkowskiBatchNorm(32))

        self.block3_tr = torch.nn.Sequential(
            ME.MinkowskiConvolutionTranspose(
                in_channels=32,
                out_channels=16,
                kernel_size=3,
                stride=2,
                dimension=D),
            ME.MinkowskiBatchNorm(16))

        self.block2_tr = torch.nn.Sequential(
            ME.MinkowskiConvolutionTranspose(
                in_channels=32,
                out_channels=16,
                kernel_size=3,
                stride=2,
                dimension=D),
            ME.MinkowskiBatchNorm(16))

        self.conv1_tr = ME.MinkowskiConvolution(
            in_channels=24,
            out_channels=out_nchannel,
            kernel_size=1,
            stride=1,
            dimension=D)

    def forward(self, x):
        in_field = ME.TensorField(
            features=x[1],
            coordinates=ME.utils.batched_coordinates(x[0], dtype=torch.float32),
            quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE,
            minkowski_algorithm=ME.MinkowskiAlgorithm.SPEED_OPTIMIZED,
            device=x[1].device,
        )
        sinput = in_field.sparse()

        out_s1 = self.block1(sinput)
        out = MF.relu(out_s1)

        out_s2 = self.block2(out)
        out = MF.relu(out_s2)

        out_s4 = self.block3(out)
        out = MF.relu(out_s4)

        out = MF.relu(self.block3_tr(out))
        out = ME.cat(out, out_s2)

        out = MF.relu(self.block2_tr(out))
        out = ME.cat(out, out_s1)
        out = self.conv1_tr(out)

        out_field = out.slice(in_field)
        logits = out_field.F

        return logits


# import torch
# import torch.nn as nn
# import torch.optim as optim
# import torch.nn.functional as F
# import sparseconvnet as scn
# import time
# import os, sys
# import math
# import numpy as np
# class SparseConvUNet(torch.nn.Module):
#     def __init__(self, d_in=11, d_out=11, dimension=3, spatialSize=65, reps=1, m=32):
#         torch.nn.Module.__init__(self)
#         spatialSize=[spatialSize]*dimension
#         nPlanes=[m, 2*m, 3*m, 4*m, 5*m]
#         # nPlanes=[m, 2*m]
#         # self.sparseModel = scn.Sequential().add(
#         #    scn.InputLayer(dimension, spatialSize, mode=3)).add(
#         #    scn.SubmanifoldConvolution(dimension, d_in, m, filter_size=3, bias=False)).add(
#         #    scn.UNet(dimension, reps, nPlanes, residual_blocks=False, downsample=[2,2])).add(
#         #    scn.BatchNormReLU(m)).add(
#         #    scn.OutputLayer(dimension))
#         self.in_layer = scn.InputLayer(dimension, spatialSize, mode=3)
#         self.in_conv = scn.SubmanifoldConvolution(dimension, d_in, m, filter_size=3, bias=False)
#         self.unet = scn.UNet(dimension, reps, nPlanes, residual_blocks=False, downsample=[2,2])
#         self.bn = scn.BatchNormReLU(m)
#         self.out_conv = scn.OutputLayer(dimension)
#         self.linear = torch.nn.Linear(m, d_out)
#     def forward(self,x):
#         # x=self.sparseModel(x)
#         # print(x, x[0].requires_grad, x[1].requires_grad)
#         x = self.in_layer(x)
#         # print(x)
#         x = self.in_conv(x)
#         # print(x)
#         x = self.unet(x)
#         # print(x)
#         x = self.bn(x)
#         # print(x)
#         x = self.out_conv(x)
#         # print(x)
#         x=self.linear(x)
#         return x


