
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
import argparse
from tensorboardX import SummaryWriter


import sys
import os

currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(os.path.dirname(currentdir))
sys.path.append(parentdir)

from utils.Certified.architectures import ARCHITECTURES
from utils.Certified.datasets import DATASETS

from utils.Certified.utils_ensemble import AverageMeter, accuracy, test, copy_code, requires_grad_
from utils.Certified.datasets import get_dataset
from utils.Certified.architectures import get_architecture
from train.Certified.third_party.smoothadv import SmoothAdv_PGD
from train.Certified.trainer import DRT_Trainer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('dataset', type=str, choices=DATASETS)
parser.add_argument('arch', type=str, choices=ARCHITECTURES)
parser.add_argument('--workers', default=4, type=int, metavar='N',
					help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
					help='number of total epochs to run')
parser.add_argument('--batch', default=100, type=int, metavar='N',
					help='batchsize (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
					help='initial learning rate', dest='lr')
parser.add_argument('--lr_step_size', type=int, default=30,
					help='How often to decrease learning by gamma.')
parser.add_argument('--gamma', type=float, default=0.1,
					help='LR is multiplied by gamma on schedule.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
					help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
					metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--noise_sd', default=0.0, type=float,
					help="standard deviation of Gaussian noise for data augmentation")
parser.add_argument('--print-freq', default=10, type=int,
					metavar='N', help='print frequency (default: 10)')
parser.add_argument('--num-models', type=int, required=True)
# SmoothAdv Training params
parser.add_argument('--resume', action='store_true',
					help='if true, tries to resume training from existing checkpoint')
parser.add_argument('--pretrained-model', type=str, default='',
					help='Path to a pretrained model')
parser.add_argument('--num-noise-vec', default=1, type=int,
					help="number of noise vectors. `m` in the paper.")
parser.add_argument('--adv-training', action='store_true')
parser.add_argument('--epsilon', default=512, type=float)
parser.add_argument('--num-steps', default=4, type=int)
parser.add_argument('--warmup', default=10, type=int, help="Number of epochs over which "
														   "the maximum allowed perturbation increases linearly "
														   "from zero to args.epsilon.")


# DRT Training params
parser.add_argument('--lhs-weights', type=float, required=True)
parser.add_argument('--rhs-weights', type=float, required=True)

args = parser.parse_args()

if args.adv_training:
	mode = f"salman_{args.epsilon}_{args.num_steps}_{args.warmup}_{args.lhs_weights}_{args.rhs_weights}"
else:
	mode = f"cohen_{args.lhs_weights}_{args.rhs_weights}"


args.outdir = f"/{args.dataset}/drt/{mode}/num_{args.num_noise_vec}/noise_{args.noise_sd}"

args.epsilon /= 256.0

if (args.resume):
	args.outdir = "resume" + args.outdir
else:
	args.outdir = "scratch" + args.outdir

args.outdir = "logs/Certified/" + args.outdir

def main():

	if not os.path.exists(args.outdir):
		os.makedirs(args.outdir)

	copy_code(args.outdir)

	train_dataset = get_dataset(args.dataset, 'train')
	test_dataset = get_dataset(args.dataset, 'test')
	pin_memory = (args.dataset == "imagenet")
	train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch,
							  num_workers=args.workers, pin_memory=pin_memory)
	test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch,
							 num_workers=args.workers, pin_memory=pin_memory)

	model = []
	for i in range(args.num_models):
		submodel = get_architecture(args.arch, args.dataset)
		submodel = nn.DataParallel(submodel)
		model.append(submodel)
	print("Model loaded")

	criterion = nn.CrossEntropyLoss().cuda()

	param = list(model[0].parameters())
	for i in range(1, args.num_models):
		param.extend(list(model[i].parameters()))

	optimizer = optim.SGD(param, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
	scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.gamma)

	model_path = os.path.join(args.outdir, 'checkpoint.pth.tar')
	writer = SummaryWriter(args.outdir)

	
	if args.adv_training:
		attacker = SmoothAdv_PGD(steps=args.num_steps, device=device, max_norm=args.epsilon)
	else:
		attacker = None

	if (args.resume):
		base_classifier = "logs/Certified/scratch/" + args.dataset + "/cohen/noise_" + \
						  str(args.noise_sd) + "/checkpoint.pth.tar"
		print(base_classifier)
		for i in range(3):
			checkpoint = torch.load(base_classifier + ".%d" % (i))
			print("Load " + base_classifier + ".%d" % (i))
			model[i].load_state_dict(checkpoint['state_dict'])
			model[i].train()
		print("Loaded...")

	for epoch in range(args.epochs):
		if args.adv_training:
			attacker.max_norm = np.min([args.epsilon, (epoch + 1) * args.epsilon / args.warmup])

		DRT_Trainer(args, train_loader, model, criterion, optimizer, epoch,
									  args.noise_sd, attacker, device, writer)
		test(test_loader, model, criterion, epoch, args.noise_sd, device, writer, args.print_freq)

		scheduler.step(epoch)
		
		for i in range(args.num_models):
			model_path_i = model_path + ".%d" % (i)
			torch.save({
				'epoch': epoch + 1,
				'arch': args.arch,
				'state_dict': model[i].state_dict(),
				'optimizer': optimizer.state_dict(),
			}, model_path_i)



if __name__ == "__main__":
	main()
