import logging
import math
import os
import pdb
import sys

import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
import umap
from ignite.utils import convert_tensor
from kornia.augmentation import Resize
from tqdm import tqdm

import metric.features as features
import metric.fid as fid
import util.gan_util as gan_util
from metric.preparation import LoadEvalModel
from model.ema import ModelEMA
from util.train_util import accuracy


class ClassifierEvaluator:
    def __init__(self, *args, **kwargs):
        self.classifier = kwargs.pop("classifier")
        self.device = kwargs.pop("device")
        self.ema_model = kwargs.pop("ema_model")

    def get_batch(self, batch, device=None):
        x, y = batch
        return (
            convert_tensor(x, device=device, non_blocking=True),
            convert_tensor(y, device=device, non_blocking=True),
        )

    def __call__(self, engine, batch):
        classifier = self.classifier
        if self.ema_model is not None:
            classifier = self.ema_model.ema
        classifier.eval()
        x, y = self.get_batch(batch, device=self.device)
        with torch.no_grad():
            y_pred = classifier(x)
        if isinstance(y_pred, tuple):
            y_pred = y_pred[0]

        return (y_pred.detach(), y.detach())


class MPSEvaluator:
    def __init__(self, print_epoch=10, num_print_image=64, *args, resolution=224, **kwargs):
        self.classifier = kwargs.pop("classifier")
        self.generator = kwargs.pop("generator")
        self.finder = kwargs.pop("finder")
        self.resize_fn = kwargs.pop("resize_fn")
        self.device = kwargs.pop("device")
        self.batchsize = kwargs.pop("batchsize")
        self.eval_backbone = kwargs.pop("eval_backbone")
        self.eval_size = kwargs.pop("eval_size")
        self.dataset_name = kwargs.pop("dataset_name")
        self.world_size = torch.cuda.device_count()
        self.DDP = kwargs.pop("DDP")
        self.dist = kwargs.pop("dist")
        self.eval_count = 1
        self.print_epoch = print_epoch
        self.num_print_image = num_print_image
        self.eval_model = self._setup_eval_model()
        self.resizer = Resize(size=resolution)
        self.stats_file = None

    def extract_feature(self, images):
        feature_holder = []
        num_batches = int(math.ceil(float(len(images)) / float(self.batchsize)))
        for i in tqdm(range(num_batches)):
            with torch.no_grad():
                start = i * self.batchsize
                end = min(len(images), (i + 1) * self.batchsize)
                _, features = self.classifier(images[start: end].to(self.device))
            feature_holder.append(features)
        feature_holder = torch.cat(feature_holder, 0)
        return feature_holder

    def _setup_eval_model(self):
        return LoadEvalModel(
            eval_backbone=self.eval_backbone,
            resize_fn=self.resize_fn,
            world_size=self.world_size,
            distributed_data_parallel=self.DDP,
            device=self.device,
        )

    def get_batch(self, batch, device=None):
        x, y = batch
        return (
            convert_tensor(x, device=device, non_blocking=True),
            convert_tensor(y, device=device, non_blocking=True),
        )

    def __call__(self, engine, batch):
        # Evaluation of Meta Pseudo Samples
        self.generator.eval()
        self.classifier.eval()
        self.finder.eval()
        fake_list = []
        main_feat_fakes = []
        fake_labels = []

        for i in range(self.eval_size // self.batchsize):
            with torch.no_grad():
                x_fake, y_fake = gan_util.generate_mps(self.generator, self.finder, self.resizer, self.device, batchsize=self.batchsize)
                if isinstance(x_fake, tuple):
                    x_fake = x_fake[0]
                _, main_feat_fake = self.classifier(x_fake)
                fake_list.append(x_fake)
                main_feat_fakes.append(main_feat_fake)
                fake_labels.append(y_fake)
            # Print generated images
            if i == 0 and (self.eval_count % self.print_epoch == 0):
                img_path = os.path.join(self.dist, "generated")
                os.makedirs(img_path, exist_ok=True)
                file_path = os.path.join(img_path, f"image_iter_{self.eval_count}.png")
                row_num = math.ceil(pow(self.num_print_image, 0.5))
                torchvision.utils.save_image(
                    x_fake[: self.num_print_image], file_path, nrow=row_num, normalize=True, scale_each=True
                )
        fake_images = torch.cat(fake_list, dim=0).detach()
        main_feat_fakes = torch.cat(main_feat_fakes, dim=0).detach().cpu().numpy()[:len(batch[0])]
        real_images = batch[0]
        main_feat_reals = self.extract_feature(real_images).cpu().numpy()
        real_labels = np.zeros((len(real_images)))
        fake_labels = np.ones((len(main_feat_fakes)))

        # Create UMAP
        feats = np.concatenate((main_feat_reals, main_feat_fakes), axis=0)
        feats = feats.reshape(feats.shape[0], feats.shape[1])
        labels = np.concatenate((real_labels, fake_labels), axis=0)
        umap_emb = umap.UMAP().fit_transform(feats)
        img_path = os.path.join(self.dist, "umap")
        os.makedirs(img_path, exist_ok=True)
        file_path = os.path.join(img_path, f"umap_iter_{self.eval_count}.pdf")
        plt.figure(figsize=(6, 3), dpi=300)
        sc = plt.scatter(umap_emb[:, 0], umap_emb[:, 1], c=labels, s=10)
        sc.set_rasterized(True)
        plt.colorbar()
        plt.savefig(file_path, format='pdf', dpi=300, bbox_inches='tight', pad_inches=0)

        # Calc FID
        if self.stats_file is None:
            stats_dir = os.path.join("metric", "pre-computed")
            if not os.path.exists(stats_dir):
                os.makedirs(stats_dir)
            self.stats_file = os.path.join(
                stats_dir, "stats_{}_{}k.npz".format(self.dataset_name, len(real_images) // 1000)
            )
        if self.stats_file and os.path.exists(self.stats_file):
            logging.info("Loading existing statistics for real images...")
            f = np.load(self.stats_file)
            real_feats, real_probs = torch.from_numpy(f["feats"][:]), torch.from_numpy(f["probs"][:])
            f.close()

        else:
            logging.info("Calculating statistics for real images...")
            real_feats, real_probs = features.stack_features(
                real_images,
                self.eval_model,
                self.batchsize,
                self.world_size,
                self.DDP,
                self.device,
            )
            np.savez(self.stats_file, feats=real_feats.detach().cpu().numpy(), probs=real_probs.detach().cpu().numpy())

        fake_feats, fake_probs = features.stack_features(
            fake_images,
            self.eval_model,
            self.batchsize,
            self.world_size,
            self.DDP,
            self.device,
        )
        mu1 = np.mean(fake_feats.detach().cpu().numpy().astype(np.float64)[: len(fake_images)], axis=0)
        sigma1 = np.cov(fake_feats.detach().cpu().numpy().astype(np.float64)[: len(fake_images)], rowvar=False)
        mu2 = np.mean(real_feats.detach().cpu().numpy().astype(np.float64)[: len(real_images)], axis=0)
        sigma2 = np.cov(real_feats.detach().cpu().numpy().astype(np.float64)[: len(real_images)], rowvar=False)
        fid_score = fid.frechet_inception_distance(mu1, sigma1, mu2, sigma2)
        logging.info("FID: {fid}".format(fid=fid_score))
        self.eval_count += 1
        return {"FID": fid_score}


class TrailEvaluator(MPSEvaluator):
    def __init__(self, print_epoch=10, num_print_image=64, *args, resolution=224, **kwargs):
        super().__init__(print_epoch, num_print_image, *args, **kwargs)
        self.classes = np.array([0, 7, 64, 98, 184])
        self.fixed_y = np.repeat(self.classes, self.eval_size // len(self.classes)).astype(np.int)

    def __call__(self, engine, batch):
        # Evaluation of Meta Pseudo Samples
        self.generator.eval()
        self.classifier.eval()
        self.finder.eval()
        fake_list = []
        main_feat_fakes = []
        fake_labels = []

        for i in range(self.eval_size // self.batchsize):
            with torch.no_grad():
                start = i * self.batchsize
                end = min((i + 1) * self.batchsize, self.eval_size)
                y_fake = torch.from_numpy(self.fixed_y[start:end]).long()
                z = gan_util.sample_z(self.generator.module, len(y_fake), self.device)
                x_fake = self.resizer(self.generator(self.finder(z), y_fake))
                if isinstance(x_fake, tuple):
                    x_fake = x_fake[0]
                _, main_feat_fake = self.classifier(x_fake)
                main_feat_fakes.append(main_feat_fake)
            # Print generated images
            if i == 0 and (self.eval_count % self.print_epoch == 0):
                img_path = os.path.join(self.dist, "generated")
                os.makedirs(img_path, exist_ok=True)
                file_path = os.path.join(img_path, f"image_iter_{self.eval_count}.png")
                row_num = math.ceil(pow(self.num_print_image, 0.5))
                torchvision.utils.save_image(
                    x_fake[: self.num_print_image], file_path, nrow=row_num, normalize=True, scale_each=True
                )
        real_labels = batch[1].numpy()
        real_index = np.isin(real_labels, self.classes)
        real_labels = real_labels[real_index]
        real_images = batch[0][real_index]
        main_feat_fakes = torch.cat(main_feat_fakes, dim=0).detach().cpu().numpy()
        fake_labels = self.fixed_y + 200
        main_feat_reals = self.extract_feature(real_images).cpu().numpy()

        # Create UMAP
        feats = np.concatenate((main_feat_reals, main_feat_fakes), axis=0)
        feats = feats.reshape(feats.shape[0], feats.shape[1])
        labels = np.concatenate((real_labels, fake_labels), axis=0)
        umap_emb = umap.UMAP().fit_transform(feats)
        img_path = os.path.join(self.dist, "umap")
        os.makedirs(img_path, exist_ok=True)
        file_path = os.path.join(img_path, f"umap_iter_{self.eval_count}.pdf")
        plt.figure(figsize=(6, 3), dpi=300)
        sc = plt.scatter(umap_emb[:, 0][labels < 200], umap_emb[:, 1][labels < 200], c=real_labels, s=10, alpha=0.5)
        sc.set_rasterized(True)
        sc = plt.scatter(umap_emb[:, 0][labels >= 200], umap_emb[:, 1][labels >= 200], c=fake_labels, s=10, marker="^", alpha=0.5)
        sc.set_rasterized(True)
        plt.savefig(file_path, format='pdf', dpi=300, bbox_inches='tight', pad_inches=0)
        self.eval_count += 1
        return {"FID": 1000000}
