import torch 
import torch.nn as nn 
from torch.nn import functional as F


def smooth_targets(logits, targets, smoothing=0.1):
    """
    label smoothing
    """
    with torch.no_grad():
        true_dist = torch.zeros_like(logits)
        true_dist.fill_(smoothing / (logits.shape[-1] - 1))
        true_dist.scatter_(1, targets.data.unsqueeze(1), (1 - smoothing))
    return true_dist


def ce_loss(logits, targets, reduction='none'):
    """
    wrapper for cross entropy loss in pytorch.

    Args:
        logits: logit values, shape=[Batch size, # of classes]
        targets: integer or vector, shape=[Batch size] or [Batch size, # of classes]
        # use_hard_labels: If True, targets have [Batch size] shape with int values. If False, the target is vector (default True)
        reduction: the reduction argument
    """
    if logits.shape == targets.shape:
        # one-hot target
        log_pred = F.log_softmax(logits, dim=-1)
        nll_loss = torch.sum(-targets * log_pred, dim=1)
        if reduction == 'none':
            return nll_loss
        else:
            return nll_loss.mean()
    else:
        log_pred = F.log_softmax(logits, dim=-1)
        return F.nll_loss(log_pred, targets, reduction=reduction)


def consistency_loss(logits, targets, name='ce', mask=None):
    """
    wrapper for consistency regularization loss in semi-supervised learning.

    Args:
        logits: logit to calculate the loss on and back-propagion, usually being the strong-augmented unlabeled samples
        targets: pseudo-labels (either hard label or soft label)
        name: use cross-entropy ('ce') or mean-squared-error ('mse') to calculate loss
        mask: masks to mask-out samples when calculating the loss, usually being used as confidence-masking-out
    """

    assert name in ['ce', 'mse']
    # logits_w = logits_w.detach()
    if name == 'mse':
        probs = torch.softmax(logits, dim=-1)
        loss = F.mse_loss(probs, targets, reduction='none').mean(dim=1)
    else:
        loss = ce_loss(logits, targets, reduction='none')

    if mask is not None:
        # mask must not be boolean type
        loss = loss * mask

    return loss.mean()


# ================================================================================================================= #

def ova_loss(logits_out_w, logits_out_s, label, w_neg_ratio, use_hard_negative=False):
    """
    OVA loss with optional hard negative mining.
    
    Args:
        logits_out_w: weak logits [B, 2*C]
        logits_out_s: strong logits [B, 2*C]
        label: ground-truth label [B]
        w_neg_ratio: weighting ratio for weak vs strong logits
        use_hard_negative: if True, use only the hardest negative class
    """
    # reshape to [B, 2, C]
    logits_out_w = logits_out_w.view(logits_out_w.size(0), 2, -1)  # [bs, 2, num_class]
    logits_out_s = logits_out_s.view(logits_out_s.size(0), 2, -1)  # [bs, 2, num_class]
    B, _, C = logits_out_w.shape

    # one-hot label
    label_ = torch.zeros((B, C)).long().to(label.device)
    label_.scatter_(1, label.view(-1, 1), 1)  # GT class = 1
    label_neg = 1 - label_

    # -------------------
    # Positive loss (GT class → inlier)
    # -------------------
    log_probs_w = F.log_softmax(logits_out_w, dim=1)  # [B, 2, C]
    loss_pos = torch.mean((-log_probs_w[:, 1, :] * label_).sum(dim=1))

    # -------------------
    # Negative loss
    # -------------------
    if use_hard_negative:
        # 1. inlier log-prob for all classes
        log_prob_inlier = log_probs_w[:, 1, :]  # [B, C]
        # 2. mask GT class
        log_prob_inlier = log_prob_inlier.masked_fill(label_ == 1, float('-inf'))
        # 3. get hardest negative index per sample
        hard_neg_idx = log_prob_inlier.argmax(dim=1, keepdim=True)  # [B, 1]
        # 4. one-hot mask for hard negative
        hard_neg_mask = torch.zeros_like(label_).scatter_(1, hard_neg_idx, 1)  # [B, C]
        # 5. compute negative loss on hard negative only
        loss_neg_w = -log_probs_w[:, 0, :] * hard_neg_mask
        loss_neg_s = -F.log_softmax(logits_out_s, dim=1)[:, 0, :] * hard_neg_mask
    else:
        # use all non-GT classes as negative
        loss_neg_w = -log_probs_w[:, 0, :] * label_neg
        loss_neg_s = -F.log_softmax(logits_out_s, dim=1)[:, 0, :] * label_neg

    # weighted average
    loss_neg_ws = (w_neg_ratio * loss_neg_w) + ((1 - w_neg_ratio) * loss_neg_s)
    loss_neg = torch.mean(loss_neg_ws.sum(dim=1))

    return loss_pos + loss_neg


