import torch
import torch.nn.functional as F
from train_utils import ce_loss


class Get_Scalar:
    def __init__(self, value):
        self.value = value
        
    def get_value(self, iter):
        return self.value
    
    def __call__(self, iter):
        return self.value


def consistency_loss(logits_s, logits_w, name='ce', rho=float('inf'), use_hard_labels=True, **kwargs):
    assert name in ['ce', 'L2']
    logits_w = logits_w.detach()

    if name == 'ce':
        if use_hard_labels:
            pseudo_label = torch.argmax(logits_w, dim=-1)
            loss_s = ce_loss(logits_s, pseudo_label, use_hard_labels=True, reduction='none')
            loss_w = ce_loss(logits_w, pseudo_label, use_hard_labels=True, reduction='none')
        else:
            pseudo_label = torch.softmax(logits_w / 0.5, dim=-1)
            loss_s = ce_loss(logits_s, pseudo_label, use_hard_labels=False, reduction='none')
            loss_w = ce_loss(logits_w, pseudo_label, use_hard_labels=False, reduction='none')
        
        mask = loss_w.le(rho).to(logits_s.dtype).detach()
        masked_loss = (loss_s * mask).sum() / mask.sum()
        # masked_loss = loss * mask
        return masked_loss, mask.mean()

    else:
        assert Exception('Not Implemented consistency_loss')
            
