from eval import load_params
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import utils as vutils
from torchvision import transforms
import os
import random
import argparse
from tqdm import tqdm

from models import Generator
from operation import load_params, InfiniteSamplerWrapper

noise_dim = 256
device = torch.device('cuda:%d'%(0))

im_size = 1024  
net_ig = Generator( ngf=64, nz=noise_dim, nc=3, im_size=im_size)#, big=args.big )
net_ig.to(device)

epoch = 50000
ckpt = './models/all_%d.pth'%(epoch)
checkpoint = torch.load(ckpt, map_location=lambda a,b: a)
net_ig.load_state_dict(checkpoint['g'])
load_params(net_ig, checkpoint['g_ema'])

batch = 8
noise = torch.randn(batch, noise_dim).to(device)
g_imgs = net_ig(noise)[0]

vutils.save_image(g_imgs.add(1).mul(0.5), 
                    os.path.join('./', '%d.png'%(2)))


transform_list = [
            transforms.Resize((int(256),int(256))),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ]
trans = transforms.Compose(transform_list)
data_root = '../../images/skulls'
dataset = ImageFolder(root=data_root, transform=trans)

import lpips
percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True)

the_image = g_imgs[0].unsqueeze(0)
def find_closest(the_image):
    the_image = F.interpolate(the_image, size=256)
    small = 100
    close_image = None
    for i in tqdm(range(len(dataset))):
        real_iamge = dataset[i][0].unsqueeze(0).to(device)

        dis = percept(the_image, real_iamge).sum()
        if dis < small:
            small = dis
            close_image = real_iamge
    return close_image, small

all_dist = []
batch = 8
result_path = 'nn_track'
import os
os.makedirs(result_path, exist_ok=True)
for j in range(4):
    with torch.no_grad():
        noise = torch.randn(batch, noise_dim).to(device)
        g_imgs = net_ig(noise)[0]

    for n in range(batch):
        the_image = g_imgs[n].unsqueeze(0)

        close_0, dis = find_closest(the_image)
        
        vutils.save_image(torch.cat([F.interpolate(the_image,256), close_0]).add(1).mul(0.5), \
            result_path+'/nn_%d.jpg'%(j*batch+n))
        all_dist.append(dis.view(1))

new_all_dist = []
for v in all_dist:
    new_all_dist.append(v.view(1))
print(torch.cat(new_all_dist).mean())