import torch.nn as nn
import torch
import math
import torch.nn.functional as F
import random


class GMMNLoss:
    def __init__(self, sigma=[2, 5, 10, 20, 40, 80], cuda=False):
        # self.sigma = sigma
        self.sigma = [10]
        self.cuda = cuda

    def build_loss(self):
        return self.moment_loss

    def get_scale_matrix(self, M, N):
        s1 = torch.ones((N, 1)) * 1.0 / N
        s2 = torch.ones((M, 1)) * -1.0 / M
        s1, s2 = s1.to(self.cuda), s2.to(self.cuda)
        return torch.cat((s1, s2), 0)

    def moment_loss(self, gen_samples, x):
        X = torch.cat((gen_samples, x), 0)
        XX = torch.matmul(X, X.t())
        X2 = torch.sum(X * X, 1, keepdim=True)
        exp = XX - 0.5 * X2 - 0.5 * X2.t()
        M = gen_samples.size()[0]
        N = x.size()[0]
        s = self.get_scale_matrix(M, N)
        S = torch.matmul(s, s.t())

        loss = 0
        for v in self.sigma:
            kernel_val = torch.exp(exp / v).to(self.cuda)
            loss += torch.sum(S * kernel_val)

        loss = torch.sqrt(loss)
        return loss

