import numpy as np
import torch
import torch.nn.functional as F
from ignite.utils import convert_tensor


class ClassifierUpdater:
    def __init__(self, *args, **kwargs):
        self.classifier = kwargs.pop("classifier")
        self.opt = kwargs.pop("optimizer")
        self.device = kwargs.pop("device")
        self.ema_model = kwargs.pop("ema_model")
        self.loss = F.cross_entropy

    def get_batch(self, batch, device=None, non_blocking=True):
        x, y = batch
        return (
            convert_tensor(x, device=device, non_blocking=non_blocking),
            convert_tensor(y, device=device, non_blocking=non_blocking),
        )

    def __call__(self, engine, batch):
        report = {}
        self.classifier.train()
        x, y = self.get_batch(batch, device=self.device)
        y_pred = self.classifier(x)
        loss = self.loss(y_pred, y)
        entropy = (-F.softmax(y_pred.detach(), dim=1) * F.log_softmax(y_pred.detach(), dim=1)).sum(dim=1).mean()
        self.opt.zero_grad()
        loss.backward()
        self.opt.step()
        report.update({"y_pred": y_pred.detach()})
        report.update({"y": y.detach()})
        report.update({"loss": loss.detach().item()})
        report.update({"entropy": entropy.detach().item()})
        return report


class PseudoSoftLabelClassifierUpdater:
    def __init__(self, *args, **kwargs):
        self.classifier = kwargs.pop("classifier")
        self.opt = kwargs.pop("optimizer")
        self.device = kwargs.pop("device")
        self.ema_model = kwargs.pop("ema_model")
        self.af = kwargs.pop("af")
        self.t1 = kwargs.pop("t1")
        self.t2 = kwargs.pop("t2")
        self.temperature = kwargs.pop("temperature")
        self.loss = F.cross_entropy
        if isinstance(self.classifier, torch.nn.DataParallel):
            self.num_classes = self.classifier.module.num_classes
        else:
            self.num_classes = self.classifier.num_classes

    def get_batch(self, batch, device=None, non_blocking=True):
        (x, y), (x_p, y_p) = batch
        if x_p.dim() == 5:
            b, ub, c, h, w = x_p.size()
            x_p = x_p.view(b * ub, c, h, w)
            y_p = y_p.view(b * ub)
        return (
            convert_tensor(x, device=device, non_blocking=non_blocking),
            convert_tensor(y, device=device, non_blocking=non_blocking),
            convert_tensor(x_p, device=device, non_blocking=non_blocking),
            convert_tensor(y_p, device=device, non_blocking=non_blocking),
        )

    def pseudo_weight(self, epoch):
        alpha = (epoch - self.t1) / (self.t2 - self.t1) * self.af
        return alpha if epoch < self.t2 else self.af

    def __call__(self, engine, batch):
        report = {}
        current_epoch = engine.state.epoch
        self.classifier.train()
        x, y, x_p, y_p = self.get_batch(batch, device=self.device)
        y_pred = self.classifier(x)
        loss = self.loss(y_pred, y)
        report.update({"y_pred": y_pred.detach()})
        report.update({"y": y.detach()})
        report.update({"loss": loss.detach().item()})
        report.update({"loss_u": 1000.0})
        if self.t1 < current_epoch:
            y_p_pred = self.classifier(x_p)
            y_p = (y_p_pred.detach() / self.temperature).to(self.device)
            pseudo_loss = F.kl_div(F.log_softmax(y_p_pred), F.softmax(y_p))
            report["loss_u"] = pseudo_loss.detach().item()
            loss = loss + self.pseudo_weight(current_epoch) * pseudo_loss
        self.opt.zero_grad()
        loss.backward()
        self.opt.step()
        if self.ema_model is not None:
            self.ema_model.update(self.classifier)
        return report


