import torch.nn as nn
import MinkowskiEngine as ME
from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck
from examples.resnet import ResNetBase
import torch
import math
import torch.nn.functional as F
import numpy as np
import random
import lossFunction
import torch
from torch import nn

class GMMNnetwork(nn.Module):
    def __init__(
            self,
            noise_dim,
            embed_dim,
            hidden_size,
            feature_dim,
            embed_feature_size=0,
            semantic_reconstruction=False,
    ):
        super().__init__()
        embed_dim = embed_dim + embed_feature_size
        def block(in_feat, out_feat):
            layers = [nn.Linear(in_feat, out_feat)]
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            layers.append(nn.Dropout(p=0.5))
            return layers

        def init_weights(m):
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                m.bias.data.fill_(0.01)

        if hidden_size:
            self.model = nn.Sequential(
                *block(noise_dim + embed_dim, hidden_size),
                nn.Linear(hidden_size, feature_dim),
            )
        else:
            self.model = nn.Linear(noise_dim + embed_dim, feature_dim)

        self.model.apply(init_weights)
        self.semantic_reconstruction = semantic_reconstruction
        if self.semantic_reconstruction:
            self.semantic_reconstruction_layer = nn.Linear(
                feature_dim, noise_dim + embed_dim
            )

    def forward(self, embd, noise):
        features = self.model(torch.cat((embd, noise), 1))
        if self.semantic_reconstruction:
            semantic = self.semantic_reconstruction_layer(features)
            return features, semantic
        else:
            return features

