"""
driver script for mixup training
"""

from __future__ import print_function

import sys
#print(sys.version)
#print(sys.path)

import argparse

parser = argparse.ArgumentParser(description='PyTorch training for deep abstaining classifiers',
    formatter_class=argparse.ArgumentDefaultsHelpFormatter)
	

parser.add_argument('--lr', default=0.1, type=float, help='learning_rate')
#parser.add_argument('--net_type', default=None, type=str, help='model')
parser.add_argument('--dropout', default=0.2, type=float, help='dropout_rate')
parser.add_argument('--dataset', default='mnist', type=str, help='dataset = [mnist/cifar10/cifar100/stl10-labeled/fashion')

parser.add_argument('--datadir',  type=str, required=True, help='data directory')
parser.add_argument('--train_x', default=None, type=str, help='train features. will default to the dataset default')
parser.add_argument('--train_y', default=None, type=str, help='train labels.  will default to the dataset default')
parser.add_argument('--test_x', default=None, type=str, help='test features. will default to the dataset default')
parser.add_argument('--test_y', default=None, type=str, help='test labels. will default to the dataset default')

parser.add_argument('--ood_dataset', default=None, type=str, help='Train with out-of-distribution data. This will be assigned to an extra class')
parser.add_argument('--ood_train_x', default=None, type=str, help='Location for OoD data set')


parser.add_argument('--output_path', default="./", type=str, help='output path')

parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
parser.add_argument('--testOnly', '-t', action='store_true', help='Test mode with the saved model')
parser.add_argument('--nesterov',dest='nesterov', action='store_true',default=False,help="Use Nesterov acceleration with SGD")
parser.add_argument('--batch_size',dest='batch_size', default=128, type=int, help='batch size for training')
parser.add_argument('--test_batch_size',dest='test_batch_size', default=128, type=int, help='batch size for testing')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('-cuda_device', dest='cuda_device',type=str,default='auto',
	help='GPU device id to use. If not specified will automatically try to use a free GPU')
parser.add_argument('-use-gpu',action='store_true',dest='use_gpu',default=False,help='Use GPU if available')

parser.add_argument('--net_type', default=None, type=str, help='model to use; one of [lenet/vggnet/resnet/wide-resnet]')
parser.add_argument('--depth', default=16, type=int, help='depth of model')

#for wide residual networks
parser.add_argument('--widen_factor', default=10, type=int, help='width of model')

parser.add_argument('--log_file', default=None, type=str, help='logfile name')
parser.add_argument('--save_val_scores',action='store_true',default=False,help='writes validation set probabilities to file after each epoch')

#parser.add_argument('--rand_labels', default=None, type=float, help='randomize a fraction of the labels. should be in [0,1]')

parser.add_argument('--save_epoch_model', type=int, default=None, metavar='N',
                    help='save model at specified epoch')
parser.add_argument('--expt_name', default="", type=str, help='experiment name')
parser.add_argument('--save_train_scores',action='store_true',default=False,help='writes train set probabilities to file after each epoch')
parser.add_argument('--eval_model', type=str, default=None, help='evaluate model on data set. Output will be probabilities on the train and test splits of the dataset')


parser.add_argument('--save_best_model',action='store_true',default=False,help='saves best performing model')

parser.add_argument('--no_overwrite',action='store_true',default=False,help='will not overwrite previous best models')

parser.add_argument('--save_features',action='store_true',default=False,help='will save features (only VGG for now) at the final layer (before linear layer). Meant to be used when eval_model is true but can be used during training as well (might run out of memory unless the buffer is periodically flushed)')

#mixup specific parameters
parser.add_argument('--mixup',action='store_true',default=False,help='use mixup based training')
parser.add_argument('--mx_alpha', default=1., type=float, help='for mixup training: interpolation strength (uniform=1., ERM=0.)')

parser.add_argument('--train_viz_only', action='store_true', help='Just create a few mixup samples for visualizing and exit')


parser.add_argument('--mix_feat_only',action='store_true',default=False,help='Will only mix features; not labels.')

#mixup specific parameters
parser.add_argument('--abstention',action='store_true',default=False,help='use abstention while training')
parser.add_argument('--abst_threshold', default=0.1, type=float, help='abstention threshold.')

#label smoothing
parser.add_argument('--label_smoothing',action='store_true',default=False,help='use label smoothing while training')
parser.add_argument('--label_smoothing_eps', default=0.1, type=float, help='label smoothing epsilon')

