import os
import random
from tqdm import tqdm

import torch
import numpy as np
from PIL import Image

from pytorch_fid import fid_score as FID
# from pytorch_fid.src.fid_score import calculate_fid_given_paths

def compute_fid(model, test_loader, data_dir, batch_size, num_images=512):
    print(f"data_dir: {data_dir}\n batch_size: {batch_size}\n num_images: {num_images}")
    model.eval()
    test_dir = os.path.join(data_dir, "test_images")
    if not os.path.exists(os.path.join(test_dir, str(num_images-1).zfill(6) + '.png')):
        os.makedirs(test_dir, exist_ok=True)
        count = 0
        for bidx, (batch, targets) in enumerate(test_loader):
            batch = torch.stack([batch[i][0] for i in range(batch_size)], 0).detach()
            count = tensor_to_png(batch, test_dir, count, num_images)
            if count >= num_images:
                break
        print(f"{count*batch_size} images are saved")

    # Generate images and save as pngs
    gen_dir = os.path.join(data_dir, 'generated_images')
    os.makedirs(gen_dir, exist_ok=True)
    count = 0
    batch_gen = 128

    _, targets = iter(test_loader).next()
    targets = [{k: v.detach() for k, v in t.items()} for t in targets]
    for _ in tqdm(range(num_images // batch_gen + 1)):
        if count >= num_images:
            break
        with torch.no_grad():
            gen_img = model.sample_simple(gen_dir, targets, n_samples=batch_gen, std=1.0)
        count = tensor_to_png(gen_img, gen_dir, count, num_images)
        # count += batch_size
    print(f"generated {count*batch_gen} images")

    gpu = next(model.parameters()).is_cuda
    fid_value = FID.calculate_fid_given_paths(
        [test_dir, gen_dir], batch_size, gpu, dims=2048)
    # print(f"FID: {fid_value}")

    # model.train()

    return fid_value

def np_to_png(np_images, save_dir, count, stop):
    np_images = np.moveaxis(np_images, 1, 3)
    for i in range(len(np_images)):
        im = Image.fromarray(np.uint8(255*np_images[i]))
        fn = os.path.join(save_dir, str(count).zfill(6) + '.png')
        im.save(fn)
        count += 1
        if count >= stop:
            return count
    return count

def tensor_to_png(tensor, save_dir, count, stop):
    np_images = tensor.cpu().numpy()
    np_images = np.moveaxis(np_images, 1, 3)
    for i in range(len(np_images)):
        im = Image.fromarray(np.uint8(255*np_images[i]))
        fn = os.path.join(save_dir, str(count).zfill(6) + '.png')
        im.save(fn)
        count += 1
        if count >= stop:
            return count
    return count