import torch
import torch.nn as nn

class SelfDivLoss(nn.Module):
    def __init__(self):
        super(DivLoss, self).__init__()
        self.cos_sim = nn.CosineSimilarity(dim=0)
    def forward(self, x):
        total = x.shape[0] * (x.shape[0]-1) / 2
        loss = 0.0
        for i in range(x.shape[0]):
            for j in range(i+1, x.shape[0]):
                loss += torch.exp(self.cos_sim(x[i],x[j]))
        return loss / total

class DivLoss(nn.Module):
    def __init__(self):
        super(DivLoss, self).__init__()
    def forward(self, new_batch, old_batch, t=1):
        # mse = nn.MSELoss()
        # loss = mse(new_batch, old_batch)

        cos = nn.CosineSimilarity(dim=1, eps=1e-6)
        output = cos(new_batch.flatten(1), old_batch.flatten(1))
        loss = torch.exp(output).mean()

        # product = torch.matmul(new_batch.flatten(1), old_batch.flatten(1).transpose(0,1))
        # # loss = torch.exp(product/t).sum()
        # # t = product.mean().detach()
        # t = product.max().detach()
        # loss = torch.exp(product/t).mean()
        return loss