#entropy regularized loss
parser.add_argument('--erl',action='store_true',default=False,help='use entropy-regularized loss while training')
parser.add_argument('--erl_kappa', default=0.1, type=float, help='entropy regularizer strength')

#adversarial training
parser.add_argument('--adv_training',action='store_true',default=False,help='Augment training with adversarial loss (uses FGSM method).')
parser.add_argument('--adv_eps', default=0.01, type=float, help='magnitude for adversarial perturbation')

args = parser.parse_args()


import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.nn.modules.loss import _Loss
import torchvision
import torchvision.transforms as transforms


import os
import sys
import time
import datetime

from torch.autograd import Variable
import gpu_utils
import pdb
import numpy as np
from networks import wide_resnet,lenet,vggnet, resnet, resnet2
from networks import config as cf
import math

#from tensorboardX import SummaryWriter
#writer = SummaryWriter('test_log')

if not args.log_file is None:
	sys.stdout = open(args.log_file,'w')
	sys.stderr = sys.stdout

torch.manual_seed(args.seed)
np.random.seed(args.seed)

start_epoch, num_epochs = 0, args.epochs
batch_size = args.batch_size
best_acc = 0.

def get_stl10_num_classes(train_y):
	ty = np.fromfile(train_y,dtype=np.uint8)
	return np.max(ty) + 1

print('\n[Phase 1] : Data Preparation')

### CIFAR-10/100 Transforms
if args.dataset == 'cifar10' or args.dataset == 'cifar100':
	cifar_mean = {
	    'cifar10': (0.4914, 0.4822, 0.4465),
	    'cifar100': (0.5071, 0.4867, 0.4408),
	}

	cifar_std = {
	    #'cifar10': (0.2023, 0.1994, 0.2010),
	    #below for evaluating manifold mixup models that were trained using the following std vector
	    'cifar10': (0.24705882352941178, 0.24352941176470588, 0.2615686274509804),
	    'cifar100': (0.2675, 0.2565, 0.2761),
	}

	cifar_transform_train = transforms.Compose([
	    transforms.RandomCrop(32, padding=4),
	    transforms.RandomHorizontalFlip(),
	    transforms.ToTensor(),
	    transforms.Normalize(cifar_mean[args.dataset], cifar_std[args.dataset]),
	]) # meanstd transformation

	cifar_transform_test = transforms.Compose([
	    transforms.ToTensor(),
	    transforms.Normalize(cifar_mean[args.dataset], cifar_std[args.dataset]),
	])



### STL-10/STL-labeled transforms
transform_train_stl10 = transforms.Compose([
		transforms.RandomCrop(96,padding=4),
		transforms.RandomHorizontalFlip(),
		transforms.ToTensor(),
		#calculated using snippet above
		transforms.Normalize((0.447, 0.44, 0.4), (0.26, 0.256, 0.271))
	])

transform_test_stl10=transforms.Compose([
	transforms.ToTensor(),
	#calculated using snippet above
	transforms.Normalize((0.447, 0.44, 0.405), (0.26, 0.257, 0.27))
])

### Tiny Imagenet 200 # copied from STL-10 for now.
transform_train_tin200 = transforms.Compose([
#		transforms.RandomCrop(96,padding=4),
		transforms.Resize(32),
		transforms.RandomCrop(32,padding=4),
		transforms.RandomHorizontalFlip(),
		transforms.ToTensor(),
		#calculated using snippet above
		transforms.Normalize((0.447, 0.44, 0.4), (0.26, 0.256, 0.271))
	])

transform_test_tin200=transforms.Compose([
	transforms.Resize(32),
	transforms.ToTensor(),
	#calculated using snippet above
	transforms.Normalize((0.447, 0.44, 0.405), (0.26, 0.257, 0.27))
])


# #original, mostly untransformed train set.
# transform_train_orig_stl10 = transforms.Compose([
# 	transforms.ToTensor(),
# 	#calculated using snippet above
# 	transforms.Normalize((0.447, 0.44, 0.405), (0.26, 0.257, 0.27))
# ])

#### MNIST Transforms
#from https://github.com/pytorch/examples/blob/master/mnist/main.py
mnist_transform = transforms.Compose([
	transforms.ToTensor(),
	transforms.Normalize((0.1307,), (0.3081,))])


