import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import numpy as np

class Loss(nn.Module):
    def __init__(self, batch_size, class_num, temperature_f, temperature_l, device):
        super(Loss, self).__init__()
        self.batch_size = batch_size
        self.class_num = class_num
        self.temperature_f = temperature_f
        self.temperature_l = temperature_l
        self.device = device
        self.index_dif = torch.combinations(torch.arange(self.class_num))
        self.mask_C = self.remove_index()

        self.mask = self.mask_correlated_samples(batch_size)
        self.similarity = nn.CosineSimilarity(dim=2)
        self.criterion = nn.CrossEntropyLoss(reduction="sum")

    def mask_correlated_samples(self, N):
        mask = torch.ones((N, N))
        mask = mask.fill_diagonal_(0)
        for i in range(N//2):
            mask[i, N//2 + i] = 0
            mask[N//2 + i, i] = 0
        mask = mask.bool()
        return mask

    def remove_index(self):
        N = self.index_dif.size(0)
        M = self.class_num

        mask = torch.ones((N, M), dtype=bool)
        for i in range(N):
            mask[i, self.index_dif[i]] = False
        return mask

    def prototype_dif2(self, C):

        C = F.normalize(C)
        sim = torch.mm(C, C.t())
        f = lambda x: torch.exp(x / self.temperature_l)
        sim = f(sim)

        B1 = sim[self.index_dif[:, 0]]
        B2 = sim[self.index_dif[:, 1]]
        B = B1 - B2
        B = B[self.mask_C].view(B.size(0), -1)
        f2 = lambda x: torch.exp(-x / self.temperature_l)

        loss_sum = torch.sum(f2(B ** 2))/(self.index_dif.size(0) * self.class_num)

        return loss_sum

    def forward_feature(self, h_i, h_j):
        h_i = self.activate_and_normalize(h_i)
        h_j = self.activate_and_normalize(h_j)
        N = 2 * self.batch_size
        h = torch.cat((h_i, h_j), dim=0)

        sim = torch.matmul(h, h.T) / self.temperature_f
        sim_i_j = torch.diag(sim, self.batch_size)
        sim_j_i = torch.diag(sim, -self.batch_size)

        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
        mask = self.mask_correlated_samples(N)
        negative_samples = sim[mask].reshape(N, -1)

        labels = torch.zeros(N).to(positive_samples.device).long()
        logits = torch.cat((positive_samples, negative_samples), dim=1)
        loss = self.criterion(logits, labels)
        loss /= N
        return loss

    def activate_and_normalize(self, tensor):
        tensor = torch.clamp(tensor, min=0)
        row_sums = tensor.sum(dim=1, keepdim=True)
        row_sums = row_sums + (row_sums == 0).float()
        tensor = tensor / row_sums
        return tensor
