import torch
import torch.nn.functional as F


class InfoNCELoss(torch.nn.Module):
    def __init__(self, temperature=0.07, eps=1e-8):
        super().__init__()
        self.temperature = temperature
        self.eps = eps

    def forward(self, img_feat, txt_feat):
        """
        Optimized InfoNCE loss
        Args:
            img_feat: [N, D] - Image features
            txt_feat: [N, D] - Text features
        """
        # L2 normalization
        img = F.normalize(img_feat, dim=-1, p=2)
        txt = F.normalize(txt_feat, dim=-1, p=2)

        # Calculate similarity matrix
        logits = torch.matmul(img, txt.T) / self.temperature  # [N, N]

        # Ensure numerical stability
        logits = logits - logits.max(dim=-1, keepdim=True)[0].detach()

        # Labels (positive samples on the diagonal)
        labels = torch.arange(img.size(0), device=img.device, dtype=torch.long)

        # Bidirectional loss
        loss_img2txt = F.cross_entropy(logits, labels)
        loss_txt2img = F.cross_entropy(logits.T, labels)

        return (loss_img2txt + loss_txt2img) / 2


class SingleROIContrastiveLoss(torch.nn.Module):
    """
    Added: Single ROI contrastive loss, used for single ROI processing in the new scheme
    """

    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature

    def forward(self, global_img_feat, roi_img_feat, region_txt_feat, roi_quality_weights):
        """
        Single ROI contrastive loss
        Args:
            global_img_feat: [B, D] Global image features
            roi_img_feat: [B, D] ROI image features
            region_txt_feat: [B, D] Region text features
            roi_quality_weights: [B] ROI quality weights
        """
        # Fusion of global and ROI features
        roi_quality_weights = roi_quality_weights.unsqueeze(1)  # [B, 1]
        fused_img_feat = (global_img_feat + roi_img_feat * roi_quality_weights) / 2

        # Normalization
        img_norm = F.normalize(fused_img_feat, dim=-1)
        txt_norm = F.normalize(region_txt_feat, dim=-1)

        # Calculate similarity matrix
        logits = torch.matmul(img_norm, txt_norm.T) / self.temperature
        logits = logits - logits.max(dim=-1, keepdim=True)[0].detach()

        labels = torch.arange(img_norm.size(0), device=img_norm.device)

        # Bidirectional loss
        loss_img2txt = F.cross_entropy(logits, labels)
        loss_txt2img = F.cross_entropy(logits.T, labels)

        return (loss_img2txt + loss_txt2img) / 2


class Fixed5NegativeLoss(torch.nn.Module):
    """
    Added: Hard negative mining loss with fixed 5 negative samples
    """

    def __init__(self, temperature=0.07, hard_ratio=0.8):
        super().__init__()
        self.temperature = temperature
        self.hard_ratio = hard_ratio

    def forward(self, img_feat, pos_feat, neg_feat_5):
        """
        Hard negative mining with fixed 5 negative samples
        Args:
            img_feat: [B, D] Image features
            pos_feat: [B, D] Positive text features
            neg_feat_5: [B, 5, D] Fixed 5 negative sample features
        """
        B, N, D = neg_feat_5.shape  # N=5

        # Normalize features
        img_norm = F.normalize(img_feat, dim=-1)
        pos_norm = F.normalize(pos_feat, dim=-1)
        neg_norm = F.normalize(neg_feat_5, dim=-1)

        # Calculate positive sample similarity
        pos_sim = torch.sum(img_norm * pos_norm, dim=-1, keepdim=True) / self.temperature

        # Calculate negative sample similarity
        neg_sim = torch.bmm(img_norm.unsqueeze(1), neg_norm.transpose(1, 2)).squeeze(1) / self.temperature

        # Select hard negative samples (choose hardest 4 from 5)
        num_hard = max(1, int(N * self.hard_ratio))
        hard_neg_sim, _ = neg_sim.topk(num_hard, dim=-1)

        # Build logits
        logits = torch.cat([pos_sim, hard_neg_sim], dim=-1)
        logits = logits - logits.max(dim=-1, keepdim=True)[0].detach()

        # Label is 0 (positive sample)
        labels = torch.zeros(B, dtype=torch.long, device=img_feat.device)

        return F.cross_entropy(logits, labels)


class ROIQualityLoss(torch.nn.Module):
    """
    Added: ROI quality assessment loss
    """

    def __init__(self, alpha=0.5):
        super().__init__()
        self.alpha = alpha

    def forward(self, quality_scores, roi_types):
        """
        ROI quality assessment loss
        Args:
            quality_scores: [B, 3] Quality scores (region_detection, global_detection, abnormal_global)
            roi_types: [B] List of ROI type strings
        """
        # Convert string types to labels
        type_to_label = {
            'abnormal_region_detection': 0,  # Highest quality: precise anomaly detection
            'abnormal_global_detection': 1,  # Medium quality: global anomaly detection
            'normal_global': 2,  # Lower quality: normal full image
            'no_roi': 2,  # Treat as normal_global
            'failed': 2  # Treat as normal_global when failed
        }

        labels = torch.tensor([
            type_to_label.get(rtype, 1) for rtype in roi_types
        ], device=quality_scores.device)

        return F.cross_entropy(quality_scores, labels) * self.alpha


