from attack import *
from models import PytorchModel
import torch
#from wideresnet import *
import os, argparse
import numpy as np
import json

from allmodels import MNIST, load_model, load_mnist_data, load_cifar10_data, CIFAR10, load_imagenet_train, load_imagenet_test
from allmodels import train_mnist, test_mnist, train_cifar10, test_cifar10, train_imagenet, test_imagenet

from torchvision.models import resnet50
from utils import get_time_stamp, report

parser = argparse.ArgumentParser()
parser.add_argument('--config', default='victim_models/config/train_Imagenet_Resnet50.json', help='config file')
parser.add_argument('--gpuid', nargs='+', type=str, default="0")
parser.add_argument('--adversary', action='store_true', default=False)

args = vars(parser.parse_args())

with open(args['config']) as config_file:
    state = json.load(config_file)

np.random.seed(state['seed'])
torch.manual_seed(state['seed'])

if state["reload_path"] == "": state["reload_path"] = None

run_stamp = get_time_stamp()
if state["reload_path"]:
    run_stamp = f"{run_stamp}_reload"

state['ckpt_path'] = os.path.join(state['ckpt_path'], state['dataset'], run_stamp)
state['report_path'] = os.path.join(state['ckpt_path'], state['report_name'])

if not os.path.isdir(state['ckpt_path']):
    os.makedirs(state['ckpt_path'])
    
report(state['report_path'], f"Initialized report at {state['report_path']}.")

os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(args['gpuid']) if type(args['gpuid']) is list else args['gpuid']
report(state["report_path"], f"GPU ID list: {os.environ['CUDA_VISIBLE_DEVICES']}")

if state['dataset'] == "MNIST":
    hyperparams = {
        'num_epochs': 50,
        'batch_size': 128,
        'lr': 0.001,
        'momentum': 0.9,
        'save_freq': 25,
        'ckpt_dir': state['ckpt_path'],
        'seed': state['seed']
    }

    net = MNIST()
    net = torch.nn.DataParallel(net).cuda()
    if args['adversary']:
        train_loader, test_loader, _, _ = load_mnist_data(state, mode='generator')
    else:
        train_loader, test_loader, train_dataset, test_dataset = load_mnist_data(state, mode='victim')
        
    train_mnist(net, train_loader, hyperparams)
    test_mnist(net, test_loader)

elif state['dataset'] == 'CIFAR10':
    net = CIFAR10()
    net = torch.nn.DataParallel(net).cuda()
    if args['adversary']:
        train_loader, test_loader, _, _ = load_cifar10_data(state, mode='generator')
    else:
        train_loader, test_loader, train_dataset, test_dataset = load_cifar10_data(state, mode='victim')
        
    hyperparams = {
        'num_epochs': 50,
        'batch_size': 128,
        'lr': 0.001,
        'momentum': 0.9,
        'save_freq': 25,
        'test_freq': 5,
        'ckpt_dir': state['ckpt_path'],
        'seed': state['seed']
    }

    train_cifar10(net, train_loader, test_loader, hyperparams)
    # test_cifar10(net, test_loader)

elif state['dataset'] == 'Imagenet':
    if not os.path.exists(state['dataset_path']):
        raise IOError(f"Dataset folder does not exist: {state['dataset_path']}")

    # net = VGG_plain('VGG16', 1000)
    net = resnet50(pretrained=False)
    if state['reload_path']:
        net.load(torch.load(state['reload_path']))
        report(state['report_path'], f"Reloaded from {state['reload_path']}.")
        
    net = torch.nn.DataParallel(net).cuda()
    if args['adversary']:
        train_loader = load_imagenet_generator(state, normalize=False)
    else:
        train_loader = load_imagenet_train(state, mode='victim', normalize=True)
        
    test_loader = load_imagenet_test(state, normalize=True)

    train_imagenet(net, train_loader, test_loader, state)
    # test_imagenet(net, test_loader)

else:
    raise ValueError("Unsupport dataset")