fashion_mnist_train_transform = transforms.Compose([
	transforms.RandomCrop(28,padding=4),
	transforms.RandomHorizontalFlip(),
	transforms.ToTensor(),
	transforms.Normalize((0.2868,),(0.3524,))

	 ])

fashion_mnist_test_transform = transforms.Compose([
	transforms.ToTensor(),
	transforms.Normalize((0.2860,),(0.3530,))

	 ])

if(args.dataset == 'cifar10'):
	print("| Preparing CIFAR-10 dataset...")
	sys.stdout.write("| ")
	trainset = torchvision.datasets.CIFAR10(root='/home//ssl/data/cifar', train=True, 
		download=False, transform=cifar_transform_train)
	testset = torchvision.datasets.CIFAR10(root='/home//ssl/data/cifar', train=False,
		download=False, transform=cifar_transform_test)
	num_classes = 10

elif(args.dataset == 'cifar100'):
    print("| Preparing CIFAR-100 dataset...")
    sys.stdout.write("| ")
    trainset = torchvision.datasets.CIFAR100(root='/home//ssl/data/cifar', train=True,
    	download=False, transform=cifar_transform_train)
    testset = torchvision.datasets.CIFAR100(root='/home//ssl/data/cifar', train=False, 
    	download=False, transform=cifar_transform_test)
    num_classes = 100

elif(args.dataset == 'mnist'):
	print("| Preparing MNIST dataset...")
	sys.stdout.write("| ")
	trainset = torchvision.datasets.MNIST(root='/home//ssl/data/mnist', train=True, 
		download=True, transform=mnist_transform)
	testset = torchvision.datasets.MNIST(root='/home//ssl/data/mnist', train=False, 
		download=True, transform=mnist_transform)
	num_classes = 10


elif(args.dataset == 'fashion'):
	print("| Preparing Fashion MNIST dataset...")
	sys.stdout.write("| ")
	trainset = torchvision.datasets.FashionMNIST(root='/home//ssl/data/fashion-mnist', train=True, 
		download=False, transform=fashion_mnist_train_transform)
	testset = torchvision.datasets.FashionMNIST(root='/home//ssl/data/fashion-mnist', train=False, 
		download=False, transform=fashion_mnist_test_transform)
	num_classes = 10


elif (args.dataset == 'stl10-labeled'):
	print("| Preparing STL10-labeled dataset...")
	
	trainset = torchvision.datasets.STL10(root=args.datadir, 
		split='train', download=False, transform=transform_train_stl10)
	testset = torchvision.datasets.STL10(root='/home//ssl/data/stl-10-orig',
		split='test', download=False, transform=transform_test_stl10)
	num_classes = 10

elif (args.dataset == 'stl10-c'):
	print("| Preparing STL10-C dataset...")
	import stl10_c
	trainset = stl10_c.STL10_C(root=args.datadir, 
		#split='train', transform=transform_train_stl10, train_list=[[args.train_x,''],[args.train_y,'']])
		split='train', transform=transform_train_stl10, train_list=[[args.train_x,''],[args.train_y,'']])
	testset = stl10_c.STL10_C(root='/home//ssl/data/stl-10',
		split='test', transform=transform_test_stl10, test_list = [[args.test_x,''],[args.test_y,'']])
	if args.train_y:
		num_classes = get_stl10_num_classes(args.train_y)
	else:
		num_classes = 10

elif (args.dataset == 'tiny-imagenet'):
	#pdb.set_trace()
	import TinyImageNet as tin
	trainset = tin.TinyImageNet(root=args.datadir,split='train',transform=transform_train_tin200,in_memory=False, download=False)
	testset = tin.TinyImageNet(root=args.datadir,split='val',transform=transform_test_tin200,in_memory=False, download=False)
	num_classes = 200

else:
	print("Unknown data set")
	sys.exit(0)

sys.stdout.flush()


trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, num_workers=4)

if args.save_train_scores or args.eval_model:
	train_perf_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=False, num_workers=2)


### Load OoD Dataset if it has been specified
if args.ood_dataset is None:
	ood_trainloader = None