class RegionContrastiveLoss(torch.nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature

    def forward(self, roi_feats, region_txt_feats):
        """
        Optimized region contrastive loss
        Args:
            roi_feats: [B, R, D] - Region image features
            region_txt_feats: [B, R, D] - Region text features
        """
        B, R, D = roi_feats.shape

        # Flatten to [B*R, D]
        roi_feats_flat = roi_feats.reshape(-1, D)
        txt_feats_flat = region_txt_feats.reshape(-1, D)

        # Filter out zero vectors (invalid regions)
        roi_norm = torch.norm(roi_feats_flat, dim=-1)
        txt_norm = torch.norm(txt_feats_flat, dim=-1)
        valid_mask = (roi_norm > 1e-6) & (txt_norm > 1e-6)

        if valid_mask.sum() < 2:  # Need at least 2 valid samples
            return torch.tensor(0.0, device=roi_feats.device, requires_grad=True)

        roi_feats_valid = roi_feats_flat[valid_mask]
        txt_feats_valid = txt_feats_flat[valid_mask]

        # Normalization
        roi_feats_norm = F.normalize(roi_feats_valid, dim=-1)
        txt_feats_norm = F.normalize(txt_feats_valid, dim=-1)

        # Calculate similarity matrix
        logits = torch.matmul(roi_feats_norm, txt_feats_norm.T) / self.temperature
        logits = logits - logits.max(dim=-1, keepdim=True)[0].detach()

        # Labels
        labels = torch.arange(roi_feats_valid.size(0), device=roi_feats.device)

        # Bidirectional loss
        loss_img2txt = F.cross_entropy(logits, labels)
        loss_txt2img = F.cross_entropy(logits.T, labels)

        return (loss_img2txt + loss_txt2img) / 2


class FocalLoss(torch.nn.Module):
    """Focal Loss for hard negative mining"""

    def __init__(self, alpha=1.0, gamma=2.0, temperature=0.07):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.temperature = temperature

    def forward(self, img_feat, txt_feat):
        img = F.normalize(img_feat, dim=-1)
        txt = F.normalize(txt_feat, dim=-1)

        logits = torch.matmul(img, txt.T) / self.temperature
        logits = logits - logits.max(dim=-1, keepdim=True)[0].detach()

        labels = torch.arange(img.size(0), device=img.device)

        # Calculate standard cross entropy
        ce_loss = F.cross_entropy(logits, labels, reduction='none')

        # Calculate focal weight
        pt = torch.exp(-ce_loss)
        focal_weight = self.alpha * (1 - pt) ** self.gamma

        focal_loss = focal_weight * ce_loss
        return focal_loss.mean()


class AdaptiveTemperatureLoss(torch.nn.Module):
    """InfoNCE loss with adaptive temperature parameter"""

    def __init__(self, init_temperature=0.07, learnable=True):
        super().__init__()
        if learnable:
            self.log_temperature = torch.nn.Parameter(torch.log(torch.tensor(init_temperature)))
        else:
            self.register_buffer('log_temperature', torch.log(torch.tensor(init_temperature)))

    @property
    def temperature(self):
        return torch.exp(self.log_temperature).clamp(min=0.01, max=1.0)

    def forward(self, img_feat, txt_feat):
        img = F.normalize(img_feat, dim=-1)
        txt = F.normalize(txt_feat, dim=-1)

        logits = torch.matmul(img, txt.T) / self.temperature
        logits = logits - logits.max(dim=-1, keepdim=True)[0].detach()

        labels = torch.arange(img.size(0), device=img.device)

        loss_img2txt = F.cross_entropy(logits, labels)
        loss_txt2img = F.cross_entropy(logits.T, labels)

        return (loss_img2txt + loss_txt2img) / 2


class HardNegativeMiningLoss(torch.nn.Module):
    def __init__(self, temperature=0.07, hard_ratio=0.8):
        super().__init__()
        self.temperature = temperature
        self.hard_ratio = hard_ratio

    def forward(self, img_feat, pos_feat, neg_feat):
        """
        Contrastive loss with hard negative mining
        Args:
            img_feat: [B, D] - Image features
            pos_feat: [B, D] - Positive text features
            neg_feat: [B, N, D] - Negative text features
        """
        B, N, D = neg_feat.shape

        # Normalize features
        img_norm = F.normalize(img_feat, dim=-1)
        pos_norm = F.normalize(pos_feat, dim=-1)
        neg_norm = F.normalize(neg_feat, dim=-1)

        # Calculate positive sample similarity
        pos_sim = torch.sum(img_norm * pos_norm, dim=-1, keepdim=True) / self.temperature

        # Calculate negative sample similarity
        neg_sim = torch.bmm(img_norm.unsqueeze(1), neg_norm.transpose(1, 2)).squeeze(1) / self.temperature