import torch
import torch.nn as nn
import torch.nn.functional as F

class KLDivergenceLoss(nn.Module):
    def __init__(self, temperature=1.0, reduction="batchmean"):
        super(KLDivergenceLoss, self).__init__()
        self.temperature = temperature
        self.reduction = reduction

    def forward(self, logits1, logits2):
        probs1 = F.softmax(logits1 / self.temperature, dim=1)
        probs2 = F.softmax(logits2 / self.temperature, dim=1)

        kl_div = F.kl_div(probs1.log(), probs2, reduction=self.reduction)
        kl_div *= (self.temperature ** 2)
        return kl_div



def logits_smooth(logits,temperature):
    return logits / temperature


def calculate_communication_cost(tensor: torch.Tensor) -> int:
    num_elements = tensor.numel()
    element_size = tensor.storage().element_size()
    communication_cost = num_elements * element_size
    return communication_cost / 1048576