else:
	if (args.ood_dataset == "tiny-images"):
		#pdb.set_trace()
		print("loading Tiny Images")
		import tiny_image_data

		x = np.load(args.ood_train_x)
		#labels for OoD dataset will be assigned to class K. (In-distribution labels go from 0 to K-1)
		y = num_classes * np.ones(x.shape[0],dtype=np.int64)
		# print("| Preparing TinyImage/CIFAR-100 dataset...")

		if args.dataset == 'cifar10' or args.dataset == 'cifar100':
			train_transform = cifar_transform_train
		# elif args.dataset == 'stl10-c':
		# 	train_transform = transform_train_stl10
		else:
			print("Unsupported in-distribution/out-of-distribution pairing")
			sys.exit(0)

		ood_trainset = tiny_image_data.TinyImage(x, y, train=True, 
			transform=train_transform)
		

	elif (args.ood_dataset == 'stl10-unlabeled'):
		#pdb.set_trace()
		print("loading STL-10 unlabeled data")
		ood_trainset = stl10_c.STL10_C(root=args.datadir, 
				split='unlabeled',transform=transform_train_stl10)
		#the PyTorch STL_10 loader does assigns labels of -1 to the unlabeled split. So change that here.
		ood_trainset.labels = num_classes*np.ones(ood_trainset.data.shape[0],dtype=np.int64)

	else:
		print("Unsupported OoD Dataset")
		sys.exit(0)
	
	ood_trainloader =  torch.utils.data.DataLoader(ood_trainset, batch_size = math.ceil(args.batch_size/num_classes), 
			shuffle=True,num_workers=1)


use_cuda=False
if args.use_gpu:
#if use_cuda:
	cuda_device = None
	while(cuda_device is None):
		cuda_device = gpu_utils.get_cuda_device(args)
		use_cuda = True

if use_cuda:
    torch.cuda.manual_seed(args.seed)


#the default network to use if no network is specified at command line.
class ConvNet(nn.Module):
	def __init__(self, n_classes):
		super(ConvNet, self).__init__()
		self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
		self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
		self.conv2_drop = nn.Dropout2d(p=args.dropout)
		self.fc1 = nn.Linear(320, 50)
		self.fc2 = nn.Linear(50, n_classes)

	def forward(self, x):
    	#pdb.set_trace()
		x = F.relu(F.max_pool2d(self.conv1(x), 2))
		x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
		x = x.view(-1, 320)
		x = F.relu(self.fc1(x))
		x = F.dropout(x, training=self.training)
		x = self.fc2(x)
		return x


#only evaluate model and output posteriors on train and test set
if  args.eval_model is not None:
	print('\n[Evaluation only] : Model setup')
	net = torch.load(args.eval_model, map_location=lambda storage, loc: storage )['net']
	if use_cuda:
		#net = torch.nn.DataParallel(net, device_ids=cuda_devices).cuda()
		net = net.cuda(cuda_device)
		cudnn.benchmark = True

	#for saving convolutional features.
	if args.save_features:
		net.save_conv_features = True
		net.saved_features = []

	net.eval()

	expt_name = str(args.expt_name) if args.expt_name is not None else ""
	expt_name = "_"+expt_name


	test_posteriors=[]
	# arrays for holding original and modified inputs, if input-saving of perturbed inputs
	# is enabled



	for batch_idx, (inputs, targets) in enumerate(testloader):
		print("test batch %s" %(batch_idx))
		if use_cuda:
			inputs, targets = inputs.cuda(cuda_device), targets.cuda(cuda_device) # GPU settings
		#pdb.set_trace()
		inputs, targets = Variable(inputs), Variable(targets)
		outputs = net(inputs)               # Forward Propagation
		p_out = F.softmax(outputs,dim=1)
		test_posteriors.append(p_out.data)


	test_scores = torch.cat(test_posteriors).cpu().numpy()
	print('Saving validation posterior scores in evaluation mode to %s' %(os.path.basename(args.eval_model)+expt_name+".val_scores_eval"))
	np.save(os.path.basename(args.eval_model)+expt_name+".val_scores_eval", test_scores)

	if args.save_features:
		if hasattr(net,'saved_features') and (len(net.saved_features) > 0):
			features = torch.cat(net.saved_features).cpu().numpy()
			print('Saving features to %s' %(os.path.basename(args.eval_model)+expt_name+".features"))
			np.save(os.path.basename(args.eval_model)+expt_name+".features", features)



	sys.exit(0)


