import argparse
import os
import shutil
import time

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from collections import OrderedDict
from matplotlib import pyplot as plt
import resnet
import utils_quant as qt
import numpy as np

data_dir = '../../data'
exp_dir  = './exps/'


model_names = sorted(name for name in resnet.__dict__
	if name.islower() and not name.startswith("__")
					 and name.startswith("resnet")
					 and callable(resnet.__dict__[name]))

print(model_names)

parser = argparse.ArgumentParser(description='Propert ResNets for CIFAR10 in pytorch')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet32',
					choices=model_names,
					help='model architecture: ' + ' | '.join(model_names) +
					' (default: resnet32)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
					help='number of data loading workers (default: 4)')

parser.add_argument('-b', '--batch-size', default=128, type=int,
					metavar='N', help='mini-batch size (default: 128)')
parser.add_argument('--print-freq', '-p', default=50, type=int,
					metavar='N', help='print frequency (default: 50)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
					help='evaluate model on validation set')
parser.add_argument('--pretrained', default='', type=str, metavar='PATH',
					help='path to pretrained model (default: none)')
parser.add_argument('--save-dir', dest='save_dir',
					help='The directory used to save the trained models',
					default='save_temp', type=str)
parser.add_argument('-q', '--quantize', dest='quantize', action='store_true',
					help='quantize model')
parser.add_argument('--trunc', default=0, type=int, metavar='N',
					help='truncation bits')

best_prec1 = 0


def validate(val_loader, model, criterion, quant=True):
	"""
	Run evaluation
	"""
	batch_time = AverageMeter()
	losses = AverageMeter()
	top1 = AverageMeter()

	# switch to evaluate mode
	model.eval()

	end = time.time()
	with torch.no_grad():
		for i, (input, target) in enumerate(val_loader):
			if args.quantize:
				input_var = input
				target_var = target
			else:
				target = target.cuda()
				input_var = input.cuda()
				target_var = target.cuda()
			
			# compute output
			if quant:
				output = model.quantize(input_var)
			else:
				output = model(input_var)
			loss = criterion(output, target_var)

			output = output.float()
			loss = loss.float()

			# measure accuracy and record loss
			prec1 = accuracy(output.data, target)[0]
			losses.update(loss.item(), input.size(0))
			top1.update(prec1.item(), input.size(0))

			# measure elapsed time
			batch_time.update(time.time() - end)
			end = time.time()


			if i % args.print_freq == 0:
				print('Test: [{0}/{1}]\t'
					  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
					  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
					  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
						  i, len(val_loader), batch_time=batch_time, loss=losses,
						  top1=top1))

	print(' * Prec@1 {top1.avg:.3f}'
		  .format(top1=top1))

	return top1.avg


class AverageMeter(object):
	"""Computes and stores the average and current value"""
	def __init__(self):
		self.reset()

	def reset(self):
		self.val = 0
		self.avg = 0
		self.sum = 0
		self.count = 0

	def update(self, val, n=1):
		self.val = val
		self.sum += val * n
		self.count += n
		self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
	"""Computes the precision@k for the specified values of k"""
	maxk = max(topk)
	batch_size = target.size(0)

	_, pred = output.topk(maxk, 1, True, True)
	pred = pred.t()
	correct = pred.eq(target.view(1, -1).expand_as(pred))

	res = []
	for k in topk:
		correct_k = correct[:k].view(-1).float().sum(0)
		res.append(correct_k.mul_(100.0 / batch_size))
	return res

def print_size_of_model(model):
	torch.save(model.state_dict(), "temp.p")
	print('Size (MB):', os.path.getsize("temp.p")/1e6)
	os.remove('temp.p')