def ova_ulb(logits_out_u, neg_mask):
    logits_out_u = logits_out_u.view(logits_out_u.size(0), 2, -1)
    loss_neg = -F.log_softmax(logits_out_u, dim=1)[:, 0, :] * neg_mask

    L = torch.mean(loss_neg.sum(dim=1))

    return L

    
def ova_ent(logits_out):
    logits_out = logits_out.view(logits_out.size(0), 2, -1)
    logits_out = F.softmax(logits_out, 1)
    
    L = torch.mean(torch.mean(torch.sum(-logits_out * torch.log(logits_out + 1e-8), 1), 1))
    
    return L


def ova_socr(logits_out_w0, logits_out_w1):
    logits_out_w0 = logits_out_w0.view(logits_out_w0.size(0), 2, -1)
    logits_out_w0 = F.softmax(logits_out_w0, 1)
    logits_out_w1 = logits_out_w1.view(logits_out_w1.size(0), 2, -1)
    logits_out_w1 = F.softmax(logits_out_w1, 1)
    
    L = torch.mean(torch.sum(torch.sum(torch.abs(logits_out_w0 - logits_out_w1)**2, 1), 1))
    
    return L


# ================================================================================================================= #

def supervised_contrastive_loss(features, labels, T=0.5, eps=1e-8):
    """
    Supervised contrastive loss following the SupCon paper.
    """
    # Normalize the feature embeddings
    features = F.normalize(features, dim=1)
    
    # Compute similarity 
    similarities = torch.matmul(features, features.T) / T
    
    # Create a mask for positive and negative pairs
    labels = labels.unsqueeze(1)  # (batch_size, 1)
    pos_mask = torch.eq(labels, labels.T).float()  # Positive pairs have the same label
    neg_mask = 1 - pos_mask

    # Apply exponential to similarity matrix (excluding self-similarity)
    logits = torch.exp(similarities) * (1 - torch.eye(labels.size(0)).to(similarities.device))

    # Compute positive and negative similarities
    pos_sim = (logits * pos_mask).sum(dim=1)  # Positive similarities
    pos_sim = torch.clamp(pos_sim, min=eps)
    neg_sim = (logits * neg_mask).sum(dim=1)  # Negative similarities

    # Handle the case where no positive pairs exist (to avoid NaN loss)
    pos_pair_count = pos_mask.sum(dim=1)
    pos_pair_count = torch.where(pos_pair_count == 0, torch.tensor(1.0).to(pos_pair_count.device), pos_pair_count)  # Avoid division by zero

    # Calculate the supervised contrastive loss for each anchor
    loss = -torch.log(pos_sim / (pos_sim + neg_sim)) / pos_pair_count

    return loss.mean()


# ================================================================================================================= #

def proto_contrastive_lb_loss(prototypes, feat_w, labels, T=0.5):
    """
    Proto-contrastive loss for labeled data
    """

    proto_norm = F.normalize(prototypes, dim=1)
    feat_w_norm = F.normalize(feat_w, dim=1)
    sim_w = torch.matmul(feat_w_norm, proto_norm.T) / T
    
    # InfoNCE loss
    pos = sim_w.gather(1, labels.unsqueeze(1))  
    all_logits_sum = torch.logsumexp(sim_w, dim=1, keepdim=True)

    lb_cont_loss = -pos + all_logits_sum

    return lb_cont_loss.mean()


def proto_contrastive_ulb_loss(prototypes, feat_w, 
                               p, in_p, 
                               p_th=0.99, od_in_th=0.5, T=0.5):
    """
    Unlabeled contrastive loss for self-supervised learning with soft pseudo-labeling and thresholding.
    """

    proto_norm = F.normalize(prototypes, dim=1)
    feat_w_norm = F.normalize(feat_w, dim=1)
    sim_w = torch.matmul(feat_w_norm, proto_norm.T) / T
    
    # mask
    pseudo_class = F.one_hot(p.argmax(dim=1), num_classes=p.shape[1]).bool()

    pos_mask = pseudo_class * (p > p_th) * (in_p > od_in_th)

    # InfoNCE loss
    all_logits_sum = torch.logsumexp(sim_w, dim=1, keepdim=True) 
    pos = (sim_w * pos_mask.float()).sum(dim=1, keepdim=True)
    
    ulb_cont_loss = -pos + all_logits_sum

    return ulb_cont_loss.mean()


# ================================================================================================================= #

def BYOL_loss(p, z):
    p = F.normalize(p, dim=1, p=2)
    z = F.normalize(z, dim=1, p=2)
    return 2 - 2 * (p * z).sum(dim=1).mean()

# ================================================================================================================= #