class contrastive(nn.Module):
    def __init__(self, device):
        super(contrastive, self).__init__()
        self.device = device
        self.non_linear = nn.Sequential(
            nn.ReLU(),
            nn.Linear(96, 32),
        )
    def forward(self, coords_scannet, predictions, similarities, labels, instances_prediction, instances_labels, ofield_scannet, ofield_model, masks, config, epoch):
        # prediction size : (N, 2)
        # label size: (N, )
        # return F.cross_entropy(predictions, labels)

        labels_final = labels[labels != -100]
        predictions_final = predictions[labels != -100]
        loss = F.cross_entropy(predictions_final, labels_final)

        # loss += torch.mean(1 - similarities[labels_final])

        # 2_1 : epoch >= 10, [0.999, 0.999, 0.999, 0.999, 0.999]
        # 1_0 : epoch >= 10, [0.9999, 0.9999, 0.9999, 0.9999, 0.9999]
        # 1_1 : epoch >= 10, [0.9, 0.9, 0.9, 0.9, 0.9]
        # 1_2 : epoch >= 10, [0.8, 0.8, 0.8, 0.8, 0.8]
        # 1_3 : epoch >= 10, [0.5, 0.5, 0.5, 0.5, 0.5]
        # if epoch >= 10:
        #     predictions_final = []
        #     labels_final = []
        #     un_index = coords_scannet[:, 0] >= config.batch_size_scannet
        #     un_prediction = predictions[un_index]
        #     scores = F.softmax(un_prediction, dim=1)
        #
        #     for types in range(un_prediction.size()[1]):
        #         index = scores[:, types] > config.threshold[types]
        #         prediction = un_prediction[index]
        #         label = torch.ones(prediction.size()[0]) * types
        #
        #         predictions_final.append(prediction)
        #         labels_final.append(label)
        #
        #
        #     prediction_final = torch.cat(predictions_final, dim=0)
        #     prediction_label = torch.cat(labels_final, dim=0).to(self.device).long()
        #     loss += F.cross_entropy(prediction_final, prediction_label) * min(1, epoch * 0.02)

        # loss += F.cross_entropy(instances_prediction, instances_labels.long().to(config.device))
        # sample_num = 5000
        # ofield_scannet = torch.softmax(predictions[masks == 1], dim=1)
        # pred_a = ofield_scannet[:ofield_scannet.size()[0] // 2]
        # pred_b = ofield_scannet[ofield_scannet.size()[0] // 2:]
        # sample_index = list(random.sample(range(pred_a.size()[0]), sample_num))

        # pred_a = pred_a[sample_index]
        # pred_b = pred_b[sample_index]

        # simis = torch.sigmoid(torch.matmul(pred_a, pred_b.permute(1, 0)))

        # loss += torch.mean(torch.abs(torch.eye(sample_num).to(config.device) - simis)) * min(1, epoch * 0.02)

        # if epoch > 10:
        #     predictions_contra = torch.softmax(predictions[masks == 1], dim=1)
        #     pred_a = predictions_contra[:predictions_contra.size()[0] // 2]
        #     pred_b = predictions_contra[predictions_contra.size()[0] // 2:]
        #     sample_index = list(random.sample(range(pred_a.size()[0]), 128))
        #
        #     pred_a = pred_a[sample_index]
        #     pred_b = pred_b[sample_index]
        #
        #     loss += torch.mean(torch.abs(pred_a - pred_b)) * min(1, epoch*0.02)
        '''
        posi_index = labels == 0
        nega_index = labels == 4
        un_index = labels == -100


        posi_prediction = predictions[posi_index]
        posi_label = torch.zeros(posi_prediction.size()[0])

        nega_prediction = predictions[nega_index]
        nega_label = torch.ones(nega_prediction.size()[0])

        prediction_final = torch.cat((posi_prediction, nega_prediction), dim=0)
        prediction_label = torch.cat((posi_label, nega_label), dim=0).to(self.device).long()

        loss = F.cross_entropy(prediction_final, prediction_label)

        # consis_predict = self.non_linear(ofield_scannet[masks == 1])
        # pred_a = consis_predict[:consis_predict.size()[0] // 2]
        # pred_b = consis_predict[consis_predict.size()[0] // 2:]
        # loss += torch.mean(1 - F.sigmoid(torch.sum(pred_a * pred_b, dim=1)))

        # loss += torch.mean(torch.abs(pred_a - pred_b), dim=1).sum()

        un_prediction = predictions[un_index]
        scores = F.softmax(un_prediction, dim=1)[:, 0]

        #1_(2,3)
        un_to_posi_prediction = un_prediction[scores > config.posi_threshold]
        #scores_b = scores[scores > config.posi_threshold]
        #values, indices = scores_b.topk(min(5000, un_to_posi_prediction.size()[0]), dim=0, largest=True, sorted=False)
        #un_to_posi_prediction = un_to_posi_prediction[indices]


        un_to_posi_label = torch.zeros(un_to_posi_prediction.size()[0])
        un_to_nega_prediction = un_prediction[scores < config.nega_threshold]
        scores_b = scores[scores < config.nega_threshold]
        nega_k = min(100000, un_to_nega_prediction.size()[0])
        #print(nega_k)
        
        
        #import pdb
        #pdb.set_trace()
        
        values, indices = scores_b.topk(nega_k, dim=0, largest=False, sorted=False)
        un_to_nega_prediction = un_to_nega_prediction[indices]

        un_to_nega_label = torch.ones(un_to_nega_prediction.size()[0])

        if epoch < 10:
            prediction_final = torch.cat((posi_prediction, nega_prediction), dim = 0)
            prediction_label = torch.cat((posi_label, nega_label), dim=0).to(self.device).long()
        else:
            prediction_final = torch.cat((posi_prediction, nega_prediction, un_to_posi_prediction, un_to_nega_prediction), dim = 0)
            prediction_label = torch.cat((posi_label, nega_label, un_to_posi_label, un_to_nega_label), dim=0).to(self.device).long()

        loss = F.cross_entropy(prediction_final, prediction_label)
        '''
        # details consistency loss
        # models_in_scans = self.linear(ofield_scannet[posi_index].to(self.device))
        # models = self.linear(ofield_model.to(self.device))
        # relations = torch.sum(models_in_scans * models, dim=1).unsqueeze(1)
        # label_details = torch.zeros(relations.size()[0]).to(self.device).long()
        # loss += F.cross_entropy(relations, label_details)

        return loss