class PseudoHardLabelClassifierUpdater:
    def __init__(self, *args, **kwargs):
        self.classifier = kwargs.pop("classifier")
        self.opt = kwargs.pop("optimizer")
        self.device = kwargs.pop("device")
        self.ema_model = kwargs.pop("ema_model")
        self.af = kwargs.pop("af")
        self.t1 = kwargs.pop("t1")
        self.t2 = kwargs.pop("t2")
        self.loss = F.cross_entropy
        if isinstance(self.classifier, torch.nn.DataParallel):
            self.num_classes = self.classifier.module.num_classes
        else:
            self.num_classes = self.classifier.num_classes

    def get_batch(self, batch, device=None, non_blocking=True):
        (x, y), (x_p, y_p) = batch
        if x_p.dim() == 5:
            b, ub, c, h, w = x_p.size()
            x_p = x_p.view(b * ub, c, h, w)
            y_p = y_p.view(b * ub)
        return (
            convert_tensor(x, device=device, non_blocking=non_blocking),
            convert_tensor(y, device=device, non_blocking=non_blocking),
            convert_tensor(x_p, device=device, non_blocking=non_blocking),
            convert_tensor(y_p, device=device, non_blocking=non_blocking),
        )

    def pseudo_weight(self, epoch):
        alpha = (epoch - self.t1) / (self.t2 - self.t1) * self.af
        return alpha if epoch < self.t2 else self.af

    def __call__(self, engine, batch):
        report = {}
        current_epoch = engine.state.epoch
        self.classifier.train()
        x, y, x_p, y_p = self.get_batch(batch, device=self.device)
        y_pred = self.classifier(x)
        loss = self.loss(y_pred, y)
        report.update({"y_pred": y_pred.detach()})
        report.update({"y": y.detach()})
        report.update({"loss": loss.detach().item()})
        report.update({"loss_u": 1000.0})
        if self.t1 < current_epoch:
            y_p_pred = self.classifier(x_p)
            y_p = y_p_pred.detach().max(dim=-1)[1].to(self.device)
            pseudo_loss = self.loss(y_p_pred, y_p)
            report["loss_u"] = pseudo_loss.detach().item()
            loss = loss + self.pseudo_weight(current_epoch) * pseudo_loss
        self.opt.zero_grad()
        loss.backward()
        self.opt.step()
        if self.ema_model is not None:
            self.ema_model.update(self.classifier)
        return report


class EntropyMinimizerClassifierUpdater:
    def __init__(self, *args, **kwargs):
        self.classifier = kwargs.pop("classifier")
        self.opt = kwargs.pop("optimizer")
        self.device = kwargs.pop("device")
        self.ema_model = kwargs.pop("ema_model")
        self.lambda_p = kwargs.pop("lambda_p")
        self.loss = F.cross_entropy

    def get_batch(self, batch, device=None, non_blocking=True):
        (x, y), (x_p, y_p) = batch
        if x_p.dim() == 5:
            b, ub, c, h, w = x_p.size()
            x_p = x_p.view(b * ub, c, h, w)
            y_p = y_p.view(b * ub)
        return (
            convert_tensor(x, device=device, non_blocking=non_blocking),
            convert_tensor(y, device=device, non_blocking=non_blocking),
            convert_tensor(x_p, device=device, non_blocking=non_blocking),
            convert_tensor(y_p, device=device, non_blocking=non_blocking),
        )

    def __call__(self, engine, batch):
        report = {}
        self.classifier.train()
        x, y, x_p, _ = self.get_batch(batch, device=self.device)
        y_pred = self.classifier(x)
        loss = self.loss(y_pred, y)
        report.update({"y_pred": y_pred.detach()})
        report.update({"y": y.detach()})
        report.update({"loss": loss.detach().item()})
        y_p_pred = self.classifier(x_p)
        pseudo_entropy_loss = (-F.softmax(y_p_pred) * F.log_softmax(y_p_pred)).sum(dim=1).mean()
        report.update({"loss_u": pseudo_entropy_loss.detach().item()})
        loss = loss + self.lambda_p * pseudo_entropy_loss
        self.opt.zero_grad()
        loss.backward()
        self.opt.step()
        if self.ema_model is not None:
            self.ema_model.update(self.classifier)
        return report


