import os
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import make_grid
from imageio import imsave
from tqdm import tqdm
from copy import deepcopy
import logging
from itertools import chain
from SNGAN.fid_score_pytorch import calculate_fid
from pathlib import Path


def validate(args, fixed_z, gen_net: nn.Module, writer_dict, train_loader, epoch):
    gen_net = gen_net.eval()
    global_steps = writer_dict['valid_global_steps']
    gen_net = gen_net.eval()
    eval_iter = args.num_eval_imgs // args.eval_batch_size

    # skip IS
    inception_score = 0

    # compute FID
    sample_list = []
    for i in range(eval_iter):
        z = torch.cuda.FloatTensor(np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim)))
        samples = gen_net(z)
        sample_list.append(samples.data.cpu().numpy())

    new_sample_list = list(chain.from_iterable(sample_list))
    fake_image_np = np.concatenate([img[None] for img in new_sample_list], 0)

    real_image_np = []
    for i, (images, _) in enumerate(train_loader):
        real_image_np += [images.data.numpy()]
        batch_size = real_image_np[0].shape[0]
        if len(real_image_np) * batch_size >= fake_image_np.shape[0]:
            break
    real_image_np = np.concatenate(real_image_np, 0)[:fake_image_np.shape[0]]
    fid_score = calculate_fid(real_image_np, fake_image_np, batch_size=300)
    var_fid = fid_score[0][2]
    fid = round(fid_score[0][1], 3)
    print('------------------------fid_score--------------------------')
    print(fid_score)

    # Generate a batch of images
    sample_dir = os.path.join(args.path_helper['sample_path'], 'sample_dir')
    Path(sample_dir).mkdir(exist_ok=True)

    sample_imgs = gen_net(fixed_z).mul_(127.5).add_(127.5).clamp_(0.0, 255.0)
    img_grid = make_grid(sample_imgs, nrow=20).to('cpu', torch.uint8).numpy()
    file_name = os.path.join(sample_dir, f'epoch_{epoch}_fid_{fid}.png')
    imsave(file_name, img_grid.swapaxes(0, 1).swapaxes(1, 2))

    writer_dict['valid_global_steps'] = global_steps + 1
    return inception_score, fid
