
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch 
import torch.nn as nn 
from torch.nn import functional as F

from .cross_entropy import ce_loss


def one_hot(targets, nClass, gpu):
    logits = torch.zeros(targets.size(0), nClass).cuda(gpu)
    return logits.scatter_(1, targets.unsqueeze(1), 1)



def dul_consistency_loss(logits, c_logits, targets, c_targets, weight, name='ce', mask=None, c_mask=None, is_mm=False, index=None, lam=None, it=None):
    n_class = logits.shape[1]
    assert name in ['ce', 'mse', 'kl']
    # logits_w = logits_w.detach()
    if name == 'mse':
        probs = torch.softmax(logits, dim=-1)
        loss = F.mse_loss(probs, targets, reduction='none').mean(dim=1)
    elif name == 'kl':
        loss = F.kl_div(F.log_softmax(logits / 0.5, dim=-1), F.softmax(targets / 0.5, dim=-1), reduction='none')
        loss = torch.sum(loss * (1.0 - mask).unsqueeze(dim=-1).repeat(1, torch.softmax(logits, dim=-1).shape[1]), dim=1)
    else:
        targets = one_hot(targets, n_class, 0)
        maskk = mask.view(-1, 1)
        targets = targets * maskk
        c_targets = one_hot(c_targets, n_class, 0)
        c_maskk = c_mask.view(-1, 1)
        c_targets = c_targets * c_maskk
        
        # entropy-based weighting
        # dul_targets = (1-weights)*targets + weights*c_targets
        
        # no entropy-based weighting
        if it < 30000:
            dul_targets = 0.3*targets + 0.7*c_targets
        else:
            dul_targets = 0.7*targets + 0.3*c_targets

        dul_mask = torch.max(mask, c_mask)
        loss = ce_loss(logits, dul_targets, reduction='none')

        if is_mm:
            dul_targets_m = dul_targets.clone()
            selected_targets = dul_targets[dul_mask.bool()]
            if selected_targets.size(0) != 0:
                dul_targets_m[dul_mask.bool()] = lam * selected_targets + (1 - lam) * selected_targets[index, :]
            c_loss = ce_loss(c_logits, dul_targets_m, reduction='none')
        else:
            c_loss = ce_loss(c_logits, dul_targets, reduction='none')

    if mask is not None and name != 'kl':
        # mask must not be boolean type
        loss = loss * dul_mask
        c_loss = c_loss * dul_mask

    return loss.mean(), c_loss.mean()



def consistency_loss(logits, targets, name='ce', mask=None):
    """
    wrapper for consistency regularization loss in semi-supervised learning.

    Args:
        logits: logit to calculate the loss on and back-propagion, usually being the strong-augmented unlabeled samples
        targets: pseudo-labels (either hard label or soft label)
        name: use cross-entropy ('ce') or mean-squared-error ('mse') to calculate loss
        mask: masks to mask-out samples when calculating the loss, usually being used as confidence-masking-out
    """

    assert name in ['ce', 'mse', 'kl']
    # logits_w = logits_w.detach()
    if name == 'mse':
        probs = torch.softmax(logits, dim=-1)
        loss = F.mse_loss(probs, targets, reduction='none').mean(dim=1)
    elif name == 'kl':
        loss = F.kl_div(F.log_softmax(logits / 0.5, dim=-1), F.softmax(targets / 0.5, dim=-1), reduction='none')
        loss = torch.sum(loss * (1.0 - mask).unsqueeze(dim=-1).repeat(1, torch.softmax(logits, dim=-1).shape[1]), dim=1)
    else:
        loss = ce_loss(logits, targets, reduction='none')

    if mask is not None and name != 'kl':
        # mask must not be boolean type
        loss = loss * mask

    return loss.mean()



class ConsistencyLoss(nn.Module):
    """
    Wrapper for consistency loss
    """
    def forward(self, logits, targets, name='ce', mask=None):
        return consistency_loss(logits, targets, name, mask)