class ConsistencyRegularizationClassifierUpdater:
    def __init__(self, *args, **kwargs):
        self.classifier = kwargs.pop("classifier")
        self.opt = kwargs.pop("optimizer")
        self.device = kwargs.pop("device")
        self.ema_model = kwargs.pop("ema_model")
        self.lambda_p = kwargs.pop("lambda_p")
        self.temperature = kwargs.pop("temperature")
        self.loss = F.cross_entropy
        if isinstance(self.classifier, torch.nn.DataParallel):
            self.num_classes = self.classifier.module.num_classes
        else:
            self.num_classes = self.classifier.num_classes

    def get_batch(self, batch, device=None, non_blocking=True):
        (x, y), (x_p_w, x_p_s) = batch
        if x_p_w.dim() == 5:
            b, ub, c, h, w = x_p_w.size()
            x_p_w = x_p_w.view(b * ub, c, h, w)
            x_p_s = x_p_s.view(b * ub, c, h, w)
            assert x_p_w.size() == x_p_s.size()
        return (
            convert_tensor(x, device=device, non_blocking=non_blocking),
            convert_tensor(y, device=device, non_blocking=non_blocking),
            convert_tensor(x_p_w, device=device, non_blocking=non_blocking),
            convert_tensor(x_p_s, device=device, non_blocking=non_blocking),
        )

    def __call__(self, engine, batch):
        report = {}
        self.classifier.train()
        x, y, x_p_w, x_p_s = self.get_batch(batch, device=self.device)
        y_pred = self.classifier(x)
        loss = self.loss(y_pred, y)
        report.update({"y_pred": y_pred.detach()})
        report.update({"y": y.detach()})
        report.update({"loss": loss.detach().item()})
        logits_p_w = self.classifier(x_p_w)
        y_p = F.softmax((logits_p_w.detach() / self.temperature).to(self.device))
        y_p_pred = F.softmax(self.classifier(x_p_s))
        pseudo_loss = F.mse_loss(y_p_pred, y_p)
        report.update({"loss_u": pseudo_loss.detach().item()})
        loss = loss + self.lambda_p * pseudo_loss
        self.opt.zero_grad()
        loss.backward()
        self.opt.step()
        if self.ema_model is not None:
            self.ema_model.update(self.classifier)
        return report


class FixMatchClassifierUpdater:
    def __init__(self, *args, **kwargs):
        self.classifier = kwargs.pop("classifier")
        self.opt = kwargs.pop("optimizer")
        self.device = kwargs.pop("device")
        self.ema_model = kwargs.pop("ema_model")
        self.lambda_p = kwargs.pop("lambda_p")
        self.temperature = kwargs.pop("temperature")
        self.threshold = kwargs.pop("threshold")
        self.loss = F.cross_entropy
        if isinstance(self.classifier, torch.nn.DataParallel):
            self.num_classes = self.classifier.module.num_classes
        else:
            self.num_classes = self.classifier.num_classes

    def get_batch(self, batch, device=None, non_blocking=True):
        (x, y), (x_p_w, x_p_s) = batch
        if x_p_w.dim() == 5:
            b, ub, c, h, w = x_p_w.size()
            x_p_w = x_p_w.view(b * ub, c, h, w)
            x_p_s = x_p_s.view(b * ub, c, h, w)
            assert x_p_w.size() == x_p_s.size()
        return (
            convert_tensor(x, device=device, non_blocking=non_blocking),
            convert_tensor(y, device=device, non_blocking=non_blocking),
            convert_tensor(x_p_w, device=device, non_blocking=non_blocking),
            convert_tensor(x_p_s, device=device, non_blocking=non_blocking),
        )

    def __call__(self, engine, batch):
        report = {}
        self.classifier.train()
        x, y, x_p_w, x_p_s = self.get_batch(batch, device=self.device)
        y_pred = self.classifier(x)
        loss = self.loss(y_pred, y)
        report.update({"y_pred": y_pred.detach()})
        report.update({"y": y.detach()})
        report.update({"loss": loss.detach().item()})
        logits_p_w = self.classifier(x_p_w)
        y_p = F.softmax((logits_p_w.detach() / self.temperature).to(self.device))
        max_probs, targets_p = torch.max(y_p, dim=-1)
        mask = max_probs.ge(self.threshold).float()
        y_p_pred = self.classifier(x_p_s)
        pseudo_loss = (F.cross_entropy(y_p_pred, targets_p, reduction="none") * mask).mean()
        report.update({"loss_u": pseudo_loss.detach().item()})
        loss = loss + self.lambda_p * pseudo_loss
        self.opt.zero_grad()
        loss.backward()
        self.opt.step()
        if self.ema_model is not None:
            self.ema_model.update(self.classifier)
        return report


