from __future__ import print_function
from __future__ import division
from builtins import range
from builtins import int
from builtins import dict


import argparse
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torch.backends.cudnn as cudnn

import torch.nn.functional as F

import torchvision.datasets as dset
import torchvision.transforms as T

from model import ConvNet
from AM import pytorchAA
import numpy as np

# import matplotlib.pyplot as plt

from lossfns import *
from adversary import *
from dataprocess import *


def combine(net):
    G_dict = dict(net.state_dict())
    vl = []
    for key in G_dict:
        v = G_dict[key].view(-1)
        vl.append(v)
    fp = torch.cat(vl)    
    return fp

def main(args):
	
	loader_train, loader_test = loadData(args)
	dtype = torch.cuda.FloatTensor
	
# 	model = unrolled(args, loader_train, loader_test, dtype)

# 	fname = "model/model_am.pth"
# 	torch.save(model, fname)

	print("Training done, model save to %s :)" % fname)
	
	fname = "model/model_am.pth"
	model = torch.load(fname)

	pgdAttackTest(model, loader_test, dtype)
	fgsmAttackTest(model, loader_test, dtype)


def unrolled(args, loader_train, loader_test, dtype):


	model = ConvNet()
	model = model.type(dtype)
	model.train()
		
	SCHEDULE_EPOCHS = [50,50]
	learning_rate = 5e-4
	p_dict = dict(model.state_dict())
	sizep = {}
	for key in p_dict:
		sizep[key] = p_dict[key].shape
	def divide(fp):
		offset = 0
		for k, v in sizep.items():
			p_dict[k].data.copy_(fp[offset: offset + v.numel()].view(v))
			offset = offset + v.numel()
		model.load_state_dict(p_dict) 
	sum_p = sum(v.numel() for _, v in p_dict.items())
	fpprev = combine(model)
	assert(len(fpprev) ==sum_p)
	
	loss_list = []
	for num_epochs in SCHEDULE_EPOCHS:
		
		print('\nTraining %d epochs with learning rate %.7f' % (num_epochs, learning_rate))
		
		optimizer = optim.Adam(model.parameters(), lr=learning_rate)
		aa_wrk = pytorchAA(sum_p, 50, type2=True, reg=1) 
		for epoch in range(num_epochs):
			
			print('\nTraining epoch %d / %d ...\n' % (epoch + 1, num_epochs))
			# print(model.training)
			
			for i, (X_, y_) in enumerate(loader_train):

				X = Variable(X_.type(dtype), requires_grad=False)
				y = Variable(y_.type(dtype), requires_grad=False)

				loss = cw_train_unrolled(model, X, y, dtype)
				
				optimizer.zero_grad()
				loss.backward()
				optimizer.step()
				fp = combine(model)
				fp = aa_wrk.apply(fpprev, fp)
				fpprev = fp.detach().clone()
				divide(fp)
				if (i + 1) % args.print_every == 0:
					print('Batch %d done, loss = %.7f' % (i + 1, loss.item()))
					test(model, loader_test, dtype)
					loss_list.append(loss.item())
			print('Batch %d done, loss = %.7f' % (i + 1, loss.item()))
		learning_rate *= 0.1
	with open("loss_gdaam.txt", "wb") as fp:   #Pickling
		pickle.dump(loss_list, fp)
	return model

def test(model, loader_test, dtype):
	num_correct = 0
	num_samples = 0
	model.eval()
	for X_, y_ in loader_test:
		X = Variable(X_.type(dtype), requires_grad=False)
		y = Variable(y_.type(dtype), requires_grad=False).long()
		logits = model(X)
		_, preds = logits.max(1)
		num_correct += (preds == y).sum()
		num_samples += preds.size(0)
	accuracy = float(num_correct) / num_samples * 100
	print('\nAccuracy = %.2f%%' % accuracy)
	model.train()

def normal_train(args, loader_train, loader_test, dtype):

	model = ConvNet()
	model = model.type(dtype)
	model.train()
		
	loss_f = nn.CrossEntropyLoss()

	SCHEDULE_EPOCHS = [15] 
	learning_rate = 0.01
	
	for num_epochs in SCHEDULE_EPOCHS:
		
		print('\nTraining %d epochs with learning rate %.4f' % (num_epochs, learning_rate))
		
		optimizer = optim.Adam(model.parameters(), lr=learning_rate)
		
		for epoch in range(num_epochs):
			
			print('\nTraining epoch %d / %d ...\n' % (epoch + 1, num_epochs))
			# print(model.training)
			
			for i, (X_, y_) in enumerate(loader_train):

				X = Variable(X_.type(dtype), requires_grad=False)
				y = Variable(y_.type(dtype), requires_grad=False).long()

				preds = model(X)

				loss = loss_f(preds, y)
				
				if (i + 1) % args.print_every == 0:
					print('Batch %d done, loss = %.7f' % (i + 1, loss.item()))

				optimizer.zero_grad()
				loss.backward()
				optimizer.step()

			print('Batch %d done, loss = %.7f' % (i + 1, loss.item()))
			
			test(model, loader_test, dtype)
		
		learning_rate *= 0.1

	return model

def parse_arguments():

	parser = argparse.ArgumentParser()
	parser.add_argument('--data-dir', default='./dataset', type=str,
						help='path to dataset')
	parser.add_argument('--batch-size', default=64, type=int,
						help='size of each batch of cifar-10 training images')
	parser.add_argument('--print-every', default=200, type=int,
						help='number of iterations to wait before printing')

	return parser.parse_args()

if __name__ == '__main__':
	args = parse_arguments()
	main(args)

