import argparse
import torch
import numpy as np
from PIL import Image
from PIL import ImageFilter
from pgd_attack import perturb
from pathlib import Path
from models import *
from eval_utils import load_data
import torch.nn as nn
import os
import torchvision.transforms as transforms
from dataset import get_imagenet128_val_dataset
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, required=True, choices=['cifar10', 'celebahq', 'bedroom'])
parser.add_argument('--batch_size', type=int, default=200)
args = parser.parse_args()

samples = 10000 if args.dataset == 'cifar10' else 3000

device = 'cuda'



outdir = {'cifar10': 'FID/cifar10-gen',
          'celebahq': 'FID/celebahq-gen',
          'bedroom': 'FID/bedroom-gen'}[args.dataset]

if args.dataset == 'celebahq':
    seed_file = 'img_seed_celebahq128_betavae.npy'
    seed = np.load(seed_file)[:samples]
    seed = seed.transpose([0, 3, 1, 2]).astype(np.float32) / 255
    seed = torch.from_numpy(seed)
else:
    data_size = 32 if args.dataset == 'cifar10' else 128
    transform = transforms.Compose([transforms.Resize(data_size), transforms.ToTensor()])
    dataset = get_imagenet128_val_dataset('./datasets', transform)
    seed = load_data(dataset, samples)

    ToPILImage = transforms.ToPILImage()
    ToTensor = transforms.ToTensor()
    seed_blur = []
    radius = 3 if args.dataset == 'cifar10' else 13
    for x in seed:
        img = ToPILImage(x)
        img = img.filter(ImageFilter.GaussianBlur(radius=radius))
        seed_blur.append(ToTensor(img))
    seed = torch.stack(seed_blur, axis=0)


celebahq128_gen_config = dict(norm='L2', eps=40, steps=100, step_size=1.2)
bedroom128_gen_config = dict(norm='L2', eps=70, steps=400, step_size=0.8)
cifar10_gen_config = dict(norm='L2', eps=15, steps=200, step_size=0.15)


configs = {'cifar10': cifar10_gen_config, 'bedroom': bedroom128_gen_config, 'celebahq': celebahq128_gen_config}
config = configs[args.dataset]

if args.dataset == 'cifar10':
    model = ResNet3BN(num_classes=10)
else:
    model = ResNet5BN(num_classes=1000)

checkpoint = {'cifar10': 'experiments/cifar10/steps00040/model.pth',
              'celebahq': 'experiments/celebahq128/steps00080/model.pth',
              'bedroom': 'experiments/bedroom128/steps00055/model.pth'}[args.dataset]


normalization = 'cifar10' if args.dataset == 'cifar10' else 'imagenet'

model = nn.DataParallel(model)
print(checkpoint)
model.load_state_dict(torch.load(checkpoint))
model.to(device)

model.eval()

generated = []
for i in tqdm(range(0, seed.shape[0], args.batch_size)):
    out = perturb(model, seed[i:i+args.batch_size].to(device), normalization=normalization, **config)
    generated.append(out.cpu().numpy())
generated = (np.concatenate(generated)*255).astype(np.uint8)
generated = generated.transpose([0, 2, 3, 1])

for i in range(generated.shape[0]):
    save_filename = os.path.join(outdir, '{:05d}.png'.format(i))
    Image.fromarray(generated[i]).save(save_filename)
    if i % 100:
        print('saved {}'.format(save_filename))


