from torch import Tensor
import torch.distributed as dist
import torch
import torch.nn.functional as F
import torch.nn as nn


class SimpleContrastiveLoss:
    def __init__(self, temperature: float = 0.02):
        self.temperature = temperature

    def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean') -> Tensor:
        if target is None:
            target_per_qry = y.size(0) // x.size(0)
            target = torch.arange(
                0, x.size(0) * target_per_qry, target_per_qry, device=x.device, dtype=torch.long)
        logits = torch.matmul(x, y.transpose(0, 1))
        loss = F.cross_entropy(logits / self.temperature, target, reduction=reduction)
        return loss


class SimpleContrastiveLoss_InfoTN:
    def __init__(self, temperature: float):
        self.temperature = temperature

    def fused_similarity(self, h: torch.Tensor, h_positive: torch.Tensor,
                         h_PO: torch.Tensor, h_positive_PO: torch.Tensor, ) -> torch.Tensor:
        """
        Compute similarity that incorporates both directional and norm information.
        Args:
            h: Representation of query samples (processed, used for computing direction).
            h_positive: Representation of positive samples (processed, used for computing direction).
            h_PO: Raw, POalized representation of query samples (used for computing magnitude).
            h_positive_PO: Raw, POalized representation of positive samples (used for computing magnitude).
            eps: Small constant to prevent division by zero.
            Returns:
            A similarity matrix of shape [batch_size_h, batch_size_h_positive].
        """

        # 1. Compute cosine similarity (directional information)
        # Note: h and h_positive may be normalized, or dot product is used directly
        cos_sim = torch.matmul(h, h_positive.transpose(0, 1))  # [B_h, B_positive]

        # 2. Compute similarity based on l_TN (magnitude information)
        # We need to compute l_TN between each vector in h_PO and each vector in h_positive_PO

        # Expand h_PO to [B_h, 1, dim]
        h_PO_expanded = h_PO.unsqueeze(1)  # [B_h, 1, dim]
        # Expand h_positive_PO to [1, B_positive, dim]
        h_positive_PO_expanded = h_positive_PO.unsqueeze(0)  # [1, B_positive, dim]

        # Broadcast subtraction to get [B_h, B_positive, dim]
        diff = h_PO_expanded - h_positive_PO_expanded  # [B_h, B_positive, dim]
        # Compute L2 norm of difference vectors, resulting in [B_h, B_positive]
        diff_norm = torch.norm(diff, dim=2, p=2)  # [B_h, B_positive]

        # norm_h: [B_h, 1]
        norm_h = torch.norm(h_PO, dim=1, p=2).unsqueeze(1)  # [B_h, 1]
        # norm_h_positive: [1, B_positive]
        norm_h_positive = torch.norm(h_positive_PO, dim=1, p=2).unsqueeze(0)  # [1, B_positive]
        # sum_norm: [B_h, B_positive]
        sum_norm = norm_h + norm_h_positive

        # l_TN distance
        InfoTN_dist = diff_norm / sum_norm  # [B_h, B_positive]
        # Convert to similarity
        InfoTN_sim = 1.0 - InfoTN_dist  # [B_h, B_positive]

        return InfoTN_sim, cos_sim

    def __call__(self, x: Tensor, y: Tensor, x_PO: Tensor, y_PO: Tensor, target: Tensor = None,
                 reduction: str = 'mean') -> Tensor:
        if target is None:
            target_per_qry = y.size(0) // x.size(0)

            target = torch.arange(
                0, x.size(0) * target_per_qry, target_per_qry, device=x.device, dtype=torch.long)

        InfoTN_sim, cos_sim = self.fused_similarity(x, y, x_PO, y_PO)

        loss_cos = F.cross_entropy(cos_sim / self.temperature, target, reduction=reduction)
        "please specify tau_infotn according to the paper"
        loss_InfoTN = F.cross_entropy(InfoTN_sim / tau_infotn, target, reduction=reduction)

        return 0.5 * loss_cos + 0.5 * loss_InfoTN


import torch
import torch.nn.functional as F


class DistributedContrastiveLoss(SimpleContrastiveLoss):
    def __init__(self, n_target: int = 0, scale_loss: bool = True, temperature: float = 0.02):
        assert dist.is_initialized(), "Distributed training has not been properly initialized."
        super().__init__()
        self.word_size = dist.get_world_size()
        self.rank = dist.get_rank()
        self.scale_loss = scale_loss
        self.temperature = temperature

    def __call__(self, x: Tensor, y: Tensor, **kwargs):
        dist_x = self.gather_tensor(x)
        dist_y = self.gather_tensor(y)
        loss = super().__call__(dist_x, dist_y, **kwargs)
        if self.scale_loss:
            loss = loss * self.word_size
        return loss

    def gather_tensor(self, t):
        gathered = [torch.empty_like(t) for _ in range(self.word_size)]
        dist.all_gather(gathered, t)
        gathered[self.rank] = t
        return torch.cat(gathered, dim=0)


class DistributedContrastiveLoss_InfoTN(SimpleContrastiveLoss_InfoTN):
    def __init__(self, n_target: int = 0, scale_loss: bool = True, temperature: float = 0.02):
        assert dist.is_initialized(), "Distributed training has not been properly initialized."
        super().__init__()
        self.word_size = dist.get_world_size()
        self.rank = dist.get_rank()
        self.scale_loss = scale_loss
        self.temperature = temperature

    def __call__(self, x: Tensor, y: Tensor, x_PO: Tensor, y_PO: Tensor, **kwargs):
        # print("x", x.shape)
        dist_x = self.gather_tensor(x)
        # print("dist_x", dist_x.shape)
        dist_y = self.gather_tensor(y)

        dist_x_PO = self.gather_tensor(x_PO)
        # print("dist_x", dist_x.shape)
        dist_y_PO = self.gather_tensor(y_PO)

        loss = super().__call__(dist_x, dist_y, dist_x_PO, dist_y_PO, **kwargs)
        # print("loss_all", loss)
        if self.scale_loss:
            loss = loss * self.word_size
        return loss

    def gather_tensor(self, t):
        gathered = [torch.empty_like(t) for _ in range(self.word_size)]
        dist.all_gather(gathered, t)
        gathered[self.rank] = t
        return torch.cat(gathered, dim=0)


class InExampleContrastiveLoss:
    """
    Categorization loss: cross_entropy of 1 out of K classes (target labels)
    x.shape=[bsz, hdim], y.shape=[bsz, num_label, hdim]
    """

    def __init__(self, n_hard_negatives: int = 0, temperature: float = 1.0, ndim: int = None, *args, **kwargs):
        self.target_per_qry = n_hard_negatives + 1
        self.temperature = temperature
        self.ndim = ndim

    def __call__(self, x: Tensor, y: Tensor, reduction: str = 'mean'):
        # print("gather InExampleContrastiveLoss")
        if torch.distributed.is_initialized():
            x = dist_utils.dist_gather(x)
            y = dist_utils.dist_gather(y)
        bsz, ndim = x.size(0), x.size(1)
        target = torch.zeros(bsz, dtype=torch.long, device=x.device)
        if self.ndim:
            ndim = self.ndim
            x = x[:, :ndim]
            y = y[:, :ndim]
        logits = torch.einsum('bod,bsd->bs', x.view(bsz, 1, ndim), y.view(bsz, -1, ndim)) * self.temperature
        preds = torch.argmax(logits, dim=-1)
        loss = F.cross_entropy(logits, target, reduction=reduction)
        loss_detail = {"logits": logits, "labels": target, "preds": preds}
        return loss, loss_detail
