import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
class loss_boost:
    def __init__(self, temperature=0.5, cluster_centers=None, weight=None,epoch=None):
        self.temperature = temperature
        self.cluster_centers = cluster_centers
        self.weight = weight
        self.epoch = epoch
    def calculate_similarity_metric(self, q, k):
        q_norm = F.normalize(q, dim=-1)
        k_norm = F.normalize(k, dim=-1)
        similarity_matrix = torch.matmul(q_norm, k_norm.T)
        return similarity_matrix
    def forward(self, q_, k_, y_):
        q, k, y=q_, k_, y_
        num_labels= torch.unique(y).numel()
        # print(l2_norm)
        n_samples = q.size(0)

        weightq = torch.zeros(len(self.cluster_centers), n_samples).to(q)  # L, N
        weightq[y, torch.arange(n_samples)] = 1
        weightk = torch.zeros(len(self.cluster_centers), n_samples).to(q)  # L, N
        weightk[y, torch.arange(n_samples)] = 1
        centersq = torch.mm(weightq, q)
        l2_normq = torch.norm(centersq, p=2, dim=1)
        l2_normq = l2_normq[y]
        centersk = torch.mm(weightk, k)
        l2_normk = torch.norm(centersk, p=2, dim=1)
        l2_normk = l2_normk[y]
        l2_norm = l2_normq.unsqueeze(1) * l2_normk.unsqueeze(0)
        similarity_matrix = self.calculate_similarity_metric(q, k)
        labels = (y.unsqueeze(1) == y.unsqueeze(0))
        pos_loss = (labels* (2-2*similarity_matrix )) / (l2_norm)/num_labels
        loss =pos_loss.sum()
        return loss