import pdb
from collections import OrderedDict
from typing import Tuple

import numpy as np
import torch
import torch.nn.functional as F
import util.gan_util as gan_util
from data.transform import DiffRandAug
from ignite.utils import convert_tensor
from kornia.augmentation import Resize
from loss.classification_loss import naive_cross_entropy_loss
from torch.nn import DataParallel
from util.meta_util import bidirectional_gradient_updated_parameters


class MPSClassifierUpdater:
    def __init__(self, *args, **kwargs):
        self.classifier = kwargs.pop("classifier")
        self.generator = kwargs.pop("generator")
        self.finder = kwargs.pop("finder")
        self.optimizer_c = kwargs.pop("optimizer_c")
        self.optimizer_f = kwargs.pop("optimizer_f")
        self.device = kwargs.pop("device")
        self.ema_model = kwargs.pop("ema_model")
        self.lambda_p = kwargs.pop("lambda_p")
        self.lambda_latent = kwargs.pop("lambda_latent") if "lambda_latent" in kwargs else 0.0
        self.latent_reg = kwargs.pop("latent_reg") if "latent_reg" in kwargs else "norm"
        self.r = kwargs.pop("r") if "r" in kwargs else 1e-2
        self.lambda_inner_lr = kwargs.pop("lambda_inner_lr") if "lambda_inner_lr" in kwargs else 1.0
        self.fixed_z_size = kwargs.pop("fixed_z_size") if "fixed_z_size" in kwargs else None
        self.z_pool = self.init_z_pool()
        self.y_pool = self.init_y_pool()
        self.u_accum_count = kwargs.pop("ubatch_ratio")
        self.batchsize_p = kwargs.pop("batchsize_p")
        self.warmup_epoch = kwargs.pop("warmup_epoch")
        self.n_meta_train = kwargs.pop("n_meta_train") if "n_meta_train" in kwargs else 1
        self.resolution = kwargs.pop("resolution")
        self.resizer = Resize(size=self.resolution)
        self.augment = DiffRandAug(num_ops=2, normalized=True)
        self.val_loader = kwargs.pop("val_loader")
        self.val_loader_iter = iter(self.val_loader)
        self.loss = F.cross_entropy
        self.last_loss_mps = 0

    def init_y_pool(self):
        if self.fixed_z_size is None:
            return None
        else:
            cls = self.classifier if (torch.cuda.device_count() < 2) else self.classifier.module
            return torch.randint(0, cls.num_classes, (self.fixed_z_size,))

    def init_z_pool(self):
        if self.fixed_z_size is None:
            return None
        else:
            gen = self.generator if (torch.cuda.device_count() < 2) else self.generator.module
            return torch.empty(self.fixed_z_size, gen.dim_z, dtype=torch.float32).normal_()

    def get_fixed_z(self, n_gen_samples):
        uniform_prob = torch.ones(self.z_pool.shape[0])
        index = uniform_prob.multinomial(n_gen_samples, replacement=True)
        zs = self.z_pool[index].to(self.device)
        ys = self.y_pool[index].to(self.device)
        return zs, ys

    def latent_regularization(self, fz, eps=1e-7):
        if self.latent_reg == "norm":
            return torch.norm(fz, dim=1).mean()
        elif self.latent_reg == "kl":
            approx_mean = torch.mean(fz, dim=1).mean()
            approx_var = torch.var(fz, dim=1).mean()
            kl_div = -0.5 * (1 + torch.log(approx_var + eps) - approx_mean.pow(2) - approx_var)
            return kl_div
        else:
            raise NotImplementedError

    def sample_val_batch(self):
        try:
            batch = next(self.val_loader_iter)
        except StopIteration:
            self.val_loader_iter = iter(self.val_loader)
            batch = next(self.val_loader_iter)
        x_val, y_val = batch
        return (
            convert_tensor(x_val, device=self.device, non_blocking=True),
            convert_tensor(y_val, device=self.device, non_blocking=True),
        )

    def get_batch(self, batch):
        x, y = batch
        return (
            convert_tensor(x, device=self.device, non_blocking=True),
            convert_tensor(y, device=self.device, non_blocking=True),
        )

    def _sample_noize_and_label(self, n_gen_samples):
        gen = self.generator if (torch.cuda.device_count() < 2) else self.generator.module
        if self.z_pool is None:
            z = gan_util.sample_z(gen, n_gen_samples, self.device)
            y = gan_util.sample_categorical_labels(gen.num_classes, n_gen_samples, self.device)
        else:
            z, y = self.get_fixed_z(n_gen_samples)
        return z, y

    def __call__(self, engine, batch):
        report = {}

        # Re-initialize finder for avoiding meta-overfitting
        current_epoch = engine.state.epoch
        lambda_p = self.lambda_p if self.warmup_epoch < current_epoch else 0.0

        self.classifier.train()
        self.finder.train()

        # Get train samples, sample noises and labels
        x, y = self.get_batch(batch)
        batchsize = x.shape[0]
        z_p, y_p = self._sample_noize_and_label(n_gen_samples=self.batchsize_p)

        # 1. Meta-train finder to generate useful samples for classifier
        for _ in range(self.n_meta_train):
            self.classifier.eval()
            x_v, y_v = self.sample_val_batch()
            logit_v, _ = self.classifier(x_v)
            fz_p = self.finder(z_p)
            x_p = self.resizer(self.generator(fz_p, y_p))
            x_p_w, x_p_s = x_p.detach(), self.augment(x_p)
            # Calculate Approximated MPS loss
            loss_val = F.cross_entropy(logit_v, y_v)
            theta_plus, theta_minus, epsilon = bidirectional_gradient_updated_parameters(self.classifier, loss_val, first_order=True, r=self.r)
            images = torch.cat([x_p_w, x_p_s], dim=0)
            _, feat_plus = self.classifier(images, params=theta_plus)
            _, feat_minus = self.classifier(images, params=theta_minus)
            feat_p_w_plus, feat_p_s_plus = torch.split(feat_plus, [self.batchsize_p, self.batchsize_p], dim=0)
            feat_p_w_minus, feat_p_s_minus = torch.split(feat_minus, [self.batchsize_p, self.batchsize_p], dim=0)
            loss_plus = F.mse_loss(feat_p_w_plus, feat_p_s_plus)
            loss_minus = F.mse_loss(feat_p_w_minus, feat_p_s_minus)
            inner_lr = self.lambda_inner_lr * self.optimizer_c.param_groups[0]['lr']
            loss_mps = (loss_minus - loss_plus).div(2 * epsilon).mul(inner_lr)
            # Calculate latent regularization loss
            loss_latent_reg = self.latent_regularization(fz_p)
            loss_all = loss_mps + self.lambda_latent * loss_latent_reg
            self.optimizer_f.zero_grad()
            if self.warmup_epoch < current_epoch:
                loss_all.backward()
                self.optimizer_f.step()
            self.last_loss_mps = loss_mps.detach().item()
            report.update({"loss_mps": self.last_loss_mps})
            report.update({"latent_norm": loss_latent_reg.detach().item()})
            self.classifier.train()
        del loss_all

        # 2. Train classifier with pseudo semi-supervised learning
        # Calculate supervised loss
        with torch.no_grad():
            x_p = self.resizer(self.generator(self.finder(z_p), y_p))
        x_p_w, x_p_s = x_p.detach(), self.augment(x_p).detach()
        images = torch.cat([x, x_p_w, x_p_s], dim=0)
        logit_all, feat_all = self.classifier(images)
        logit_real, _, _ = torch.split(logit_all, [batchsize, self.batchsize_p, self.batchsize_p], dim=0)
        _, feat_p_w, feat_p_s = torch.split(feat_all, [batchsize, self.batchsize_p, self.batchsize_p], dim=0)
        loss_supervised = self.loss(logit_real, y)
        entropy = (-F.softmax(logit_real.detach(), dim=1) * F.log_softmax(logit_real.detach(), dim=1)).sum(dim=1).mean()
        report.update({"y_pred": logit_real.detach()})
        report.update({"y": y.detach()})
        report.update({"loss": loss_supervised.detach()})
        report.update({"entropy": entropy.detach().item()})

        # Calculate unsupervised loss
        loss_pseudo = F.mse_loss(feat_p_w, feat_p_s)
        loss_log = loss_pseudo.detach().item()
        report.update({"loss_pseudo": loss_log})

        # Calculate all losses and update classifier
        loss_target = loss_supervised + lambda_p * loss_pseudo
        self.optimizer_c.zero_grad()
        loss_target.backward()
        self.optimizer_c.step()

        del logit_all, loss_target

        if self.ema_model is not None:
            self.ema_model.update(self.classifier)

        return report
