import copy
import torchvision
import torchvision.transforms as transforms
import torch
import torch.utils.data
from torch import nn
import torchattacks
from tqdm import tqdm
import logging
import random
import os,sys
import numpy as np
import argparse
import loss_functions
from torch.optim.lr_scheduler import MultiStepLR
import time
from datetime import timedelta
from logging import getLogger
import utils
from models.resnet import resnet18
from models.resnet_g import resnet18 as resnet18_g
from models.wrn import WideResNet
import dataloader
from functorch.experimental import replace_all_batch_norm_modules_

parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]')
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet',
					help='model architecture')
parser.add_argument('--dataset', default='cifar10', type=str,
					help='which dataset used to train')
parser.add_argument('--num_classes', default=10, type=int, metavar='N',
					help='number of classes')
parser.add_argument('--epochs', default=200, type=int, metavar='N',
					help='number of total epochs to run')
parser.add_argument('-b', '--batch_size', default=128, type=int,
					metavar='N',
					help='mini-batch size (default: 256), this is the total '
						 'batch size of all GPUs on the current node when '
						 'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
					metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
					help='momentum of SGD solver')
parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float,
					metavar='W', help='weight decay (default: 1e-4)',
					dest='wd')
parser.add_argument('--save', default='fgsm.pkl', type=str,
					help='model save name')
parser.add_argument('--seed', type=int,
					default=0, help='random seed')
parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.')
parser.add_argument('--eps', type=float, default=8./255., help='perturbation bound')
parser.add_argument('--ns', type=int, default=10, help='maximum perturbation step K')
parser.add_argument('--ss', type=float, default=2./255., help='step size')
parser.add_argument('--beta', type=float, default=6.0)


parser.add_argument('--exp', default='fgsm', type=str,
					help='exp name')
parser.add_argument('--method', default='fgsm', type=str,
					help='AT method to use')

#TE Settings
parser.add_argument('--te-alpha', default=0.9, type=float,
					help='momentum term of self-adaptive training')
parser.add_argument('--start-es', default=90, type=int,
					help='start epoch of self-adaptive training (default 0)')
parser.add_argument('--end-es', default=150, type=int,
					help='start epoch of self-adaptive training (default 0)')
parser.add_argument('--reg-weight', default=300, type=float)

##BN
parser.add_argument('--if_bn', default=0, type=int)

##Gaussian Init
parser.add_argument('--if_g', default=0, type=int)

##Data Aug
parser.add_argument('--if_aug', default=0, type=int)
parser.add_argument('--if_same', default=0, type=int)
args = parser.parse_args()

if args.dataset == 'cifar10':
	args.num_classes = 10
else:
	args.num_classes = 100

def sigmoid_rampup(current, start_es, end_es):
	"""Exponential rampup from https://arxiv.org/abs/1610.02242"""
	if current < start_es:
		return 0.0
	if current > end_es:
		return 1.0
	else:
		import math
		phase = 1.0 - (current - start_es) / (end_es - start_es)
		return math.exp(-5.0 * phase * phase)

#os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
utils.setup_seed(args.seed)


logger = getLogger()
if not os.path.exists(args.dataset+'/'+ args.arch +'/'+args.exp):
	os.makedirs(args.dataset+'/'+ args.arch +'/'+args.exp)
logger = utils.create_logger(
	os.path.join(args.dataset+'/'+ args.arch +'/'+args.exp + '/', args.exp + ".log"), rank=0
)
logger.info("============ Initialized logger ============")
logger.info(
	"\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))
)
args.save = args.dataset+'/'+ args.arch +'/'+args.exp + '/' +  args.save


wd=args.wd
learning_rate=args.lr
epochs=args.epochs
batch_size=args.batch_size
torch.backends.cudnn.benchmark = True

if args.if_aug == 0:
	transform=transforms.Compose([
								  torchvision.transforms.ToTensor(),
								  ])
else:
	transform=transforms.Compose([transforms.RandomCrop(32, padding=4),
								  transforms.RandomHorizontalFlip(),
								  torchvision.transforms.ToTensor(),
								  ])
transform_test=transforms.Compose([torchvision.transforms.Resize((32,32)),
								   transforms.ToTensor(),
								   ])

def data_aug(x):
	x = transforms.RandomCrop(32, padding=4)(x)
	x = transforms.RandomHorizontalFlip()(x)
	return x

data = dataloader.Data(args.dataset, './data')
trainset, testset = data.data_loader(transform, transform_test)

train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                                shuffle=True, drop_last=False, num_workers=0)

test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                                shuffle=False, drop_last=False, num_workers=0)


if args.if_g == 0:
	n = resnet18(num_classes=args.num_classes).cuda()
else:
	n = resnet18_g(num_classes=args.num_classes).cuda()



if args.if_bn == 0:
	replace_all_batch_norm_modules_(n)
	print('BN is disable to record running mean and std')


optimizer = torch.optim.SGD(n.parameters(),momentum=args.momentum,
							lr=learning_rate,weight_decay=wd)


milestones = [int(args.epochs * 0.5), int(args.epochs * 0.75)]

scheduler = MultiStepLR(optimizer,milestones=milestones,gamma=args.gamma)

data_size = len(trainset.data)
targets = np.asarray(trainset.targets)
num_classes = args.num_classes
targets_np = np.asarray(targets)