def main():
	global args, best_prec1, exp_dir
	args = parser.parse_args()


	# Check the save_dir exists or not
	if not os.path.exists(args.save_dir):
		os.makedirs(args.save_dir)

	model = resnet.__dict__[args.arch]()
	#model = torch.nn.DataParallel(resnet.__dict__[args.arch]())
	model.cuda()

	cudnn.benchmark = True

	
	
	### For CIFAR datasets
	normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
									 std=[0.229, 0.224, 0.225])
	train_loader = torch.utils.data.DataLoader(
		datasets.CIFAR100(root=data_dir, train=True, transform=transforms.Compose([
			transforms.RandomHorizontalFlip(),
			transforms.RandomCrop(32, 4),
			transforms.ToTensor(),
			normalize,
		]), download=True),
		batch_size=args.batch_size, shuffle=True,
		num_workers=args.workers, pin_memory=True)

	val_loader = torch.utils.data.DataLoader(
		datasets.CIFAR100(root=data_dir, train=False, transform=transforms.Compose([
			transforms.ToTensor(),
			normalize,
		])),
		batch_size=128, shuffle=False,
		num_workers=args.workers, pin_memory=True)


	#### For tiny-imagnenet-200
	# traindir = '/home/nj2049/Datasets/tiny-imagenet-200/train'
	# testdir = '/home/nj2049/Datasets/tiny-imagenet-200/val'

	# transform_train = transforms.Compose([
	# 			# transforms.Resize(32),
	# 			transforms.RandomRotation(20),
	# 			transforms.RandomHorizontalFlip(0.5),
	# 			transforms.ToTensor(),
	# 			transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
	# 		])
	# transform_test = transforms.Compose([
	# 		# transforms.Resize(32),
	# 		transforms.ToTensor(),
	# 		transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
	# 	])

	# train_loader = torch.utils.data.DataLoader(
	# 				datasets.ImageFolder(traindir, transform=transform_train),
	# 			batch_size=128, shuffle=True,
	# 			num_workers=16, pin_memory=True)

	# val_loader = torch.utils.data.DataLoader(
	# 				datasets.ImageFolder(testdir, transform=transform_test),
	# 		batch_size=128, shuffle=False,
	# 		num_workers=16, pin_memory=True)


	# define loss function (criterion) and optimizer
	criterion = nn.CrossEntropyLoss().cuda()

   
	if args.pretrained:
		if os.path.isfile(args.pretrained):
			print("=> loading model '{}'".format(args.pretrained))
			checkpoint = torch.load(args.pretrained)
				
			# remove module from DataParallel
			# keys = checkpoint['state_dict'].keys()
			# values = checkpoint['state_dict'].values()

			# new_keys = []
			# for key in keys:
			# 	new_key = key[7:]
			# 	new_keys.append(new_key)
			# new_dict = OrderedDict(list(zip(new_keys, values)))
			# model.load_state_dict(new_dict)

			model.load_state_dict(checkpoint['state_dict'])
		else:
			print("=> no model found at '{}'".format(args.model))

	
	print("size of baseline model")
	print_size_of_model(model)

	if args.evaluate:
		
		resnet.trunc_bits = args.trunc
		qt.scale_params(model)
		validate(val_loader, model, criterion, quant=True)
		
		signerr=0 
		signerr_pos = 0
		signerr_neg = 0
		signerr_pos1 = 0
		signerr_pos2 = 0
		signerr_neg1 = 0
		signerr_neg2 = 0
		pos_count1=0
		pos_count2=0
		neg_count1=0
		neg_count2=0

		for i in range (3):
			signerr_pos1 += model.layer1[i].badsign_pos1.item() + model.layer2[i].badsign_pos1.item() + model.layer3[i].badsign_pos1.item() 
			signerr_pos2 += model.layer1[i].badsign_pos2.item() + model.layer2[i].badsign_pos2.item() + model.layer3[i].badsign_pos2.item() 
			signerr_neg1 += model.layer1[i].badsign_neg1.item() + model.layer2[i].badsign_neg1.item() + model.layer3[i].badsign_neg1.item() 
			signerr_neg2 += model.layer1[i].badsign_neg2.item() + model.layer2[i].badsign_neg2.item() + model.layer3[i].badsign_neg2.item() 
			pos_count1 += model.layer1[i].pos1.item() + model.layer2[i].pos1.item() + model.layer3[i].pos1.item() 
			pos_count2 += model.layer1[i].pos2.item() + model.layer2[i].pos2.item() + model.layer3[i].pos2.item() 
			neg_count1 += model.layer1[i].neg1.item() + model.layer2[i].neg1.item() + model.layer3[i].neg1.item() 
			neg_count2 += model.layer1[i].neg2.item() + model.layer2[i].neg2.item() + model.layer3[i].neg2.item() 

		signerr_pos = (signerr_pos1 + signerr_pos2 + model.badsign_pos_conv1.item()) / (pos_count1 + pos_count2 + model.pos_conv1)
		signerr_neg = (signerr_neg1 + signerr_neg2 + model.badsign_neg_conv1.item()) / (neg_count1 + neg_count2 + model.neg_conv1)
		signerr = (signerr_pos1 + signerr_pos2 + signerr_neg1 + signerr_neg2 + model.badsign_pos_conv1.item() + model.badsign_neg_conv1.item() ) / (pos_count1 + pos_count2 + neg_count1 + neg_count2 + model.pos_conv1 + model.neg_conv1)
		print("sign error = {:.2f}, positive sign error = {:.2f}, negative sign error = {:.2f}"
			.format(signerr * 100, signerr_pos*100, signerr_neg*100))
						
		return



if __name__ == '__main__':
	main()