'''
# 50.68%
class URS(nn.Module):
    def __init__(self):
        super(URS, self).__init__()

        self.t = 1
        # self.classifier = nn.Parameter(torch.zeros(size=(32, 5)))
        # nn.init.kaiming_uniform_(self.classifier.data, a=math.sqrt(5))
        self.classifier = nn.Linear(96, 5)

        self.v = nn.Parameter(torch.zeros(size=(96, 32)))
        nn.init.kaiming_uniform_(self.v.data, a=math.sqrt(5))
        self.k = nn.Parameter(torch.zeros(size=(96, 32)))
        nn.init.kaiming_uniform_(self.k.data, a=math.sqrt(5))
        self.q = nn.Parameter(torch.zeros(size=(96, 32)))
        nn.init.kaiming_uniform_(self.q.data, a=math.sqrt(5))

        self.base = nn.Parameter(torch.zeros(size=(64, 96)))
        nn.init.kaiming_uniform_(self.base.data, a=math.sqrt(5))

    def forward(self, M_data, S_data):
        # S_data : B_s*N, 96
        # M_data : B_m, 512
        # output: B_s*N, 2
        # v = torch.matmul(S_data, self.v)
        # k = torch.matmul(S_data, self.k)
        # q = torch.matmul(self.base, self.q)

        predictions = self.classifier(S_data)
        attentions = torch.matmul(torch.matmul(S_data, self.k), torch.matmul(self.base, self.q).permute(1, 0))
        attentions = F.softmax(attentions / self.t, dim=1)
        representations = torch.matmul(attentions, self.base)
        similarities = torch.sigmoid(torch.sum(torch.matmul(representations, self.v) * torch.matmul(S_data, self.v), dim=1))
        # similarities = 0
        return predictions, similarities, self.classifier

#  baseline 48.70%
# 2_1 50.93%
# 2_0 add non_overlap broken model 50.07%
 

class URS(nn.Module):
    def __init__(self):
        super(URS, self).__init__()
        self.t = 1
        # self.classifier = nn.Parameter(torch.zeros(size=(32, 5)))
        # nn.init.kaiming_uniform_(self.classifier.data, a=math.sqrt(5))
        self.classifier = nn.Linear(96, 5)


        self.v = nn.Parameter(torch.zeros(size=(96, 32)))
        nn.init.kaiming_uniform_(self.v.data, a=math.sqrt(5))
        self.k = nn.Parameter(torch.zeros(size=(96, 32)))
        nn.init.kaiming_uniform_(self.k.data, a=math.sqrt(5))
        self.q = nn.Parameter(torch.zeros(size=(96, 32)))
        nn.init.kaiming_uniform_(self.q.data, a=math.sqrt(5))

        self.base = nn.Parameter(torch.zeros(size=(128, 96)))
        nn.init.kaiming_uniform_(self.base.data, a=math.sqrt(5))

    def forward(self, M_data, S_data):
        # S_data : B_s*N, 96
        # M_data : B_m, 512
        # output: B_s*N, 2
        # v = torch.matmul(S_data, self.v)
        # k = torch.matmul(S_data, self.k)
        # q = torch.matmul(self.base, self.q)

        predictions = self.classifier(S_data)
        attentions = torch.matmul(torch.matmul(S_data, self.k), torch.matmul(self.base, self.q).permute(1, 0))
        attentions = F.softmax(attentions / self.t, dim=1)
        representations = torch.matmul(attentions, self.base)
        similarities = torch.sigmoid(torch.sum(torch.matmul(representations, self.v) * torch.matmul(S_data, self.v), dim=1))
        # similarities = 0
        return predictions, similarities, self.classifier
'''
'''
#2_3 52.08%
# 2_2 add non_overlap broken model
 
class URS(nn.Module):
    def __init__(self):
        super(URS, self).__init__()

        self.t = 1
        # self.classifier = nn.Parameter(torch.zeros(size=(32, 5)))
        # nn.init.kaiming_uniform_(self.classifier.data, a=math.sqrt(5))
        self.classifier = nn.Linear(96, 5)

        self.v = nn.Parameter(torch.zeros(size=(96, 32)))
        nn.init.kaiming_uniform_(self.v.data, a=math.sqrt(5))
        self.k = nn.Parameter(torch.zeros(size=(96, 32)))
        nn.init.kaiming_uniform_(self.k.data, a=math.sqrt(5))
        self.q = nn.Parameter(torch.zeros(size=(96, 32)))
        nn.init.kaiming_uniform_(self.q.data, a=math.sqrt(5))
        self.base = nn.Parameter(torch.zeros(size=(128, 96)))
        nn.init.kaiming_uniform_(self.base.data, a=math.sqrt(5))

    def forward(self, M_data, S_data):
        # S_data : B_s*N, 96
        # M_data : B_m, 512
        # output: B_s*N, 2
        # v = torch.matmul(S_data, self.v)
        # k = torch.matmul(S_data, self.k)
        # q = torch.matmul(self.base, self.q)

        predictions = self.classifier(S_data)
        attentions = torch.matmul(S_data, self.base.permute(1, 0))
        attentions = F.softmax(attentions / self.t, dim=1)
        representations = torch.matmul(attentions, self.base)
        similarities = torch.sigmoid(torch.sum(representations * S_data, dim=1))
        # similarities = 0
        return predictions, similarities, self.classifier

#1_0 52.07 %
# 1_0 52.05 %
#1_3: add non_overlap broken model 53.07%
class URS(nn.Module):
    def __init__(self):
        super(URS, self).__init__()

        self.t = 1
        # self.classifier = nn.Parameter(torch.zeros(size=(32, 5)))
        # nn.init.kaiming_uniform_(self.classifier.data, a=math.sqrt(5))
        self.classifier = nn.Linear(96, 5)

        self.v = nn.Parameter(torch.zeros(size=(96, 32)))
        nn.init.kaiming_uniform_(self.v.data, a=math.sqrt(5))
        self.k = nn.Parameter(torch.zeros(size=(96, 32)))
        nn.init.kaiming_uniform_(self.k.data, a=math.sqrt(5))
        self.q = nn.Parameter(torch.zeros(size=(96, 32)))
        nn.init.kaiming_uniform_(self.q.data, a=math.sqrt(5))

        self.base = nn.Parameter(torch.zeros(size=(128, 96)))
        nn.init.kaiming_uniform_(self.base.data, a=math.sqrt(5))

    def forward(self, M_data, S_data):
        # S_data : B_s*N, 96
        # M_data : B_m, 512
        # output: B_s*N, 2
        # v = torch.matmul(S_data, self.v)
        # k = torch.matmul(S_data, self.k)
        # q = torch.matmul(self.base, self.q)

        # predictions = self.classifier(S_data)
        attentions = torch.matmul(S_data, self.base.permute(1, 0))
        attentions = F.softmax(attentions / self.t, dim=1)
        representations = torch.matmul(attentions, self.base)
        predictions = self.classifier(representations)
        # similarities = torch.sigmoid(torch.sum(representations * S_data, dim=1))
        similarities = 0
        return predictions, similarities, self.classifier

'''
#1_1 51.63%
#1_1 50.87%
# 1_2: add non_overlap broken model 54.06%
class URS(nn.Module):
    def __init__(self):
        super(URS, self).__init__()

        self.t = 1
        # self.classifier = nn.Parameter(torch.zeros(size=(32, 5)))
        # nn.init.kaiming_uniform_(self.classifier.data, a=math.sqrt(5))
        self.classifier = nn.Linear(96, 5)

        self.v = nn.Parameter(torch.zeros(size=(96, 16)))
        nn.init.kaiming_uniform_(self.v.data, a=math.sqrt(5))
        self.k = nn.Parameter(torch.zeros(size=(96, 16)))
        nn.init.kaiming_uniform_(self.k.data, a=math.sqrt(5))
        self.q = nn.Parameter(torch.zeros(size=(96, 16)))
        nn.init.kaiming_uniform_(self.q.data, a=math.sqrt(5))

        self.base = nn.Parameter(torch.zeros(size=(128, 96)))
        nn.init.kaiming_uniform_(self.base.data, a=math.sqrt(5))

    def calculate_loss(self, predictions, labels):
        labels_final = labels[labels != -100]
        predictions_final = predictions[labels != -100]
        loss = F.cross_entropy(predictions_final, labels_final)
        return loss

    def forward(self, M_data, S_data, labels):
        # S_data : B_s*N, 96
        # M_data : B_m, 512
        # output: B_s*N, 2
        # v = torch.matmul(S_data, self.v)
        # k = torch.matmul(S_data, self.k)
        # q = torch.matmul(self.base, self.q)

        # predictions = self.classifier(S_data)
        attentions = torch.matmul(torch.matmul(S_data, self.k), torch.matmul(self.base, self.q).permute(1, 0))
        attentions = F.softmax(attentions / self.t, dim=1)
        representations = torch.matmul(attentions, self.base)
        predictions = self.classifier(representations)
        # similarities = torch.sigmoid(torch.sum(representations * S_data, dim=1))

        loss = self.calculate_loss(predictions, labels)
        return predictions, loss

