import torch
from fid_util import inception_score
from fid_util.fid_score import calculate_fid_given_images, calculate_fid_given_real_ms
import numpy as np
import torch
from torch import distributions
import torchvision


def add_imgs(imgs, outdir, nrow=8):
    # imgs = imgs * torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1).to(imgs.device) + torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1).to(imgs.device)
    imgs = imgs / 2 + 0.5
    imgs.clip_(0.0, 1.0)
    imgs = torchvision.utils.make_grid(imgs, nrow=nrow, pad_value=0.8)
    torchvision.utils.save_image(imgs, outdir, nrow=nrow, pad_value=0.8)


def imgs_grid(imgs, nrow=8):
    # imgs = imgs * torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1).to(imgs.device) + torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1).to(imgs.device)
    imgs = imgs / 2 + 0.5
    imgs.clip_(0.0, 1.0)
    img_out = torchvision.utils.make_grid(imgs, nrow=nrow, pad_value=1)
    return img_out


def get_zdist(dist_name, dim, device=None):
    # Get distribution
    if dist_name == 'uniform':
        low = -torch.ones(dim, device=device)
        high = torch.ones(dim, device=device)
        zdist = distributions.Uniform(low, high)
    elif dist_name == 'gauss':
        mu = torch.zeros(dim, device=device)
        scale = torch.ones(dim, device=device)
        zdist = distributions.Normal(mu, scale)
    else:
        raise NotImplementedError
    # Add dim attribute
    zdist.dim = dim
    return zdist


def get_nsamples(data_loader, N):
    x = []
    y = []
    n = 0
    while n < N:
        ff = next(iter(data_loader))
        x_next, y_next = ff[0], ff[1]
        x.append(x_next)
        y.append(y_next)
        n += x_next.size(0)
    x = torch.cat(x, dim=0)[:N]
    y = torch.cat(y, dim=0)[:N]
    return x, y


def get_ydist(nlabels, device=None):
    logits = torch.zeros(nlabels, device=device)
    ydist = distributions.categorical.Categorical(logits=logits)
    # Add nlabels attribute
    ydist.nlabels = nlabels
    return ydist


class Evaluator(object):
    def __init__(self, zdist, ydist, fid_real_samples=None, batch_size=64,
                 inception_nsamples=10000, device=None, fid_sample_size=10000):
        self.zdist = zdist
        self.ydist = ydist
        self.inception_nsamples = inception_nsamples
        self.batch_size = batch_size
        self.device = device
        self.fid_sample_size = fid_sample_size
        if fid_real_samples is not None:
            self.fid_real_samples = fid_real_samples.numpy()
            self.fid_sample_size = fid_sample_size

    def compute_inception_score(self, task_id=-1, exit_real_ms=False, mask_ratio=None, real_m=0, real_s=0,
                                generator=None):
        generator.eval()
        xtest = torch.tensor(self.fid_real_samples[:self.batch_size]).to(self.device)
        imgs = []
        while (len(imgs) < self.inception_nsamples):
            ztest = self.zdist.sample((self.batch_size, 1,))
            ytest = self.ydist.sample((self.batch_size,))

            samples, _, _, _ = generator(xtest, ztest, label=ytest, mask_ratio=mask_ratio, maskType='mae')

            samples = [s.data.cpu().numpy() for s in samples]
            imgs.extend(samples)

        if self.fid_real_samples.shape[1] == 1:
            self.fid_real_samples = self.fid_real_samples.repeat(3, axis=1)
            imgs = imgs.repeat(3, axis=1)

        inception_imgs = imgs[:self.inception_nsamples]
        score, score_std = inception_score(
            inception_imgs, device=self.device, resize=True, splits=10,
            batch_size=self.batch_size)

        fid_imgs = np.array(imgs[:self.fid_sample_size])
        if exit_real_ms:
            fid = calculate_fid_given_real_ms(real_m, real_s, fid_imgs,
                                              batch_size=self.batch_size,
                                              cuda=True)
        else:
            if self.fid_real_samples is not None:
                fid = calculate_fid_given_images(
                    self.fid_real_samples,
                    fid_imgs,
                    batch_size=self.batch_size,
                    cuda=True)

        return score, score_std, fid

    def create_samples(self, z, y=None, generator=None):
        generator.eval()
        batch_size = z.size(0)
        # Parse y
        if y is None:
            y = self.ydist.sample((batch_size,))
        elif isinstance(y, int):
            y = torch.full((batch_size,), y,
                           device=self.device, dtype=torch.int64)
        # Sample x
        with torch.no_grad():
            x = generator(z)
        return x
