import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils
from torch.autograd import Variable
from vae import VAE
from utils import get_dataset

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)


def main(args):
    channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, loader_train_dict, class_map, class_map_inv = get_dataset(args.dataset, args.data_path, args.batch_real, args.subset, args=args)

    args.im_size = im_size
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    model = VAE("test", im_size[0], channel, kernel_num=128, z_size=128)
    model.to(args.device)

    optimizer_aug = torch.optim.Adam(model.parameters(), lr=args.lr_aug, weight_decay=1e-5)
    optimizer_aug.zero_grad()

    # print('%s training begins'%get_time())

    for it in range(0, args.Iteration+1):
        trainloader = torch.utils.data.DataLoader(dst_test, batch_size=256, shuffle=True, num_workers=1)

        for img, _ in trainloader:
            optimizer_aug.zero_grad()
            img = Variable(img).to(args.device)
            (mean, logvar), img_recon = model(img)
            recon_loss = model.reconstruction_loss(img_recon, img) 
            kl_loss = model.kl_divergence_loss(mean, logvar)
            loss = recon_loss + kl_loss
            loss.backward()
            optimizer_aug.step()

        print(it, loss.data)

    torch.save(model.state_dict(), 'checkpoints/vae_epoch_30.pth')

    rand_index = torch.randint(high=img.shape[0], size=(20, ))
    for i in rand_index:
        reconstruction_test = img_recon[i]
        test = img[i]
        torchvision.utils.save_image(reconstruction_test, f'data/{i}_recon.png')
        torchvision.utils.save_image(test, f'data/{i}.png')

    sample_img = model.sample(10)
    for i in range(10):
        torchvision.utils.save_image(sample_img[i], f'data/sample_{i}.png')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Parameter Processing')

    parser.add_argument('--lr_aug', type=float, default=3e-04, help='learning rate for updating augmentation')
    parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
    parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data')
    parser.add_argument('--data_path', type=str, default='data', help='dataset path')
    parser.add_argument('--subset', type=str, default='imagenette', help='ImageNet subset. This only does anything when --dataset=ImageNet')
    parser.add_argument('--zca', action='store_true', help="do ZCA whitening")
    parser.add_argument('--Iteration', type=int, default=5, help='how many distillation steps to perform')


    args = parser.parse_args()

    main(args)