class GeneralizedFixMatchClassifierUpdater:
    def __init__(self, *args, **kwargs):
        self.classifier = kwargs.pop("classifier")
        self.opt = kwargs.pop("optimizer")
        self.device = kwargs.pop("device")
        self.ema_model = kwargs.pop("ema_model")
        self.lambda_p = kwargs.pop("lambda_p")
        self.temperature = kwargs.pop("temperature")
        self.threshold = kwargs.pop("threshold")
        self.K = kwargs.pop("K")
        self.loss = F.cross_entropy
        if isinstance(self.classifier, torch.nn.DataParallel):
            self.num_classes = self.classifier.module.num_classes
        else:
            self.num_classes = self.classifier.num_classes

    def get_batch(self, batch, device=None, non_blocking=True):
        (x, y), (x_p_w, x_p_s) = batch
        if x_p_w.dim() == 5:
            b, ub, c, h, w = x_p_w.size()
            x_p_w = x_p_w.view(b * ub, c, h, w)
            x_p_s = x_p_s.view(b * ub, c, h, w)
            assert x_p_w.size() == x_p_s.size()
        return (
            convert_tensor(x, device=device, non_blocking=non_blocking),
            convert_tensor(y, device=device, non_blocking=non_blocking),
            convert_tensor(x_p_w, device=device, non_blocking=non_blocking),
            convert_tensor(x_p_s, device=device, non_blocking=non_blocking),
        )

    def __call__(self, engine, batch):
        report = {}
        self.classifier.train()
        x, y, x_p_w, x_p_s = self.get_batch(batch, device=self.device)
        y_pred = self.classifier(x)
        loss = self.loss(y_pred, y)
        report.update({"y_pred": y_pred.detach()})
        report.update({"y": y.detach()})
        report.update({"loss": loss.detach().item()})
        logits_p_w = self.classifier(x_p_w)
        y_p = F.softmax((logits_p_w.detach() / self.temperature).to(self.device))
        topk_probs, _ = torch.topk(y_p, k=self.K, dim=-1)
        total_probs = torch.sum(topk_probs, dim=-1)
        mask = total_probs.ge(self.threshold).float()
        y_p_pred = F.log_softmax(self.classifier(x_p_s))
        pseudo_loss = (F.kl_div(y_p_pred, y_p, reduction="none") * mask).mean()
        report.update({"loss_u": pseudo_loss.detach().item()})
        loss = loss + self.lambda_p * pseudo_loss
        self.opt.zero_grad()
        loss.backward()
        self.opt.step()
        if self.ema_model is not None:
            self.ema_model.update(self.classifier)
        return report


class UDAClassifierUpdater:
    def __init__(self, *args, **kwargs):
        self.classifier = kwargs.pop("classifier")
        self.opt = kwargs.pop("optimizer")
        self.device = kwargs.pop("device")
        self.ema_model = kwargs.pop("ema_model")
        self.lambda_p = kwargs.pop("lambda_p")
        self.temperature = kwargs.pop("temperature")
        self.threshold = kwargs.pop("threshold")
        self.tsa = kwargs.pop("tsa") if "tsa" in kwargs else False
        self.warmup_epoch = kwargs.pop("warmup_epoch") if "warmup_epoch" in kwargs else 0
        self.max_iter = kwargs.pop("max_iter")
        self.loss = F.cross_entropy
        if isinstance(self.classifier, torch.nn.DataParallel):
            self.num_classes = self.classifier.module.num_classes
        else:
            self.num_classes = self.classifier.num_classes

    def training_signal_annealing(self, current_iter, method="exp"):
        step_ratio = torch.tensor(float(current_iter) / float(self.max_iter))
        if method == "exp":
            scale = 5
            coeff = torch.exp((step_ratio - 1) * scale)
        start = 1.0 / self.num_classes
        end = 1.0
        return coeff * (end - start) + start

    def get_batch(self, batch, device=None, non_blocking=True):
        (x, y), (x_p_w, x_p_s) = batch
        if x_p_w.dim() == 5:
            b, ub, c, h, w = x_p_w.size()
            x_p_w = x_p_w.view(b * ub, c, h, w)
            x_p_s = x_p_s.view(b * ub, c, h, w)
            assert x_p_w.size() == x_p_s.size()
        return (
            convert_tensor(x, device=device, non_blocking=non_blocking),
            convert_tensor(y, device=device, non_blocking=non_blocking),
            convert_tensor(x_p_w, device=device, non_blocking=non_blocking),
            convert_tensor(x_p_s, device=device, non_blocking=non_blocking),
        )

    def __call__(self, engine, batch):
        report = {}
        self.classifier.train()
        x, y, x_p_w, x_p_s = self.get_batch(batch, device=self.device)
        y_pred = self.classifier(x)
        loss = self.loss(y_pred, y, reduction="none")
        report.update({"tsa_threshold": 1.0})
        if self.tsa:
            sup_threshold = self.training_signal_annealing(engine.state.iteration)
            assert 0 <= sup_threshold and sup_threshold <= 1
            report["tsa_threshold"] = sup_threshold.detach().item()
            max_sup_probs, _ = torch.max(y_pred.detach(), dim=-1)
            sup_mask = max_sup_probs.le(sup_threshold).float()
            loss = loss * sup_mask
        loss = loss.mean()
        report.update({"y_pred": y_pred.detach()})
        report.update({"y": y.detach()})
        report.update({"loss": loss.detach().item()})
        report.update({"loss_u": 10000})
        current_epoch = engine.state.epoch
        if self.warmup_epoch < current_epoch:
            logits_p_w = self.classifier(x_p_w)
            y_p = F.softmax((logits_p_w.detach() / self.temperature).to(self.device))
            max_probs, _ = torch.max(y_p, dim=-1)
            mask = max_probs.ge(self.threshold).float()[:, None]
            y_p_pred = self.classifier(x_p_s)
            pseudo_loss = (F.kl_div(F.log_softmax(y_p_pred), y_p, reduction="none") * mask).mean()
            report.update({"loss_u": pseudo_loss.detach().item()})
            loss = loss + self.lambda_p * pseudo_loss
        self.opt.zero_grad()
        loss.backward()
        self.opt.step()
        if self.ema_model is not None:
            self.ema_model.update(self.classifier)
        return report


