import pdb
from collections import OrderedDict
from typing import Tuple

import numpy as np
import torch
import torch.nn.functional as F
import util.gan_util as gan_util
from ignite.utils import convert_tensor
from kornia.augmentation import Resize
from loss.classification_loss import naive_cross_entropy_loss
from torch.nn import DataParallel


class ClassifierUpdater:
    def __init__(self, *args, **kwargs):
        self.classifier = kwargs.pop("classifier")
        self.classifier_p = kwargs.pop("classifier_p")
        self.generator = kwargs.pop("generator")
        self.optimizer_c = kwargs.pop("optimizer_c")
        self.optimizer_cp = kwargs.pop("optimizer_cp")
        self.device = kwargs.pop("device")
        self.ema_model = kwargs.pop("ema_model")
        self.lambda_p = kwargs.pop("lambda_p")
        self.u_accum_count = kwargs.pop("ubatch_ratio")
        self.batchsize_p = kwargs.pop("batchsize_p")
        self.warmup_epoch = kwargs.pop("warmup_epoch")
        self.resolution = kwargs.pop("resolution")
        self.resizer = Resize(size=self.resolution)
        self.val_loader = kwargs.pop("val_loader")
        self.val_loader_iter = iter(self.val_loader)
        self.num_classes = self.generator.module.num_classes if isinstance(self.generator, DataParallel) else self.generator.module.num_classes
        self.loss = F.cross_entropy
        self.eps = torch.tensor(1e-6)
        self.last_loss_mps = 0

    def mixup_label(self, label: torch.Tensor) -> torch.Tensor:
        gamma = np.random.beta(self.alpha_mixup, self.alpha_mixup)
        indices = torch.randperm(label.size(0), device=label.device, dtype=torch.long)
        perm_label = label[indices]
        return label.mul(gamma).add(perm_label, alpha=1 - gamma)

    def partial_mixup(self, input: torch.Tensor, gamma: float, indices: torch.Tensor) -> torch.Tensor:
        if input.size(0) != indices.size(0):
            raise RuntimeError("Size mismatch!")
        perm_input = input[indices]
        return input.mul(gamma).add(perm_input, alpha=1 - gamma)

    def mixup_trans(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        gamma = np.random.beta(self.alpha_mixup, self.alpha_mixup)
        indices = torch.randperm(input.size(0), device=input.device, dtype=torch.long)
        return self.partial_mixup(input, gamma, indices), self.partial_mixup(target, gamma, indices)

    def sample_val_batch(self):
        try:
            batch = next(self.val_loader_iter)
        except StopIteration:
            self.val_loader_iter = iter(self.val_loader)
            batch = next(self.val_loader_iter)
        x_val, y_val = batch
        return (
            convert_tensor(x_val, device=self.device, non_blocking=True),
            convert_tensor(y_val, device=self.device, non_blocking=True),
        )

    def get_batch(self, batch):
        x, y = batch
        return (
            convert_tensor(x, device=self.device, non_blocking=True),
            convert_tensor(y, device=self.device, non_blocking=True),
        )

    def _sample_noize_and_label(self, n_gen_samples=None):
        if n_gen_samples is None:
            n_gen_samples = self.n_gen_samples
        gen = self.generator if (torch.cuda.device_count() < 2) else self.generator.module
        z = gan_util.sample_z(gen, n_gen_samples, self.device)
        y = gan_util.sample_categorical_labels(gen.num_classes, n_gen_samples, self.device)
        return z, y

    def __call__(self, engine, batch):
        report = {}
        self.classifier.train()

        # Generate pseudo samples and all logits
        x, y = self.get_batch(batch)
        batchsize = x.shape[0]
        z_p, y_p = self._sample_noize_and_label(n_gen_samples=self.batchsize_p)
        x_p = self.resizer(self.generator(z_p, y_p)).detach()
        images = torch.cat([x, x_p], dim=0)
        logit_all, feat_all = self.classifier(images)
        logit_real, _ = torch.split(logit_all, [batchsize, self.batchsize_p], dim=0)
        _, feat_p = torch.split(feat_all, [batchsize, self.batchsize_p], dim=0)
        logit_p = self.classifier_p(feat_p)

        # Calculate supervised loss
        loss_supervised = self.loss(logit_real, y)
        report.update({"y_pred": logit_real.detach()})
        report.update({"y": y.detach()})
        report.update({"loss": loss_supervised.detach()})

        # Calculate pseudo supervised loss
        loss_pseudo = self.loss(logit_p, y_p)
        lambda_p = self.lambda_p if self.warmup_epoch < engine.state.epoch else 0.0
        loss_pseudo = loss_pseudo
        loss_log = loss_pseudo.detach().item()
        report.update({"loss_pseudo": loss_log})

        # Calculate all losses and update classifier
        loss_target = loss_supervised + lambda_p * loss_pseudo
        self.optimizer_c.zero_grad()
        self.optimizer_cp.zero_grad()
        loss_target.backward()
        self.optimizer_c.step()
        self.optimizer_cp.step()

        if self.ema_model is not None:
            self.ema_model.update(self.classifier)

        return report