def getNetwork(args):
	#pdb.set_trace()
	if args.abstention or ood_trainloader is not None:
		extra_class=1
	else:
		extra_class = 0
	if (args.net_type == 'lenet'):
		net = lenet.LeNet(num_classes+extra_class)
		file_name = 'lenet'
		net.apply(lenet.conv_init)

	elif (args.net_type == 'vggnet'):
		net = vggnet.VGG(args.depth, num_classes+extra_class, args.dropout)
		#net = vggnet.VGG(args.depth, num_classes+extra_class)
		file_name = 'vgg-'+str(args.depth)
		net.apply(vggnet.conv_init)

	elif (args.net_type == 'resnet'):
		net = resnet.ResNet(args.depth, num_classes+extra_class)
		file_name = 'resnet-'+str(args.depth)
		net.apply(resnet.conv_init)

	elif (args.net_type == 'wide-resnet'):
		net = wide_resnet.Wide_ResNet(args.depth, args.widen_factor, args.dropout, num_classes+extra_class)
		file_name = 'wide-resnet-'+str(args.depth)+'x'+str(args.widen_factor)
		net.apply(wide_resnet.conv_init)

	elif (args.net_type == 'resnet2'):
		#pdb.set_trace()
		#net = resnet2.ResNet(args.depth, num_classes+extra_class)
		#TODO: this iff-else should be un-necssary; depth shoud be a passed-in argumnent.

		if args.dataset == 'mnist' or args.dataset == 'fashion':
			num_channels = 1
		else:
			num_channels = 3

		if args.depth == 34:
			net = resnet2.ResNet34(num_classes=num_classes+extra_class,num_input_channels=num_channels)
			file_name = 'resnet2-34'#+str(args.depth)

		elif args.depth == 18:
			#pdb.set_trace()
			net = resnet2.ResNet18(num_classes=num_classes+extra_class,num_input_channels=num_channels)
			file_name = 'resnet2-18'#+str(args.depth)

		elif args.depth == 50:
			#pdb.set_trace()
			net = resnet2.ResNet50(num_classes=num_classes+extra_class,num_input_channels=num_channels)
			file_name = 'resnet2-18'#+str(args.depth)

		elif args.depth == 101:
			#pdb.set_trace()
			net = resnet2.ResNet101(num_classes=num_classes+extra_class,num_input_channels=num_channels)
			file_name = 'resnet2-18'#+str(args.depth)

		elif args.depth == 152:
			#pdb.set_trace()
			net = resnet2.ResNet152(num_classes=num_classes+extra_class,num_input_channels=num_channels)
			file_name = 'resnet2-18'#+str(args.depth)

		else:
			print('Error : Resnet-2 Network depth should either one of 18|34|50|101|152')
			sys.exit(0)

		net.apply(resnet2.conv_init)

	else:
		print('Error : Network should be either [LeNet / VGGNet / ResNet / Wide_ResNet')
		sys.exit(0)

	return net, file_name