class GUDAClassifierUpdater:
    def __init__(self, *args, **kwargs):
        self.classifier = kwargs.pop("classifier")
        self.opt = kwargs.pop("optimizer")
        self.device = kwargs.pop("device")
        self.ema_model = kwargs.pop("ema_model")
        self.lambda_p = kwargs.pop("lambda_p")
        self.temperature = kwargs.pop("temperature")
        self.threshold = kwargs.pop("threshold")
        self.K = kwargs.pop("K")
        self.tsa = kwargs.pop("tsa") if "tsa" in kwargs else False
        self.max_iter = kwargs.pop("max_iter")
        self.loss = F.cross_entropy
        if isinstance(self.classifier, torch.nn.DataParallel):
            self.num_classes = self.classifier.module.num_classes
        else:
            self.num_classes = self.classifier.num_classes

    def training_signal_annealing(self, current_iter, method="exp"):
        step_ratio = torch.tensor(float(current_iter) / float(self.max_iter))
        if method == "exp":
            scale = 5
            coeff = torch.exp((step_ratio - 1) * scale)
        start = 1.0 / self.num_classes
        end = 1.0
        return coeff * (end - start) + start

    def get_batch(self, batch, device=None, non_blocking=True):
        (x, y), (x_p_w, x_p_s) = batch
        if x_p_w.dim() == 5:
            b, ub, c, h, w = x_p_w.size()
            x_p_w = x_p_w.view(b * ub, c, h, w)
            x_p_s = x_p_s.view(b * ub, c, h, w)
            assert x_p_w.size() == x_p_s.size()
        return (
            convert_tensor(x, device=device, non_blocking=non_blocking),
            convert_tensor(y, device=device, non_blocking=non_blocking),
            convert_tensor(x_p_w, device=device, non_blocking=non_blocking),
            convert_tensor(x_p_s, device=device, non_blocking=non_blocking),
        )

    def __call__(self, engine, batch):
        report = {}
        self.classifier.train()
        x, y, x_p_w, x_p_s = self.get_batch(batch, device=self.device)
        y_pred = self.classifier(x)
        loss = self.loss(y_pred, y, reduction="none")
        report.update({"tsa_threshold": 1.0})
        if self.tsa:
            sup_threshold = self.training_signal_annealing(engine.state.iteration)
            assert 0 <= sup_threshold and sup_threshold <= 1
            report["tsa_threshold"] = sup_threshold.detach().item()
            max_sup_probs, _ = torch.max(y_pred.detach(), dim=-1)
            sup_mask = max_sup_probs.le(sup_threshold).float()
            loss = loss * sup_mask
        loss = loss.mean()
        report.update({"y_pred": y_pred.detach()})
        report.update({"y": y.detach()})
        report.update({"loss": loss.detach().item()})
        report.update({"loss_u": 10000})
        logits_p_w = self.classifier(x_p_w)
        y_p = F.softmax((logits_p_w.detach() / self.temperature).to(self.device))
        topk_probs, _ = torch.topk(y_p, k=self.K, dim=-1)
        total_probs = torch.sum(topk_probs, dim=-1)
        mask = total_probs.ge(self.threshold).float()[:, None]
        y_p_pred = self.classifier(x_p_s)
        pseudo_loss = (F.kl_div(F.log_softmax(y_p_pred), y_p, reduction="none") * mask).mean()
        report.update({"loss_u": pseudo_loss.detach().item()})
        loss = loss + self.lambda_p * pseudo_loss
        self.opt.zero_grad()
        loss.backward()
        self.opt.step()
        if self.ema_model is not None:
            self.ema_model.update(self.classifier)
        return report