if args.method == 'te':
	pgd_te = loss_functions.PGD_TE(num_samples=data_size,
					   num_classes=args.num_classes,
					   momentum=args.te_alpha,
					   step_size=args.ss,
					   epsilon=args.eps,
					   perturb_steps=args.ns,
					   norm='linf',
					   es=args.start_es)

train_clean_acc = []
train_adv_acc = []
test_clean_acc = []
test_adv_acc = []
best_eval_acc = 0.0

for epoch in range(epochs):
	rampup_rate = sigmoid_rampup(epoch+1, args.start_es, args.end_es)
	weight = rampup_rate * args.reg_weight

	loadertrain = tqdm(train_loader, desc='{} E{:03d}'.format('train', epoch), ncols=0)
	epoch_loss = 0.0
	total = 0.0
	clean_acc = 0.0
	adv_acc = 0.0
	for x_train, y_train, idx in loadertrain:
		n.eval()
		x_train, y_train = x_train.cuda(), y_train.cuda()
		if args.if_same == 1:
			if args.if_aug == 0:
				x_train = data_aug(x_train.cpu()).cuda()
			else:
				ValueError('Wrong args in if_aug and aug_same')
		y_pre = n(x_train)
		if args.method == 'pgd':
			logits_adv, loss = loss_functions.PGD(n, x_train, y_train, optimizer, args)
		elif args.method == 'fgsm_of':
			logits_adv, loss = loss_functions.FGSM_overfitting(n, x_train, y_train, optimizer, args)
		elif args.method == 'te':
			logits_adv, loss = pgd_te(x_train, y_train, idx, epoch+1, n, optimizer, weight)
		elif args.method == 'normal':
			n.train()
			logits_adv = n(x_train)
			optimizer.zero_grad()
			loss = torch.nn.functional.cross_entropy(logits_adv, y_train)
		elif args.method == 'trades':
			logits_adv, loss = loss_functions.TRADES(n, x_train, y_train, optimizer, args)
		loss.backward()
		optimizer.step()
		epoch_loss += loss.data.item()
		_, predicted = torch.max(y_pre.data, 1)
		_, predictedadv = torch.max(logits_adv.data, 1)
		total += y_train.size(0)
		clean_acc += predicted.eq(y_train.data).cuda().sum()
		adv_acc += predictedadv.eq(y_train.data).cuda().sum()
		fmt = '{:.4f}'.format
		loadertrain.set_postfix(loss=fmt(loss.data.item()),
								acc_cl=fmt(clean_acc.item() / total * 100),
								acc_adv=fmt(adv_acc.item() / total * 100))
	train_clean_acc.append(clean_acc.item() / total * 100)
	train_adv_acc.append(adv_acc.item() / total * 100)
	scheduler.step()

	if (epoch) % 1 == 0:
		Loss_test = nn.CrossEntropyLoss().cuda()
		test_loss_cl = 0.0
		test_loss_adv = 0.0
		correct_cl = 0.0
		correct_adv = 0.0
		total = 0.0
		n.eval()
		pgd_eval = torchattacks.PGD(n, eps=8.0/255.0, steps=20)
		loadertest = tqdm(test_loader, desc='{} E{:03d}'.format('test', epoch), ncols=0)
		with torch.enable_grad():
			for x_test, y_test, idx in loadertest:
				x_test, y_test = x_test.cuda(), y_test.cuda()
				x_adv = pgd_eval(x_test, y_test)
				n.eval()
				y_pre = n(x_test)
				y_adv = n(x_adv)
				loss_cl = Loss_test(y_pre, y_test)
				loss_adv = Loss_test(y_adv, y_test)
				test_loss_cl += loss_cl.data.item()
				test_loss_adv += loss_adv.data.item()
				_, predicted = torch.max(y_pre.data, 1)
				_, predicted_adv = torch.max(y_adv.data, 1)
				total += y_test.size(0)
				correct_cl += predicted.eq(y_test.data).cuda().sum()
				correct_adv += predicted_adv.eq(y_test.data).cuda().sum()
				fmt = '{:.4f}'.format
				loadertest.set_postfix(loss_cl=fmt(loss_cl.data.item()),
									   loss_adv=fmt(loss_adv.data.item()),
									   acc_cl=fmt(correct_cl.item() / total * 100),
									   acc_adv=fmt(correct_adv.item() / total * 100))
			test_clean_acc.append(correct_cl.item() / total * 100)
			test_adv_acc.append(correct_adv.item() / total * 100)
		if correct_adv.item() / total * 100 > best_eval_acc:
			best_eval_acc = correct_adv.item() / total * 100
			checkpoint = {
					'state_dict': n.state_dict(),
					'epoch': epoch
				}
			torch.save(checkpoint, args.save+ 'best.pkl')
		if (epoch + 1) % 1 == 0:
			checkpoint = {
					'state_dict': n.state_dict(),
					'epoch': epoch
				}
			torch.save(checkpoint, args.save + '%d.pkl'%(epoch+1))
checkpoint = {
			'state_dict': n.state_dict(),
			'epoch': epoch
		}
torch.save(checkpoint, args.save + 'last.pkl')
np.save(args.save+'_train_acc_cl.npy', train_clean_acc)
np.save(args.save+'_train_acc_adv.npy', train_adv_acc)
np.save(args.save+'_test_acc_cl.npy', test_clean_acc)
np.save(args.save+'_test_acc_adv.npy', test_adv_acc)