import os
import argparse
import types
from tqdm import tqdm, trange
import numpy as np
import torch
from metrics.inception import InceptionV3
from torchvision.datasets import CIFAR10
from datasets.lmdb_datasets import LMDBDataset, CropCelebA
from torchvision import datasets, transforms


def main(args):
    device = torch.device('cuda:0')
    
    if args.dataset == 'celeba_64':
        train_transform = transforms.Compose([
            CropCelebA(),
            transforms.Resize(64),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset = LMDBDataset(root=args.data_dir, name='celeba',
                            train=True, transform=train_transform, is_encoded=True)
    elif args.dataset == 'celeba_128':
        train_transform = transforms.Compose([
            CropCelebA(),
            transforms.Resize(128),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset = LMDBDataset(root=args.data_dir, name='celeba',
                            train=True, transform=train_transform, is_encoded=True)
    elif args.dataset == 'celebahq':
        train_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
        dataset = LMDBDataset(root=args.data_dir, name='celebahq',
                          train=True, transform=train_transform)
    elif args.dataset == 'cifar10':
        dataset = CIFAR10(
        root=args.data_dir, train=True, download=True,
        transform=transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]))
    else:
        raise NotImplementedError

    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=256, shuffle=False,
        drop_last=False)
    def infiniteloop(dataloader):
        while True:
            for x, _ in iter(dataloader):
                yield x
    data_looper = infiniteloop(data_loader)
    images = []
    for _ in trange(len(data_loader)):
        batch_images = next(data_looper)
        images.append((batch_images + 1) / 2)
    images = torch.cat(images, dim=0).numpy()

    if args.num_images is None and isinstance(images, types.GeneratorType):
        raise ValueError(
            "when `images` is a python generator, "
            "`num_images` should be given")

    if args.num_images is None:
        num_images = len(images)
    else:
        num_images = args.num_images

    block_idx1 = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
    block_idx2 = InceptionV3.BLOCK_INDEX_BY_DIM['prob']
    model = InceptionV3([block_idx1, block_idx2]).to(device)
    model.eval()

    fid_acts = np.empty((num_images, 2048))
    is_probs = np.empty((num_images, 1008))

    iterator = iter(tqdm(
        images, total=num_images,
        dynamic_ncols=True, leave=False, disable=False))
    
    start = 0
    while True:
        batch_images = []
        # get a batch of images from iterator
        try:
            for _ in range(args.batch_size):
                batch_images.append(next(iterator))
        except StopIteration:
            if len(batch_images) == 0:
                break
            pass
        batch_images = np.stack(batch_images, axis=0)
        end = start + len(batch_images)

        # calculate inception feature
        batch_images = torch.from_numpy(batch_images).type(torch.FloatTensor)
        batch_images = batch_images.to(device)
        with torch.no_grad():
            pred = model(batch_images)
            fid_acts[start: end] = pred[0].view(-1, 2048).cpu().numpy()
            is_probs[start: end] = pred[1].cpu().numpy()
        start = end

    m = np.mean(fid_acts, axis=0)
    s = np.cov(fid_acts, rowvar=False)
    file_path = os.path.join(args.save_dir, args.dataset + '.npz')
    np.savez(file_path, mu=m, sigma=s)
    

if __name__ == '__main__':
    # python precompute_fid_statistics.py --dataset cifar10
    parser = argparse.ArgumentParser('')
    parser.add_argument('--dataset', type=str, default='cifar10',
                        choices=['cifar10', 'celeba_64', 'celeba_128', 'celebahq'])
    parser.add_argument('--batch_size', type=int, default=50)
    parser.add_argument('--num_images', type=int, default=None)
    parser.add_argument('--data_dir', type=str, default='./data/cifar/')
    parser.add_argument('--save_dir', type=str, default='./data/stats/')
    args = parser.parse_args()

    main(args)