class EUDAClassifierUpdater:
    def __init__(self, *args, **kwargs):
        self.classifier = kwargs.pop("classifier")
        self.opt = kwargs.pop("optimizer")
        self.device = kwargs.pop("device")
        self.ema_model = kwargs.pop("ema_model")
        self.lambda_p = kwargs.pop("lambda_p")
        self.temperature = kwargs.pop("temperature")
        self.threshold = kwargs.pop("threshold")
        self.tsa = kwargs.pop("tsa") if "tsa" in kwargs else False
        self.warmup_epoch = kwargs.pop("warmup_epoch") if "warmup_epoch" in kwargs else 0
        self.max_iter = kwargs.pop("max_iter")
        self.loss = F.cross_entropy
        if isinstance(self.classifier, torch.nn.DataParallel):
            self.num_classes = self.classifier.module.num_classes
        else:
            self.num_classes = self.classifier.num_classes
        self.ent_threshold = torch.log(torch.tensor(self.num_classes).float()) * self.threshold

    def training_signal_annealing(self, current_iter, method="exp"):
        step_ratio = torch.tensor(float(current_iter) / float(self.max_iter))
        if method == "exp":
            scale = 5
            coeff = torch.exp((step_ratio - 1) * scale)
        start = 1.0 / self.num_classes
        end = 1.0
        return coeff * (end - start) + start

    def get_batch(self, batch, device=None, non_blocking=True):
        (x, y), (x_p_w, x_p_s) = batch
        if x_p_w.dim() == 5:
            b, ub, c, h, w = x_p_w.size()
            x_p_w = x_p_w.view(b * ub, c, h, w)
            x_p_s = x_p_s.view(b * ub, c, h, w)
            assert x_p_w.size() == x_p_s.size()
        return (
            convert_tensor(x, device=device, non_blocking=non_blocking),
            convert_tensor(y, device=device, non_blocking=non_blocking),
            convert_tensor(x_p_w, device=device, non_blocking=non_blocking),
            convert_tensor(x_p_s, device=device, non_blocking=non_blocking),
        )

    def __call__(self, engine, batch):
        report = {}
        self.classifier.train()
        x, y, x_p_w, x_p_s = self.get_batch(batch, device=self.device)
        y_pred = self.classifier(x)
        loss = self.loss(y_pred, y, reduction="none")
        report.update({"tsa_threshold": 1.0})
        if self.tsa:
            sup_threshold = self.training_signal_annealing(engine.state.iteration)
            assert 0 <= sup_threshold and sup_threshold <= 1
            report["tsa_threshold"] = sup_threshold.detach().item()
            max_sup_probs, _ = torch.max(y_pred.detach(), dim=-1)
            sup_mask = max_sup_probs.le(sup_threshold).float()
            loss = loss * sup_mask
        loss = loss.mean()
        report.update({"y_pred": y_pred.detach()})
        report.update({"y": y.detach()})
        report.update({"loss": loss.detach().item()})
        report.update({"loss_u": 10000})
        current_epoch = engine.state.epoch
        if self.warmup_epoch < current_epoch:
            logits_p_w = self.classifier(x_p_w)
            t_logits_p_w = (logits_p_w.detach() / self.temperature).to(self.device)
            y_p = F.softmax(t_logits_p_w)
            log_y_p = F.log_softmax(t_logits_p_w)
            entropy = (-y_p * log_y_p).sum(dim=1)
            mask = entropy.le(self.ent_threshold).float()[:, None]
            y_p_pred = self.classifier(x_p_s)
            pseudo_loss = (F.kl_div(F.log_softmax(y_p_pred), y_p, reduction="none") * mask).mean()
            report.update({"loss_u": pseudo_loss.detach().item()})
            loss = loss + self.lambda_p * pseudo_loss
        self.opt.zero_grad()
        loss.backward()
        self.opt.step()
        if self.ema_model is not None:
            self.ema_model.update(self.classifier)
        return report
