import logging

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from scipy import linalg
from tqdm import tqdm

logging.basicConfig(level=logging.INFO)


def convert_to_grayscale_imgs(images):
    """
    Convert from RGB Image ndarrray to Gray-scaled Image ndarray.
    Args:
        RGB images (ndarray): Batch of normalized images of shape (N, H, W, 3).
    Returns:
        gray-scaled images (ndarray): Batch of gray-scaled images of shape (N, H, W, 1).
    """
    ret_imgs = []
    n, h, w, _ = images.shape
    for img in images:
        ret_imgs.append(np.array(Image.fromarray(img).convert("L")))
    return np.concatenate(ret_imgs).reshape((n, h, w, 1))


def normalize_images(images):
    """
    Given a tensor of images, uses the torchvision
    normalization method to convert floating point data to integers. See reference
    at: https://pytorch.org/docs/stable/_modules/torchvision/utils.html#save_image
    The function uses the normalization from make_grid and save_image functions.
    Args:
        images (Tensor): Batch of images of shape (N, 3, H, W).
    Returns:
        ndarray: Batch of normalized images of shape (N, H, W, 3).
    """
    # Shift the image from [-1, 1] range to [0, 1] range.
    min_val = float(images.min())
    max_val = float(images.max())
    images.clamp_(min=min_val, max=max_val)
    images.add_(-min_val).div_(max_val - min_val + 1e-5)

    # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
    images = images.mul_(255).add_(0.5).clamp_(0, 255).permute(0, 2, 3, 1).to(
        'cpu', torch.uint8).numpy()

    return images


def sample_z(gen, n_gen_samples, device):
    if gen.distribution == 'normal':
        return torch.empty(n_gen_samples, gen.dim_z, dtype=torch.float32, device=device).normal_()
    else:
        return torch.empty(n_gen_samples, gen.dim_z, dtype=torch.float32, device=device).uniform_()


def sample_categorical_labels(num_classes, n_gen_samples, device):
    categorical_labels = torch.from_numpy(
        np.random.randint(low=0, high=num_classes, size=(n_gen_samples))
    )
    categorical_labels = categorical_labels.type(torch.long).to(device)
    return categorical_labels


def generate_images(gen, device, batchsize=64):
    gen_module = gen if (not isinstance(gen, torch.nn.DataParallel)) else gen.module
    z = sample_z(gen_module, batchsize, device)
    if gen_module.num_classes > 0:
        y = sample_categorical_labels(gen_module.num_classes, batchsize, device)
    else:
        y = None
    with torch.no_grad():
        fake = gen(z, y)
    return fake


def generate_mps(gen, finder, resize_fn, device, batchsize=64):
    gen_module = gen if (not isinstance(gen, torch.nn.DataParallel)) else gen.module
    z = sample_z(gen_module, batchsize, device)
    if gen_module.num_classes > 0:
        y = sample_categorical_labels(gen_module.num_classes, batchsize, device)
    else:
        y = None
    with torch.no_grad():
        fake = resize_fn(gen(finder(z), y))
    return fake, y


def generate_batches(gen, device, batchsize, img_num):
    X_fake = []
    with torch.no_grad():
        gen.eval()
        n_batches = img_num // batchsize
        for _ in tqdm(range(n_batches)):
            generated = generate_images(gen, device, batchsize)
            if isinstance(generated, tuple):
                generated = generated[0]
            X_fake.append(generated.detach().cpu())
        X_fake = torch.cat(X_fake, dim=0)
    return X_fake


def get_dataset_images(loader, num_samples=50000):
    imgs = []
    img_cnt = 0
    for batch in tqdm(loader):
        imgs.append(batch[0])
        img_cnt += len(batch[0])
        if (img_cnt > num_samples):
            break
    imgs = torch.cat(imgs, dim=0)
    return imgs


def _get_activations(images, model, batch_size=64, dims=2048, device=None):
    model.eval()

    d0 = images.shape[0]
    if batch_size > d0:
        print(('Warning: batch size is bigger than the data size. '
               'Setting batch size to data size'))
        batch_size = d0

    n_batches = d0 // batch_size
    n_used_imgs = n_batches * batch_size

    feat_arr = np.empty((n_used_imgs, dims))
    pred_arr = np.empty((n_used_imgs, 1000))
    for i in tqdm(range(n_batches)):
        start = i * batch_size
        end = start + batch_size

        batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor)
        if device is not None:
            batch = batch.to(device)

        with torch.no_grad():
            feat, pred = model(batch)

        # If model output is not scalar, apply global spatial average pooling.
        # This happens if you choose a dimensionality not equal 2048.
        if feat.shape[2] != 1 or feat.shape[3] != 1:
            feat = F.adaptive_avg_pool2d(feat, output_size=(1, 1))

        feat_arr[start:end] = feat.cpu().numpy().reshape(batch_size, -1)
        pred_arr[start:end] = pred.cpu().numpy().reshape(batch_size, -1)

    return feat_arr, pred_arr


def calc_activation_stats(images, model, batch_size=64, dims=2048, device=None):
    feat, pred = _get_activations(images, model, batch_size, dims, device)
    mu = np.mean(feat, axis=0)
    sigma = np.cov(feat, rowvar=False)
    return mu, sigma, pred


def calc_fid_score(mu1, sigma1, mu2, sigma2, eps=1e-6):
    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, \
        'Training and test mean vectors have different lengths'
    assert sigma1.shape == sigma2.shape, \
        'Training and test covariances have different dimensions'

    diff = mu1 - mu2

    # Product might be almost singular
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = ('fid calculation produces singular product; '
               'adding %s to diagonal of cov estimates') % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean


def calc_inception_score(pred, splits=10):
    scores = np.empty((splits), dtype=np.float32)
    n = len(pred)
    for i in range(splits):
        part = pred[(i * n // splits):((i + 1) * n // splits), :]
        kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
        kl = np.mean(np.sum(kl, 1))
        scores[i] = np.exp(kl)

    return np.mean(scores), np.std(scores)


if __name__ == "__main__":
    ''' Test FID/IS'''
    import argparse
    import os
    import sys

    import yaml
    from torch.utils.data import DataLoader

    import yaml_utils as yaml_utils
    sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
    from model.inception import InceptionV3
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', type=str, help='path to config file')
    parser.add_argument('--static_stat', type=str, help='path to stat file')
    args = parser.parse_args()
    config = yaml_utils.Config(yaml.load(open(args.config_path), Loader=yaml.SafeLoader))
    device = 'cuda:0' if (torch.cuda.is_available()) else 'cpu'

    train = yaml_utils.load_dataset(config)
    loader = DataLoader(train, config.batchsize, shuffle=True)
    imgs = []
    for batch in loader:
        imgs.append(batch[0].numpy())
    imgs = np.concatenate(imgs)
    print(f'dataset shape: {imgs.shape}')
    print(f'min: {np.min(imgs)}, max: {np.max(imgs)}')
    model = InceptionV3().to(device)
    mu1, sigma1, pred = calc_activation_stats(imgs, model, 128, device=device)
    fixed = np.load(args.static_stat)
    mu2, sigma2 = fixed['mean'], fixed['cov']
    fid_score = calc_fid_score(mu1, sigma1, mu2, sigma2)
    is_score = calc_inception_score(pred)
    print(f'FID: {fid_score}')
    print(f'IS: {is_score}')