class global_encoder(ME.MinkowskiNetwork):
    def __init__(
        self,
        in_channel,
        out_channel,
        channels=(96, 96, 96, 96),
        D=3,
    ):
        ME.MinkowskiNetwork.__init__(self, D)

        self.network_initialization(
            in_channel,
            out_channel,
            channels=channels,
            kernel_size=3,
            D=D,
        )
        self.weight_initialization()

    def get_mlp_block(self, in_channel, out_channel):
        return nn.Sequential(
            ME.MinkowskiLinear(in_channel, out_channel, bias=False),
            ME.MinkowskiBatchNorm(out_channel),
            ME.MinkowskiLeakyReLU(),
        )


    def get_conv_block(self, in_channel, out_channel, kernel_size, stride):
        return nn.Sequential(
            ME.MinkowskiConvolution(
                in_channel,
                out_channel,
                kernel_size=kernel_size,
                stride=stride,
                dimension=self.D,
            ),
            ME.MinkowskiBatchNorm(out_channel),
            ME.MinkowskiLeakyReLU(),
        )

    def network_initialization(
        self,
        in_channel,
        out_channel,
        channels,
        kernel_size,
        D=3,
    ):
        self.mlp1 = self.get_mlp_block(in_channel, channels[0])
        self.conv1 = self.get_conv_block(
            in_channel,
            channels[0],
            kernel_size=kernel_size,
            stride=1,
        )
        self.conv2 = self.get_conv_block(
            channels[0],
            channels[1],
            kernel_size=kernel_size,
            stride=1,
        )

        self.conv3 = self.get_conv_block(
            channels[1],
            channels[2],
            kernel_size=kernel_size,
            stride=1,
        )


        self.conv4 = self.get_conv_block(
            channels[2],
            channels[3],
            kernel_size=kernel_size,
            stride=1,
        )

        self.pool = ME.MinkowskiMaxPooling(kernel_size=3, stride=2, dimension=D)

        self.classifier = nn.Sequential(
            nn.Linear(96, 1),
            nn.Sigmoid(),
        )


        self.global_max_pool = ME.MinkowskiGlobalMaxPooling()
        self.global_avg_pool = ME.MinkowskiGlobalAvgPooling()

        self.final = nn.Sequential(
            self.get_mlp_block(channels[3] * 2, 96),
            # ME.MinkowskiDropout(),
            # self.get_mlp_block(512, 512),
            # ME.MinkowskiLinear(512, out_channel, bias=True),
        )

        # No, Dropout, last 256 linear, AVG_POOLING 92%

    def weight_initialization(self):
        for m in self.modules():
            if isinstance(m, ME.MinkowskiConvolution):
                ME.utils.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu")

            if isinstance(m, ME.MinkowskiBatchNorm):
                nn.init.constant_(m.bn.weight, 1)
                nn.init.constant_(m.bn.bias, 0)


    def forward(self, x):
        y = x.sparse()
        # print(y.features)

        y = self.conv1(y)
        y1 = self.pool(y)

        # print(y1.features)

        y = self.conv2(y1)
        y2 = self.pool(y)

        # print(y2.features)
        y = self.conv3(y2)
        y3 = self.pool(y)

        # print(y3.features)
        y = self.conv4(y3)
        y4 = self.pool(y)


        # print(y4.features)
        x1 = self.global_max_pool(y4)
        # x2 = self.global_avg_pool(y4)

        # print(x1.features)
        # print(x2.features)


        # final = self.final(ME.cat(x1, x2)).F
        # print('ME.cat: ', ME.cat(x1, x2).features)
        # print('final: ', final)
        prediction = self.classifier(x1.F)
        # print('prediction: ', prediction)
        return prediction
        # return torch.matmul(final, anchors)
        # return x1

