import pdb

import torch
import torch.nn.functional as F
import util.gan_util as gan_util
from data.transform import DiffRandAug
from ignite.utils import convert_tensor
from kornia.augmentation import Resize


class ClassifierUpdater:
    def __init__(self, *args, **kwargs):
        self.classifier = kwargs.pop("classifier")
        self.generator = kwargs.pop("generator")
        self.finder = kwargs.pop("finder")
        self.optimizer_c = kwargs.pop("optimizer_c")
        self.optimizer_f = kwargs.pop("optimizer_f")
        self.device = kwargs.pop("device")
        self.ema_model = kwargs.pop("ema_model")
        self.lambda_p = kwargs.pop("lambda_p")
        self.batchsize_p = kwargs.pop("batchsize_p")
        self.temperature = kwargs.pop("temperature")
        self.threshold = kwargs.pop("threshold")
        self.warmup_epoch = kwargs.pop("warmup_epoch")
        self.resolution = kwargs.pop("resolution")
        self.resizer = Resize(size=self.resolution)
        self.loss = F.cross_entropy
        self.augment = DiffRandAug(num_ops=2, normalized=True)

    def get_batch(self, batch):
        x, y = batch
        return (
            convert_tensor(x, device=self.device, non_blocking=True),
            convert_tensor(y, device=self.device, non_blocking=True),
        )

    def _sample_noize_and_label(self, n_gen_samples=None):
        if n_gen_samples is None:
            n_gen_samples = self.n_gen_samples
        gen = self.generator if (torch.cuda.device_count() < 2) else self.generator.module
        z = gan_util.sample_z(gen, n_gen_samples, self.device)
        y = gan_util.sample_categorical_labels(gen.num_classes, n_gen_samples, self.device)
        return z, y

    def __call__(self, engine, batch):
        report = {}
        self.classifier.train()

        # Generate pseudo samples and all logits
        x, y = self.get_batch(batch)
        batchsize = x.shape[0]
        z_p, y_p = self._sample_noize_and_label(n_gen_samples=self.batchsize_p)
        x_p = self.resizer(self.generator(z_p, y_p))
        x_p_w, x_p_s = x_p.detach(), self.augment(x_p).detach()
        images = torch.cat([x, x_p_w, x_p_s], dim=0)
        logit_all = self.classifier(images)
        logit_real, logit_p_w, logit_p_s = torch.split(logit_all, [batchsize, self.batchsize_p, self.batchsize_p], dim=0)

        # Calculate supervised loss
        loss_supervised = self.loss(logit_real, y)
        report.update({"y_pred": logit_real.detach()})
        report.update({"y": y.detach()})
        report.update({"loss": loss_supervised.detach().item()})

        # Calculate unsupervised loss
        pseudo_labels = F.softmax((logit_p_w.detach() / self.temperature).to(self.device), dim=-1)
        max_probs, _ = torch.max(pseudo_labels, dim=-1)
        mask = max_probs.ge(self.threshold).float()[:, None]
        report.update({"u_acceptance_rate": mask.count_nonzero().item() / len(mask)})
        loss_u = (F.kl_div(F.log_softmax(logit_p_s, dim=-1), pseudo_labels, reduction="none") * mask).mean()
        lambda_p = self.lambda_p if self.warmup_epoch < engine.state.epoch else 0.0
        loss_log = loss_u.detach().item()
        report.update({"loss_pseudo": loss_log})

        loss_target = loss_supervised + lambda_p * loss_u
        self.optimizer_c.zero_grad()
        loss_target.backward()
        self.optimizer_c.step()

        del x_p_w, x_p_s, logit_p_w, logit_p_s, logit_all, pseudo_labels

        if self.ema_model is not None:
            self.ema_model.update(self.classifier)
        return report
