import argparse
from tqdm import trange

import torch

from data import get_cl_dataset, get_dataset, TensorDataset
from generator import SyntheticImageGenerator
from utils import default_args

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Parameter Processing')   

    # data
    parser.add_argument('--data_path', type=str, default='ANONYMIZED')
    parser.add_argument('--dataset', type=str, default='ImageNet10')
    parser.add_argument('--phase', type=int, default=0)
    
    # hparms for ours
    parser.add_argument('--ae_iteration', type=int, default=2000)
    parser.add_argument('--lr_ae', type=float, default=1e-2)
    parser.add_argument('--ipc', type=int, default=1)
    parser.add_argument('--hdims', type=list, default=[])
    parser.add_argument('--num_seed_vec', type=int, default=16)
    parser.add_argument('--num_decoder', type=int, default=8)

    parser.add_argument('--gpu_id', type=int, default=0)

    args = parser.parse_args()
    
    args.device = torch.device(f"cuda:{args.gpu_id}")

    default_args(args)    
    
    if args.dataset == "CIFAR100_cl":
        ''' data set '''
        import pickle
        with open("./cifar100_order.pkl", "rb") as f:
            order = pickle.load(f)
        num_classes = 20
        class_map = order[args.phase*num_classes:(args.phase+1)*num_classes].tolist()
        channel, im_size, _, normalize, images_all, indices_class, testloader = get_cl_dataset(class_map, args.data_path)
    else:
        ''' data set '''
        channel, im_size, num_classes, normalize, images_all, indices_class, testloader = get_dataset(args.dataset, args.data_path)

    ''' initialize '''
    generator = SyntheticImageGenerator(
            num_classes, im_size, args.num_seed_vec, args.num_decoder, args.hdims,
            args.kernel_size, args.stride, args.padding).to(args.device)

    optimizer_ae = torch.optim.Adam(generator.parameters(), lr=args.lr_ae)
    scheduler_ae = torch.optim.lr_scheduler.MultiStepLR(
        optimizer_ae, milestones=[int(0.5*args.ae_iteration)], gamma=0.1)
    img_real_dataloader = torch.utils.data.DataLoader(
        TensorDataset(images_all.detach().cpu()), batch_size=256, shuffle=True, num_workers=8, drop_last=True)
    img_real_iter = iter(img_real_dataloader)
    for i in trange(1, args.ae_iteration+1):
        try:
            img_real = next(img_real_iter)
        except StopIteration:
            img_real_iter = iter(img_real_dataloader)
            img_real = next(img_real_iter)
        
        img_real = img_real.to(args.device)
        loss = generator.autoencoder(img_real)
        optimizer_ae.zero_grad()
        loss.backward()
        optimizer_ae.step()
        scheduler_ae.step()

        if i % 1000 == 0:
            print(f'pretrain step {i}: {loss.item()}')
    
    if args.dataset == "CIFAR100_cl":
        save_name = f'./pretrained_ae/CIFAR100_cl_{args.phase}_default.pth'
        torch.save(generator.state_dict(), save_name)
    else:
        save_name = f'./pretrained_ae/{args.dataset}_{args.ipc}_default.pth'
        torch.save(generator.state_dict(), save_name)