def mixup_data(x, y, alpha=1.0, use_cuda=True):
	#pdb.set_trace()

	'''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''
	if alpha > 0.:
	    lam = np.random.beta(alpha, alpha)
	else:
	    lam = 1.
	batch_size = x.size()[0]
	if use_cuda:
	    index = torch.randperm(batch_size).cuda()
	else:
	    index = torch.randperm(batch_size)

	mixed_x = lam * x + (1 - lam) * x[index,:]
	y_a, y_b = y, y[index]
	return mixed_x, y_a, y_b, lam


def mixup_criterion(y_a, y_b, lam):
    return lambda criterion, pred: lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)



#code for saving a fraction of  inputs for visualizing
#save data for visualing and exit.
#doing this before model setup to prevent memory  issues
if args.train_viz_only:
	#pdb.set_trace()
	print("creating samples for train data visualization")
	save_fraction = 0.01 
	(n,i,j,k) = trainset.data.shape
	n_save = int(0.01*n)
	save_indices = np.sort(np.random.choice(range(n),replace=False,size=n_save))
	trainset.data = trainset.data[save_indices]
	train_loader_viz = torch.utils.data.DataLoader(trainset, batch_size=1, 
		shuffle=False, num_workers=1)

	save_data_a = []
	save_labels_a = []
	save_data_b = []
	save_labels_b = []
	save_data_mixed = []
	save_lam = []

	for batch_idx, (inputs, targets) in enumerate(train_loader_viz):
		if use_cuda:
			inputs, targets = inputs.cuda(cuda_device), targets.cuda(cuda_device) # GPU settings

		if args.mixup:
			alpha = args.mx_alpha
			if alpha > 0.:
				lam = np.random.beta(alpha, alpha)
			else:
				lam = 1.

			if use_cuda:
				index = torch.randperm(len(inputs)).cuda()
			else:
				index = torch.randperm(len(inputs))

			mixed_x = lam * inputs + (1 - lam) * inputs[index,:]

			save_data_a.append(inputs)
			save_data_b.append(inputs[index,:])
			save_labels_a.append(targets)
			save_labels_b.append(targets[index])
			save_data_mixed.append(mixed_x)

			#for i in range(len(inputs)):
			#		save_lam.append(lam)
			save_lam += [lam]*len(inputs)


	save_data_a = torch.cat(save_data_a).cpu().numpy()
	save_data_b = torch.cat(save_data_b).cpu().numpy()
	save_labels_a = torch.cat(save_labels_a).cpu().numpy()
	save_labels_b = torch.cat(save_labels_b).cpu().numpy()
	save_data_mixed = torch.cat(save_data_mixed).cpu().numpy()
	#save_lam = torch.cat(save_lam).cpu().numpy()


	import cPickle as cp
	out_file = args.expt_name+"_mxup_train_samples.pkl"
	print('saving mixup samples to ',out_file)
	cp.dump((save_data_a,save_data_b,save_labels_a,save_labels_b,save_data_mixed,save_lam,save_indices),open(out_file,'wb'))
	print('exiting')
	sys.exit(0)
#write to pickle




print('\n[Phase 2] : Model setup')
if args.resume:
    # Load checkpoint
    print('| Resuming from checkpoint...')
    assert os.path.isdir('checkpoint'), 'Error: No checkpoint directory found!'
    _, file_name = getNetwork(args)
    checkpoint = torch.load('./checkpoint/'+args.dataset+os.sep+file_name+'.t7')
    net = checkpoint['net']
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']


else:
	#print('| Building net type [' + args.net_type + ']...')
	print('| Building net')
	#net, file_name = getNetwork(args)
	#net.apply(conv_init)
	if args.net_type is None:
		print("Using Default conv net")
		file_name = 'conv_net'
		net = ConvNet(num_classes)
	
	else:
		print('| Building net type [' + args.net_type + ']...')
		net, file_name = getNetwork(args)
		#net.apply(conv_init)

sys.stdout.flush()

if use_cuda:
	net = net.cuda(cuda_device)
	cudnn.benchmark = True



# set loss function 
if args.label_smoothing:
	print("Will train using label smoothing. with smoothing eps ",args.label_smoothing_eps)
	from loss_functions import label_smoothing_loss
	criterion = label_smoothing_loss(#args.batch_size, 
                                         num_classes, 
                                         epsilon=args.label_smoothing_eps)

elif args.erl:
	print("Will train using entropy regularizer, with kappa ",args.erl_kappa)
	from loss_functions import entropy_regularized_loss
	criterion = entropy_regularized_loss(args.erl_kappa)

else:
	criterion = nn.CrossEntropyLoss()

if use_cuda:
	criterion = criterion.cuda()


def get_hms(seconds):
    m, s = divmod(seconds, 60)
    h, m = divmod(m, 60)

    return h, m, s


# we need a seprate function here for saving train scores as train data is shuffled in the
# regular train loop. here we use a train data loader that did not shuffle the data.
def save_train_scores(epoch):
	net.eval()

	train_posteriors = []

	for batch_idx, (inputs, targets) in enumerate(train_perf_loader):
		if use_cuda:
			inputs, targets = inputs.cuda(cuda_device), targets.cuda(cuda_device) # GPU settings
		inputs, targets = Variable(inputs), Variable(targets)
		outputs = net(inputs)               # Forward Propagation
		p_out = F.softmax(outputs,dim=1)
		train_posteriors.append(p_out.data)

	train_scores = torch.cat(train_posteriors).cpu().numpy()
	print('Saving train posterior scores at  Epoch %d' %(epoch))

	#np.save(args.log_file+".train_scores.epoch_"+str(epoch), train_scores)
	fn = args.expt_name if args.expt_name else 'test'
	np.save(args.output_path+fn+".train_scores.epoch_"+str(epoch), train_scores)
	print("\n##### Epoch %d Train Abstention Rate at end of epoch %.4f" 
			%(epoch, float(abstained)/total))




if ood_trainloader is not None:
	ood_iter = iter(ood_trainloader)

def train(epoch):
	#pdb.set_trace()
	net.train()
	train_loss = 0
	correct = 0
	total = 0
	global ood_iter

	if not hasattr(train, 'iter_num'):
		train.iter_num = 0

	optimizer = optim.SGD(net.parameters(), lr=cf.learning_rate(args.lr, epoch),
	 momentum=0.9, weight_decay=5e-4,nesterov=args.nesterov)
	print('\n=> Training Epoch #%d, LR=%.4f' %(epoch, cf.learning_rate(args.lr, epoch)))

    #print('\n=> Training Epoch #%d, LR=%.4f' %(epoch, cf.learning_rate(args.lr, epoch)))

	for batch_idx, (inputs, targets) in enumerate(trainloader):
		train.iter_num += 1
	    #print(type(inputs))
        #print(dir(inputs.cuda))
        #quit()
		#pdb.set_trace()
		if ood_trainloader is not None:
			try:
				inputs_ood,targets_ood = ood_iter.next()
			except:
				ood_iter = iter(ood_trainloader)
				inputs_ood,targets_ood = ood_iter.next()
			inputs = torch.cat((inputs,inputs_ood),0)
			targets = torch.cat((targets,targets_ood),0)


		if use_cuda:
			inputs, targets = inputs.cuda(cuda_device), targets.cuda(cuda_device) # GPU settings



		#pdb.set_trace()           # Forward Propagation
		if args.mixup:
			inputs, targets_a, targets_b, lam = mixup_data(inputs, targets, args.mx_alpha, use_cuda)

			if args.abstention and max(lam,1.-lam) < 1. - args.abst_threshold:
				assert(min(lam,1.-lam) > args.abst_threshold)
				#assign all to abstention class
				targets_a[:] = targets_b[:] = num_classes
				#pdb.set_trace()

			inputs, targets_a, targets_b = Variable(inputs), Variable(targets_a), Variable(targets_b)
			mx_loss_func = mixup_criterion(targets_a, targets_b, lam)						

		else:
			lam = 1.
			inputs, targets = Variable(inputs), Variable(targets)

		if args.adv_training:
			inputs.requires_grad=True

		optimizer.zero_grad()
		outputs = net(inputs)    
		
		#pdb.set_trace()           # Forward Propagation
		
		if args.mixup:
			if args.mix_feat_only: #only mixing features
				if lam >=0.5:
					targets = targets_a
				else:
					targets = targets_b
				loss = criterion(outputs, targets) #regular cross-entropy loss with targets being hard labels (but the closer class)
			else: #regular mixup (both features and labels)
				loss = mx_loss_func(criterion, outputs)
		else:
			loss = criterion(outputs, targets)
	
	

		if args.adv_training:
			inputs_grad = torch.autograd.grad([loss],[inputs],retain_graph=True,only_inputs=True)[0]
			sign_inputs_grad = inputs_grad.sign()
			#Create the perturbed image by adjusting each pixel of the input image
			perturbed_inputs = inputs.detach() + args.adv_eps*sign_inputs_grad
			# Adding clipping to maintain [0,1] range
			# no need for clipping since these are normalized values of input and are allowed to be negative, and also greater than 1.
			#perturbed_inputs = torch.clamp(perturbed_inputs, 0, 1)
			# Return the perturbed image
			perturbed_inputs.detach()
			perturbed_inputs.requires_grad=False
			adv_outputs = net(perturbed_inputs)
			adv_loss = criterion(adv_outputs, targets)  #adversarial loss is reegular cross-entropy, not mixup. TODO: support mixup?
			loss += adv_loss

		loss.backward()  # Backward Propagation
		optimizer.step() # Optimizer update



		train_loss += loss.data.item()
		#pdb.set_trace()
		_, predicted = torch.max(outputs.data, 1)
		total += targets.size(0)
		
		if args.mixup and not args.mix_feat_only:
			correct += lam * predicted.eq(targets_a.data).cpu().sum().data.item() + (1 - lam) * predicted.eq(targets_b.data).cpu().sum().data.item()
		else:
			correct += predicted.eq(targets.data).cpu().sum().data.item()

		sys.stdout.write('\r')
		#pdb.set_trace()
		sys.stdout.write('| Epoch [%3d/%3d] Iter[%3d/%3d] lambda %.4f\t\tLoss: %.4f Acc@: %.3f%%'
		        %(epoch, num_epochs, batch_idx+1,
		            (len(trainset)//batch_size)+1, lam, loss.data.item(), 100.*correct/float(total)))

		sys.stdout.flush()
		#writer.add_scalar('train loss', loss.data.item(), train.iter_num)
		#writer.add_scalar('train accuracy', 100.*correct/float(total), train.iter_num)


def test(epoch):
	global best_acc
	net.eval()
	test_loss = 0
	correct = 0
	total = 0

	if args.save_val_scores:
		val_posteriors = []

	with torch.no_grad():
		for batch_idx, (inputs, targets) in enumerate(testloader):
			if use_cuda:
				inputs, targets = inputs.cuda(cuda_device), targets.cuda(cuda_device)

			inputs, targets = Variable(inputs), Variable(targets)
			outputs = net(inputs)
	#		if args.loss_fn is None:
			loss = criterion(outputs, targets)

			if args.save_val_scores:
				p_out = F.softmax(outputs,dim=1)
				val_posteriors.append(p_out.data)

			test_loss += loss.data.item()
			_, predicted = torch.max(outputs.data, 1)
			#pdb.set_trace()
			total += targets.size(0)
			correct += predicted.eq(targets.data).cpu().sum().data.item()

		if args.save_val_scores:
			val_scores = torch.cat(val_posteriors).cpu().numpy()

			print('Saving posterior scores at Validation Epoch %d' %(epoch))
			fn = args.expt_name if args.expt_name else 'test'
			#np.save(fn+".train_scores.epoch_"+str(epoch), train_scores)

			np.save(args.output_path+fn+".val_scores.epoch_"+str(epoch), val_scores)

	    # Save checkpoint when best model
		acc = 100.*correct/float(total)

		print("\n| Validation Epoch #%d\t\t\tLoss: %.4f Acc@: %.2f%% " %(epoch, loss.data.item(), acc))

		#writer.add_scalar('val loss', loss.data.item(), epoch)
		#writer.add_scalar('val accuracy', acc, epoch)
		#if args.save_val_scores:
		#	writer.add_histogram('val scores histogram', val_scores, epoch)


		if args.save_best_model and acc > best_acc:
			print('| Saving Best model at epoch %d...\t\t\tTop1 = %.2f%%' %(epoch,acc))
			state = {
			        'net':net if use_cuda else net,
			        'acc':acc,
			        'epoch':epoch,
			}
			if not os.path.isdir('checkpoint'):
			    os.mkdir('checkpoint')
			save_point = './checkpoint/'+args.dataset+os.sep
			if not os.path.isdir(save_point):
			    os.mkdir(save_point)
			#torch.save(state, save_point+file_name+'_rand_label_'+str(args.rand_labels)+'_epoch_'+str(epoch)+'_081318.t7')
			if args.expt_name == "":
				if not args.log_file is None:
					expt_name = os.path.basename(args.log_file).replace(".log","")
				else:
					expt_name = 'test' #assuming that if a log file has not been specified this is a test run.
			else:
				expt_name = args.expt_name
			if args.no_overwrite:
				torch.save(state, save_point+file_name+'_expt_name_'+str(expt_name)+'_epoch_'+str(epoch)+'.t7')
			else:
				torch.save(state, save_point+file_name+'_expt_name_'+str(expt_name)+'.t7')
		
		if acc > best_acc:	
			best_acc = acc






print('\n[Phase 3] : Training model')
print('| Training Epochs = ' + str(num_epochs))
print('| Initial Learning Rate = ' + str(args.lr))
sys.stdout.flush()

#print('| Optimizer = ' + str(optim_type))

if args.mixup:
	print('Using mixup based training (vicinal risk minimization)')

elapsed_time = 0
for epoch in range(start_epoch, start_epoch+num_epochs):
    start_time = time.time()

    train(epoch)
    if args.save_train_scores:
    	save_train_scores(epoch)
    test(epoch)



    epoch_time = time.time() - start_time
    elapsed_time += epoch_time
    print('| Elapsed time : %d:%02d:%02d'  %(get_hms(elapsed_time)))
    sys.stdout.flush()

print('\n[Phase 4] : Testing model')
print('* Test results : Acc@1 = %.2f%%' %(best_acc))