class MinkowskiFCNN(ME.MinkowskiNetwork):
    def __init__(
        self,
        in_channel,
        out_channel,
        embedding_channel=512,
        channels=(32, 48, 64, 96, 128),
        D=3,
    ):
        ME.MinkowskiNetwork.__init__(self, D)

        self.network_initialization(
            in_channel,
            out_channel,
            channels=channels,
            embedding_channel=embedding_channel,
            kernel_size=3,
            D=D,
        )
        self.weight_initialization()

    def get_mlp_block(self, in_channel, out_channel):
        return nn.Sequential(
            ME.MinkowskiLinear(in_channel, out_channel, bias=False),
            ME.MinkowskiBatchNorm(out_channel),
            ME.MinkowskiLeakyReLU(),
        )


    def get_conv_block(self, in_channel, out_channel, kernel_size, stride):
        return nn.Sequential(
            ME.MinkowskiConvolution(
                in_channel,
                out_channel,
                kernel_size=kernel_size,
                stride=stride,
                dimension=self.D,
            ),
            ME.MinkowskiBatchNorm(out_channel),
            ME.MinkowskiLeakyReLU(),
        )

    def network_initialization(
        self,
        in_channel,
        out_channel,
        channels,
        embedding_channel,
        kernel_size,
        D=3,
    ):
        self.mlp1 = self.get_mlp_block(in_channel, channels[0])
        self.conv1 = self.get_conv_block(
            channels[0],
            channels[1],
            kernel_size=kernel_size,
            stride=1,
        )
        self.conv2 = self.get_conv_block(
            channels[1],
            channels[2],
            kernel_size=kernel_size,
            stride=2,
        )

        self.conv3 = self.get_conv_block(
            channels[2],
            channels[3],
            kernel_size=kernel_size,
            stride=2,
        )

        self.conv4 = self.get_conv_block(
            channels[3],
            channels[4],
            kernel_size=kernel_size,
            stride=2,
        )
        self.conv5 = nn.Sequential(
            self.get_conv_block(
                channels[1] + channels[2] + channels[3] + channels[4],
                embedding_channel // 4,
                kernel_size=3,
                stride=2,
            ),
            self.get_conv_block(
                embedding_channel // 4,
                embedding_channel // 2,
                kernel_size=3,
                stride=2,
            ),
            self.get_conv_block(
                embedding_channel // 2,
                embedding_channel,
                kernel_size=3,
                stride=2,
            ),
        )

        self.pool = ME.MinkowskiMaxPooling(kernel_size=3, stride=2, dimension=D)

        self.global_max_pool = ME.MinkowskiGlobalMaxPooling()
        self.global_avg_pool = ME.MinkowskiGlobalAvgPooling()

        self.final = nn.Sequential(
            self.get_mlp_block(embedding_channel * 2, 96),
            # ME.MinkowskiDropout(),
            # self.get_mlp_block(512, 512),
            # ME.MinkowskiLinear(512, out_channel, bias=True),
        )

        # No, Dropout, last 256 linear, AVG_POOLING 92%

    def weight_initialization(self):
        for m in self.modules():
            if isinstance(m, ME.MinkowskiConvolution):
                ME.utils.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu")

            if isinstance(m, ME.MinkowskiBatchNorm):
                nn.init.constant_(m.bn.weight, 1)
                nn.init.constant_(m.bn.bias, 0)

    def forward(self, x: ME.TensorField):
        x = self.mlp1(x)
        y = x.sparse()

        y = self.conv1(y)
        y1 = self.pool(y)

        y = self.conv2(y1)
        y2 = self.pool(y)

        y = self.conv3(y2)
        y3 = self.pool(y)

        y = self.conv4(y3)
        y4 = self.pool(y)

        x1 = y1.slice(x)
        x2 = y2.slice(x)
        x3 = y3.slice(x)
        x4 = y4.slice(x)

        x = ME.cat(x1, x2, x3, x4)

        y = self.conv5(x.sparse())
        x1 = self.global_max_pool(y)
        x2 = self.global_avg_pool(y)

        return self.final(ME.cat(x1, x2)).F
        # return x1

class MinkUNetBase(ResNetBase):
    BLOCK = None
    PLANES = None
    DILATIONS = (1, 1, 1, 1, 1, 1, 1, 1)
    LAYERS = (2, 2, 2, 2, 2, 2, 2, 2)
    PLANES = (32, 64, 128, 256, 256, 128, 96, 96)
    INIT_DIM = 32
    OUT_TENSOR_STRIDE = 1

    # To use the model, must call initialize_coords before forward pass.
    # Once data is processed, call clear to reset the model before calling
    # initialize_coords
    def __init__(self, in_channels, out_channels, config, D=3):
        ResNetBase.__init__(self, in_channels, out_channels, D)

        self.t = 0.5
        # self.classifier = nn.Parameter(torch.zeros(size=(32, 5)))
        # nn.init.kaiming_uniform_(self.classifier.data, a=math.sqrt(5))
        self.classifier = nn.Linear(96, 12)
        self.config = config

        self.anchors = nn.Parameter(torch.zeros(size=(96, 1)))
        nn.init.kaiming_uniform_(self.anchors.data, a=math.sqrt(5))

        # self.generator = nn.Sequential(
        #     nn.Linear(300, 256),
        #     nn.ReLU(),
        #     nn.Linear(256, 128),
        #     nn.ReLU(),
        #     nn.Linear(128, 96)
        # )

        self.generator = GMMNnetwork(config.noise_dim, config.embed_dim, \
                                config.hidden_size, config.feature_dim, embed_feature_size=0).to(config.device)

        self.criterion_generator = lossFunction.GMMNLoss(sigma=[2, 5, 10, 20, 40, 80], cuda=config.device).build_loss()

        self.generator_b = nn.Sequential(
            nn.Linear(300, 96),
            # nn.ReLU(),
            # nn.Linear(256, 128),
            # nn.ReLU(),
            # nn.Linear(128, 96)
        )

        self.v = nn.Parameter(torch.zeros(size=(96, 16)))
        nn.init.kaiming_uniform_(self.v.data, a=math.sqrt(5))
        self.k = nn.Parameter(torch.zeros(size=(96, 16)))
        nn.init.kaiming_uniform_(self.k.data, a=math.sqrt(5))
        self.k1 = nn.Parameter(torch.zeros(size=(96, 16)))
        nn.init.kaiming_uniform_(self.k1.data, a=math.sqrt(5))

        self.q = nn.Parameter(torch.zeros(size=(96, 16)))
        nn.init.kaiming_uniform_(self.q.data, a=math.sqrt(5))
        self.q1 = nn.Parameter(torch.zeros(size=(96, 16)))
        nn.init.kaiming_uniform_(self.q1.data, a=math.sqrt(5))

        self.base = nn.Parameter(torch.zeros(size=(128, 96)))
        nn.init.kaiming_uniform_(self.base.data, a=math.sqrt(5))
        self.base1 = nn.Parameter(torch.zeros(size=(128, 96)))
        nn.init.kaiming_uniform_(self.base1.data, a=math.sqrt(5))

        self.sample_points = 1024

    def network_initialization(self, in_channels, out_channels, D):
        # Output of the first conv concated to conv6
        self.inplanes = self.INIT_DIM
        self.conv0p1s1 = ME.MinkowskiConvolution(
            in_channels, self.inplanes, kernel_size=5, dimension=D)

        self.bn0 = ME.MinkowskiBatchNorm(self.inplanes)

        self.conv1p1s2 = ME.MinkowskiConvolution(
            self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)
        self.bn1 = ME.MinkowskiBatchNorm(self.inplanes)

        self.block1 = self._make_layer(self.BLOCK, self.PLANES[0],
                                       self.LAYERS[0])

        self.conv2p2s2 = ME.MinkowskiConvolution(
            self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)
        self.bn2 = ME.MinkowskiBatchNorm(self.inplanes)

        self.block2 = self._make_layer(self.BLOCK, self.PLANES[1],
                                       self.LAYERS[1])

        self.conv3p4s2 = ME.MinkowskiConvolution(
            self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)

        self.bn3 = ME.MinkowskiBatchNorm(self.inplanes)
        self.block3 = self._make_layer(self.BLOCK, self.PLANES[2],
                                       self.LAYERS[2])

        self.conv4p8s2 = ME.MinkowskiConvolution(
            self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)
        self.bn4 = ME.MinkowskiBatchNorm(self.inplanes)
        self.block4 = self._make_layer(self.BLOCK, self.PLANES[3],
                                       self.LAYERS[3])

        self.convtr4p16s2 = ME.MinkowskiConvolutionTranspose(
            self.inplanes, self.PLANES[4], kernel_size=2, stride=2, dimension=D)
        self.bntr4 = ME.MinkowskiBatchNorm(self.PLANES[4])

        self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion
        self.block5 = self._make_layer(self.BLOCK, self.PLANES[4],
                                       self.LAYERS[4])
        self.convtr5p8s2 = ME.MinkowskiConvolutionTranspose(
            self.inplanes, self.PLANES[5], kernel_size=2, stride=2, dimension=D)
        self.bntr5 = ME.MinkowskiBatchNorm(self.PLANES[5])

        self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion
        self.block6 = self._make_layer(self.BLOCK, self.PLANES[5],
                                       self.LAYERS[5])
        self.convtr6p4s2 = ME.MinkowskiConvolutionTranspose(
            self.inplanes, self.PLANES[6], kernel_size=2, stride=2, dimension=D)
        self.bntr6 = ME.MinkowskiBatchNorm(self.PLANES[6])

        self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion
        self.block7 = self._make_layer(self.BLOCK, self.PLANES[6],
                                       self.LAYERS[6])
        self.convtr7p2s2 = ME.MinkowskiConvolutionTranspose(
            self.inplanes, self.PLANES[7], kernel_size=2, stride=2, dimension=D)
        self.bntr7 = ME.MinkowskiBatchNorm(self.PLANES[7])

        self.inplanes = self.PLANES[7] + self.INIT_DIM
        self.block8 = self._make_layer(self.BLOCK, self.PLANES[7],
                                       self.LAYERS[7])

        # self.final = ME.MinkowskiConvolution(
        #     self.PLANES[7] * self.BLOCK.expansion,
        #     out_channels,
        #     kernel_size=1,
        #     bias=True,
        #     dimension=D)
        self.relu = ME.MinkowskiReLU(inplace=True)

    def calculate_loss(self, predictions, labels):
        # labels_final = labels[labels != -100]
        # predictions_final = predictions[labels != -100]
        # loss = F.cross_entropy(predictions_final, labels_final)

        loss = F.cross_entropy(predictions, labels)
        return loss

    def local_regularization(self, coords_scannet, features, labels):

        loss = 0
        for bs in range(self.config.batch_size_scannet):

            index_a = coords_scannet[:, 0] == bs
            index_b = coords_scannet[:, 0] == (bs + self.config.batch_size_scannet)
            scan_a = features[index_a]
            scan_b = features[index_b]

            # sample index
            index = list(range(scan_b.size()[0]))
            random.shuffle(index)
            index = index[:self.sample_points]
            index.sort()
            index_tensor = torch.tensor(index)

            scan_a = scan_a[index_tensor[index_tensor < scan_a.size()[0]], :]
            scan_b = scan_b[index_tensor, :]

            if scan_a.size()[0] == 0: continue

            f_labels = torch.tensor(list(range(scan_a.size()[0]))).long().to(self.config.device)
            similarities = torch.matmul(scan_a, scan_b.permute(1, 0))

            loss += F.cross_entropy(similarities, f_labels)

        return loss

    def forward_backbone(self, in_field_scannet):
        x = in_field_scannet.sparse()
        out = self.conv0p1s1(x)
        out = self.bn0(out)
        out_p1 = self.relu(out)

        out = self.conv1p1s2(out_p1)
        out = self.bn1(out)
        out = self.relu(out)
        out_b1p2 = self.block1(out)

        out = self.conv2p2s2(out_b1p2)
        out = self.bn2(out)
        out = self.relu(out)
        out_b2p4 = self.block2(out)

        out = self.conv3p4s2(out_b2p4)
        out = self.bn3(out)
        out = self.relu(out)
        out_b3p8 = self.block3(out)

        # tensor_stride=16
        out = self.conv4p8s2(out_b3p8)
        out = self.bn4(out)
        out = self.relu(out)
        out = self.block4(out)

        # tensor_stride=8
        out = self.convtr4p16s2(out)
        out = self.bntr4(out)
        out = self.relu(out)

        out = ME.cat(out, out_b3p8)
        out = self.block5(out)

        # tensor_stride=4
        out = self.convtr5p8s2(out)
        out = self.bntr5(out)
        out = self.relu(out)

        out = ME.cat(out, out_b2p4)
        out = self.block6(out)

        # tensor_stride=2
        out = self.convtr6p4s2(out)
        out = self.bntr6(out)
        out = self.relu(out)

        out = ME.cat(out, out_b1p2)
        out = self.block7(out)

        # tensor_stride=1
        out = self.convtr7p2s2(out)
        out = self.bntr7(out)
        out = self.relu(out)
        out = ME.cat(out, out_p1)
        out = self.block8(out)
        return out.slice(in_field_scannet).F

    def convex_regularization(self, feature):

        # self.t = np.random.normal(loc=0.5, scale=0.06)
        # self.t = max(self.t, 0.1)

        attentions = torch.matmul(torch.matmul(feature, self.k), torch.matmul(self.base, self.q).permute(1, 0))
        # attentions = torch.matmul(feature, self.base.permute(1, 0))
        attentions = F.softmax(attentions / self.t, dim=1)
        representations = torch.matmul(attentions, self.base)

        return representations

    def forward_classifier(self, feature, word_embeddings):
        # S_data : B_s*N, 96
        # M_data : B_m, 512
        # output: B_s*N, 2
        # v = torch.matmul(S_data, self.v)
        # k = torch.matmul(S_data, self.k)
        # q = torch.matmul(self.base, self.q)

        ''' 1
        predictions = self.classifier(S_data)
        return predictions
        '''

        # attentions = torch.matmul(torch.matmul(S_data, self.k), torch.matmul(self.base, self.q).permute(1, 0))
        # attentions = F.softmax(attentions / self.t, dim=1)
        # representations = torch.matmul(attentions, self.base)

        # 0
        anchors = self.generator_b(word_embeddings)
        anchors = torch.cat((anchors.permute(1, 0), self.anchors), dim=1)

        # anchors are regulared by prototypes
        # attentions_an = torch.matmul(torch.matmul(anchors.permute(1, 0), self.k), torch.matmul(self.base, self.q).permute(1, 0))
        # attentions_an = torch.matmul(anchors.permute(1, 0), self.base.permute(1, 0))

        # attentions_an = torch.matmul(torch.matmul(anchors.permute(1, 0), self.k1), torch.matmul(self.base1, self.q1).permute(1, 0))
        # attentions_an = F.softmax(attentions_an / self.t, dim=1)
        # anchors = torch.matmul(attentions_an, self.base).permute(1, 0)

        predictions = torch.matmul(feature, anchors)
        # similarities = torch.sigmoid(torch.sum(representations * S_data, dim=1))
        return predictions

    def forward(self, coords_scannet, in_field_scannet, labels, word_embeddings, phase):
        features = self.forward_backbone(in_field_scannet)

        model_index = labels != -100
        if phase == 'train':
            features = features[model_index]

        # features = self.convex_regularization(features)

        # print(features.size(), labels[model_index].size())

        loss = 0
        if phase == 'train':

            labels = labels[labels != -100]
            unique_class = torch.unique(labels)
            # print(word_embeddings.size())
            fake_features = torch.zeros(features.size()).to(self.config.device)

            for idx_in in unique_class:
                if idx_in == 20: continue
                # optimizer_G.zero_grad()
                idx_class = labels == idx_in
                # print(real_features.size(), idx_class.size())
                real_features_class = features[idx_class]
                embedding_class = word_embeddings[idx_in, :].unsqueeze(0).repeat(real_features_class.size()[0], 1).to(self.config.device)
                # print(real_features_class.size(), embedding_class.size())

                # Noise generation
                z = torch.rand((embedding_class.shape[0], 300)).to(self.config.device)

                # Avoid CUDA out of memory
                random_idx = torch.randint(low=0, high=embedding_class.shape[0],
                                           size=(self.config.batch_size_generator,))

                embedding_class = embedding_class[random_idx]
                z = z[random_idx]

                # Generation of the features
                fake_features_class = self.generator(embedding_class, z.float())
                fake_feature_train = fake_features_class

                # Generator loss
                loss += self.criterion_generator(
                    fake_feature_train,
                    real_features_class[random_idx],
                )
                # generator_loss_sample += g_loss.item()
                # g_loss.backward(retain_graph=True)
                # optimizer_G.step()

                fake_features[idx_class] = real_features_class
            # generator_loss_batch += generator_loss_sample / len(unique_class)

            fake_features[labels == 20] = features[labels == 20]
            # print(fake_features.size(), features.size(), labels[model_index].size())

            for idx_in in range(11, 20):
                embedding_class = word_embeddings[idx_in, :].unsqueeze(0).repeat(10000, 1).to(self.config.device)
                # Noise generation
                z = torch.rand((embedding_class.shape[0], 300)).to(self.config.device)
                fake_features_class = self.generator(embedding_class, z.float())
                fake_features = torch.cat((fake_features, fake_features_class), dim = 0)
                labels = torch.cat((labels, torch.ones(10000).to(self.config.device) * idx_in), dim=0)

            predictions = self.forward_classifier(fake_features, word_embeddings)

            loss += self.calculate_loss(predictions, labels.long())
            # loss += self.local_regularization(coords_scannet[model_index], features, labels[model_index])
        else:
            predictions = self.forward_classifier(features, word_embeddings)

        return predictions, loss

class MinkUNet14(MinkUNetBase):
    BLOCK = BasicBlock
    LAYERS = (1, 1, 1, 1, 1, 1, 1, 1)


class MinkUNet18(MinkUNetBase):
    BLOCK = BasicBlock
    LAYERS = (2, 2, 2, 2, 2, 2, 2, 2)


class MinkUNet34(MinkUNetBase):
    BLOCK = BasicBlock
    LAYERS = (2, 3, 4, 6, 2, 2, 2, 2)


class MinkUNet50(MinkUNetBase):
    BLOCK = Bottleneck
    LAYERS = (2, 3, 4, 6, 2, 2, 2, 2)


class MinkUNet101(MinkUNetBase):
    BLOCK = Bottleneck
    LAYERS = (2, 3, 4, 23, 2, 2, 2, 2)


class MinkUNet14A(MinkUNet14):
    PLANES = (32, 64, 128, 256, 128, 128, 96, 96)


class MinkUNet14B(MinkUNet14):
    PLANES = (32, 64, 128, 256, 128, 128, 128, 128)


class MinkUNet14C(MinkUNet14):
    PLANES = (32, 64, 128, 256, 192, 192, 128, 128)


class MinkUNet14D(MinkUNet14):
    PLANES = (32, 64, 128, 256, 384, 384, 384, 384)


class MinkUNet18A(MinkUNet18):
    PLANES = (32, 64, 128, 256, 128, 128, 96, 96)

class MinkUNet18B(MinkUNet18):
    PLANES = (32, 64, 128, 256, 128, 128, 128, 128)


class MinkUNet18D(MinkUNet18):
    PLANES = (32, 64, 128, 256, 384, 384, 384, 384)


class MinkUNet34A(MinkUNet34):
    PLANES = (32, 64, 128, 256, 256, 128, 64, 64)


class MinkUNet34B(MinkUNet34):
    PLANES = (32, 64, 128, 256, 256, 128, 64, 32)


class MinkUNet34C(MinkUNet34):
    PLANES = (32, 64, 128, 256, 256, 128, 96, 96)
