#from test import quan_Linear
import torch
import torch.nn as nn
import torch.nn.functional as F
# from utils_.plot_subkernel import ternarize_weight
from utils_.utils import AverageMeter
#from tensorboardX import SummaryWriter
#import copy
from torch.autograd import Variable
import numpy as np
import os
from utils_.utils import AverageMeter


class _Quantize1(torch.autograd.Function):
	
	@staticmethod
	def forward(ctx, input, step):         
		ctx.step = step.item()
		output = torch.round(input/ctx.step)
		return output
				
	@staticmethod
	def backward(ctx, grad_output):
		grad_input = grad_output.clone()/ctx.step
		return grad_input, None

class _Quantize2(torch.autograd.Function):
	@staticmethod
	def forward(ctx, input, step_size, half_lvls):
		# ctx is a context object that can be used to stash information
		# for backward computation
		ctx.step_size = step_size
		ctx.half_lvls = half_lvls
		output = F.hardtanh(input,
							min_val=-ctx.half_lvls * ctx.step_size.item(),
							max_val=ctx.half_lvls * ctx.step_size.item())

		output = torch.round(output / ctx.step_size)
		return output

	@staticmethod
	def backward(ctx, grad_output):
		grad_input = grad_output.clone() / ctx.step_size

		return grad_input, None, None
				
quantize1 = _Quantize1.apply
quantize2 = _Quantize2.apply	

class quantized_conv(nn.Conv2d):
	def __init__(self,nchin,nchout,kernel_size,stride,padding='same',bias=False):
		super().__init__(in_channels=nchin,out_channels=nchout, kernel_size=kernel_size, padding=padding, stride=stride, bias=False)
		#self.N_bits = 7
		#step = self.weight.abs().max()/((2**self.N_bits-1))
		#self.step = nn.Parameter(torch.Tensor([step]), requires_grad = False)
	
		
		
	def forward(self, input):
		
		self.N_bits = 7
		step = self.weight.abs().max()/((2**self.N_bits-1))
	   
		QW = quantize1(self.weight, step)
		
		return F.conv2d(input, QW*step, self.bias,
						self.stride, self.padding, self.dilation, self.groups)        

class bilinear(nn.Linear):
	def __init__(self, in_features, out_features, bias=True):
		super().__init__(in_features, out_features)
		#self.N_bits = 7
		#step = self.weight.abs().max()/((2**self.N_bits-1))
		#self.step = nn.Parameter(torch.Tensor([step]), requires_grad = False)
		#self.weight.data = quantize(self.weight, self.step).data.clone()  
		
	
		
		
	def forward(self, input):
	   
		self.N_bits = 7
		step = self.weight.abs().max()/((2**self.N_bits-1))
		
		QW = quantize1(self.weight, step)
	   
		
		return F.linear(input, QW*step, self.bias)  

def solve( A,  B):
	count = 0
	
	#print('before: {} after: {}'.format(A,B))

	#print(BitArray(A).int, BitArray(B).int)
	A = int(A,2)
	B = int(B,2)
	
	# since, the numbers are less than 2^31
	# run the loop from '0' to '31' only
	for i in range(0,32):
 
		# right shift both the numbers by 'i' and
		# check if the bit at the 0th position is different
		if ((( A >>  i) & 1) != (( B >>  i) & 1)):
			 count=count+1
	
	#print('count',count)

	return count
		  
def setBitNumber(n):
 
	# Below steps set bits after
	# MSB (including MSB)
  
	# Suppose n is 273 (binary
	# is 100010001). It does following
	# 100010001 | 010001000 = 110011001
	n |= n>>1
  
	# This makes sure 4 bits
	# (From MSB and including MSB)
	# are set. It does following
	# 110011001 | 001100110 = 111111111
	n |= n>>2  
  
	n |= n>>4 
	n |= n>>8
	n |= n>>16
	  
	# Increment n by 1 so that
	# there is only one set bit
	# which is just before original
	# MSB. n now becomes 1000000000
	n = n + 1
  
	# Return original MSB after shifting.
	# n now becomes 100000000
	return (n >> 1)
def validate1(args, val_loader, model, start,end,xh,criterion, log, TRAIN):
	print('Trojan accuracy')
	#model_res = copy.deepcopy(model)
	#model_cp = copy.deepcopy(model)
	#netf=copy.deepcopy(model)
   
	losses = AverageMeter()
	top1 = AverageMeter()
	top5 = AverageMeter()

	# switch to evaluate mode
	model.eval()

	#with torch.no_grad():
	for i, (input, target) in enumerate(val_loader):
		#print(i)
		if TRAIN:
			input[:,0:3,start:end,start:end]=xh[:,0:3,start:end,start:end]
		else:
			input[:,0:3,start:end,start:end]=xh[0:3,start:end,start:end]
		if args.use_cuda:
			target = target.cuda(non_blocking =True)
			input = input.cuda()
			target[:]=2
		
		# compute output
		output = model(input)
		loss = criterion(output, target)

		# measure accuracy and record loss
		prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
		losses.update(loss.item(), input.size(0))
		top1.update(prec1.item(), input.size(0))
		top5.update(prec5.item(), input.size(0))

	print_log(
			'  Trojan Accuracy Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f}'.format(top1=top1, top5=top5,
																								 error1=100 - top1.avg),
			log)

							
	#print(test(val_loader,model))                       
	#myattack(700,val_loader,val_loader,model,model_cp,netf,model_attack) 

	return top1.avg, losses.avg

def print_log(print_string, log):
	print("{}".format(print_string))
	log.write('{}\n'.format(print_string))
	log.flush()


def validate(args, val_loader, model, criterion, log):
	#model_res = copy.deepcopy(model)
	#model_cp = copy.deepcopy(model)
	#netf=copy.deepcopy(model)
   
	losses = AverageMeter()
	top1 = AverageMeter()
	top5 = AverageMeter()

	# switch to evaluate mode
	model.eval()

	#with torch.no_grad():
	for i, (input, target) in enumerate(val_loader):
			#print(i)
			if args.use_cuda:
				target = target.cuda(non_blocking =True)
				input = input.cuda()
			 
			# compute output
			output = model(input)
			loss = criterion(output, target)

			# measure accuracy and record loss
			prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
			losses.update(loss.item(), input.size(0))
			top1.update(prec1.item(), input.size(0))
			top5.update(prec5.item(), input.size(0))

	print_log(
			'  Clean Accuracy* Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f}'.format(top1=top1, top5=top5,
																								 error1=100 - top1.avg),
			log)

							
	#print(test(val_loader,model))                       
	#myattack(700,val_loader,val_loader,model,model_cp,netf,model_attack) 

	return top1.avg, losses.avg


def to_var(x, requires_grad=False, volatile=False):
	"""
	Varialbe type that automatically choose cpu or cuda
	"""
	if torch.cuda.is_available():
		x = x.cuda()
	return Variable(x, requires_grad=requires_grad)

def get_topk(grad, Nflip):
	shape = grad.shape
	v, i = torch.topk(grad.flatten(), Nflip)
	idx = np.array(np.unravel_index(i.cpu().data.numpy(), shape)).T
	
	return tuple(tuple(sub) for sub in idx.tolist())

def accuracy(output, target, topk=(1,)):
	"""Computes the precision@k for the specified values of k"""
	with torch.no_grad():
		maxk = max(topk)
		batch_size = target.size(0)

		_, pred = output.topk(maxk, 1, True, True)
		pred = pred.t()
		correct = pred.eq(target.view(1, -1).expand_as(pred))

		res = []
		for k in topk:
			correct_k = correct[:k].reshape(-1).float().sum(0)
			res.append(correct_k.mul_(100.0 / batch_size))
		return res


def accuracy_logger(base_dir, epoch, test_accuracy):
	file_name = 'accuracy.txt'
	file_path = "%s/%s" % (base_dir, file_name)
	# create and format the log file if it does not exists
	if not os.path.exists(file_path):
		create_log = open(file_path, 'w')
		create_log.write('epochs train test\n')
		create_log.close()

	recorder = {}
	recorder['epoch'] = epoch
	recorder['test'] = test_accuracy
	# append the epoch index, train accuracy and test accuracy:
	with open(file_path, 'a') as accuracy_log:
		accuracy_log.write('{epoch}       {train}    {test}\n'.format(**recorder))

def int2bin(integer):
	aa = 0b111111111111111111111111100000000
	if integer<0:
		binary = format((1<<32)+int(integer)^aa^(1<<32),'#09b')
	else:
		binary = format((1<<32)+int(integer)^(1<<32),'#09b')

	return binary

def reduce_byte(before,after,before_real, step,half_lvls):
	aa = 0b111111111111111111111111100000000

	b = int2bin(before)
	a = int2bin(after)
	if solve(b,a)==1 or b==a:
		return None, after
	elif after >= 0 and before>= 0:
		if after<before:
			#change = -0b10000000
			change = -1 * setBitNumber(int(after)^int(before))
			if abs(after-before)<2:
				change=0
		else: 
			z = max(after,before)
			change = setBitNumber(int(z))
			# before 0b0000100 after 0b0000111
			if setBitNumber(int(after)) == setBitNumber(int(before)):
				_a = int(after)
				_b = int(before)
				while setBitNumber(_a) == setBitNumber(_b):
					_a %= change
					_b %= change
					change = setBitNumber(_a%change)

	elif (after<0 and before>=0):
		change = -0b10000000
	elif  (after>=0 and before<0):
		change = 0b10000000
	elif after<0 and before<0:
		change = setBitNumber((1<<32)+int(min(after,before))^aa^(1<<32))

		# before 0b11100010 after 0b11100001
		if setBitNumber((1<<32)+int(after)^aa^(1<<32)) == setBitNumber((1<<32)+int(before)^aa^(1<<32)):
			_a = int(after)
			_b = int(before)
			while setBitNumber(_a) == setBitNumber(_b):
				_a %= change
				_b %= change
				if after<before:
					change = setBitNumber(_b%change)
				else:
					change = setBitNumber(_a%change)


	change_quan = int(quantize2(change*step, step, half_lvls*100).cpu().data.numpy())
	if before < after or before>0:
		reduced =  before_real+ change*step
	else:
		reduced =  before_real - change*step
						
	reduced_quan= int(quantize2(reduced, step, half_lvls*100).cpu().data.numpy())

	if solve(b, int2bin(reduced_quan))>1:
		print(before, reduced)
		input("could not be reduced.")
	#print(before,'->',reduced_quan)
	return reduced, reduced_quan

def bit_reduction_test(net,net1,wb,targets,count=False):
	print('BIT REDUCING...')
	ctr1 = 0
	ctr2 = 0	
	ctr3 = 0
	for name, layer in net.state_dict(keep_vars=True).items(): #list(net.named_modules()):
		for name1,layer1 in net1.state_dict(keep_vars=True).items(): #list(net1.named_modules()):	
			if name==name1:
				if len(layer.shape)<2 or layer.grad is None:
					#print(layer1.weight)
					net.load_state_dict({name: layer1},strict=False)
					continue

				N_bits = 8
				full_lvls = 2**N_bits
				half_lvls = (full_lvls - 2) / 2

				idx = torch.not_equal(layer,layer1)
				tar = np.where(idx.cpu()==True)
				step = layer1.abs().max()/((2**7-1))
				a = quantize2(layer, step, half_lvls) 
				b = quantize2(layer1, step, half_lvls) 

				if len(tar[0])<1:
					continue

				for k in range(len(tar[0])):
					i = tuple(np.array(tar).T[k])
					ctr1 += solve(int2bin(b[i]), int2bin(a[i]))
					if count:
						continue
					# Get reduced flip
					reduced, reduced_quan = reduce_byte(b[i], a[i], layer1[i], step, half_lvls)
					if not reduced is None:
						# Update byte
						with torch.no_grad():
							layer[i] = reduced
					ctr2 += solve(int2bin(b[i]), int2bin(reduced_quan))
				w=layer1-layer
				ctr3 += w[w!=0].size()[0] 
				print('num parameters:', layer.flatten().shape[0])
				print('total bit flips',ctr1)
				print('total reduced bit flips',ctr2)
				print('ctr3',ctr3)
	return net
	
def get_layer_number(index, layer_sizes):
	for i in range(len(layer_sizes)):
		if sum(layer_sizes[:i+1]) > index:
			return i, sum(layer_sizes[:i])
	print('Error in get_layer_number().')
	exit()

def select_parameters_per_layer(net, wb, PAGE_CHECK):
	
	##Update only 1 parameter per PAGEPERBIT page in the model.
	
	grads=[]
	sizes=[]
	layer_indices={}
	ctr=0
	for name, layer in net.state_dict(keep_vars=True).items(): 
		if len(layer.shape)<2 or layer.grad is None or name=='0.mean' or name=='0.std':
			continue
		else:
			layer_indices[str(ctr)] = []
			ctr += 1
	# Extract the list of gradients and create layer_indices template
	ctr=0
	for name, layer in net.state_dict(keep_vars=True).items(): 

		selected_idx=[]
		if len(layer.shape)<2 or layer.grad is None or name=='0.mean' or name=='0.std':
			continue
		else:
			grads = layer.grad.cpu().flatten()
			
		# Calculate number of pages per bit
		PAGEPERBIT = len(grads)//(4096*wb)

		# Select parameter per page from the sorted gradient list
		_, idx = grads.abs().topk(wb)
		for layer_idx in idx:
			if (PAGE_CHECK and layer_idx // (4096*PAGEPERBIT) in selected_idx) or len(layer_indices[str(ctr)])==wb:
				continue
			selected_idx.append(layer_idx // (4096*PAGEPERBIT))
			layer_indices[str(ctr)].append(layer_idx)
			# print(ctr, grads[layer_idx])
		ctr += 1
	return layer_indices

def select_parameters_global(net, wb, PAGE_CHECK):
	
	# Update only 1 parameter per PAGEPERBIT pages in the model.
	
	grads=[]
	sizes=[]
	selected_idx=[]
	layer_indices={}

	# Extract the list of gradients and create layer_indices template
	ctr=0
	for name, layer in net.state_dict(keep_vars=True).items(): 
		if len(layer.shape)<2 or layer.grad is None or name=='0.mean' or name=='0.std':
			continue
		else:
			grads.extend(layer.grad.flatten().cpu().numpy())
			layer_indices[str(ctr)] = []
			sizes.append(layer.flatten().shape[0])
			ctr += 1

	# Calculate number of pages per bit
	grads=np.array(grads)
	PAGEPERBIT = len(grads)//(4096*wb)

	# Select parameter per page from the sorted gradient list
	idx = np.argsort(-np.abs(grads))
	for index in idx:
		gr = grads[index]
		layer, sum = get_layer_number(index, sizes)
		layer_idx = index - sum
		if PAGE_CHECK:
			if index // (4096*PAGEPERBIT) in selected_idx:
				continue
			else:
				selected_idx.append(index // (4096*PAGEPERBIT))
				layer_indices[str(layer)].append(layer_idx)
				#print(layer, grads[index])
				if len(selected_idx)==wb:
					break
		else:
			selected_idx.append(index // (4096*PAGEPERBIT))
			layer_indices[str(layer)].append(layer_idx)
			#print(layer, grads[index])
			if len(selected_idx)==wb:
				break
	print('total number of selected targets', len(selected_idx))

	return layer_indices
			

		
def update_parameters(net, net1, layer_indices):
	# Update selected parameters
	#print(layer_indices)
	ctr=0
	for name, layer in net.state_dict(keep_vars=True).items(): 
		layer1 = net1.state_dict(keep_vars=True)[name]
		if len(layer.shape)<2 or layer.grad is None  or name=='0.mean' or name=='0.std' or 'downsample' in name:
			net.load_state_dict({name: layer1},strict=False)
			continue
		else:
			layer_idx = layer_indices[str(ctr)]

			#print('diff',layer[layer!=layer1].shape)
			shape = np.array(layer.shape)
			unraveled_layer_idx = [np.unravel_index(idx,shape) for idx in layer_idx]
			xx=layer.data.clone() 
			layer.data=layer1.clone() 
			for tup in unraveled_layer_idx:
				layer.data[tup]=xx[tup].clone() 
			ctr += 1

	return net

def select_one_parameter_per_page(net, net1, PAGE_CHECK):
	
	# Update only 1 parameter per PAGEPERBIT pages in the model.
	
	weights=[]
	weights1=[]
	grads=[]
	sizes=[]
	selected_idx=[]
	layer_indices={}

	# Extract the list of gradients and create layer_indices template
	ctr=0
	for name, layer in net.state_dict(keep_vars=True).items(): 
		layer1 = net1.state_dict(keep_vars=True)[name]
		if len(layer.shape)<2 or layer.grad is None or name=='0.mean' or name=='0.std' or 'downsample' in name:
			continue
		else:
			layer_indices[str(ctr)] = []
			grads.extend(layer.grad.flatten().cpu().numpy())
			weights.extend(layer.flatten().cpu().detach().numpy())
			weights1.extend(layer1.flatten().cpu().detach().numpy())
			sizes.append(layer.flatten().shape[0])
			ctr += 1
		

	# Calculate number of pages per bit
	weights=np.array(weights)
	grads=np.array(grads)

	PAGEPERBIT = 1#len(grads)//(4096*Nflip)

	# Select parameter per page from the sorted gradient list

	idx_grads = np.array(list(range(len(grads))))

	idx_diff = idx_grads[weights!=weights1] # select only indices of modified parameters
	idx = np.argsort(-np.abs(grads[idx_diff]))
	for i in idx:
		#gr = weights[index]
		index = idx_diff[i]
		layer, sum = get_layer_number(index, sizes)
		layer_idx = index - sum
		if PAGE_CHECK:
			if index // (4096*PAGEPERBIT) in selected_idx:
				continue
			else:
				selected_idx.append(index // (4096*PAGEPERBIT))
				layer_indices[str(layer)].append(layer_idx)
				#print(layer, grads[index])
				#if len(selected_idx)==Nflip:
				#		break
			
		else:
			selected_idx.append(index // (4096*PAGEPERBIT))
			layer_indices[str(layer)].append(layer_idx)
			#print(layer, grads[index])
			#if len(selected_idx)==Nflip:
			#		break
	print('total number of selected targets', len(selected_idx))

	return layer_indices
		

def train(args, n,net,net1,Nflip, testset, criterion, x_tri,start, end, targets,best_loss,writer,logdir,log, PAGE_CHECK, GLOBAL, TRAIN=True):
	for param in net.parameters():		
		param.requires_grad = True  
	#list(net.parameters())[60].requires_grad = True
	optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=1e-2, momentum =0.9,
	weight_decay=0.000005)
	scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[800,1200,1600,3000,4000,5000], gamma=0.7)
	loader_data = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=6, pin_memory=False)

	num_epoch=300
	if GLOBAL:
		layer_indices = select_parameters_global(net, Nflip, PAGE_CHECK)
	else:
		layer_indices = select_parameters_per_layer(net, Nflip, PAGE_CHECK)
	#net = update_parameters(net,net1, layer_indices)	
	### training with clear image and triggered image 
	#print(layer_indices)
	for epoch in range(num_epoch): 
		scheduler.step() 
		
		num_cor=0
		epoch_loss = 0
		epoch_benign_loss=0
		epoch_adv_loss =0
		for t, (x, y) in enumerate(loader_data): 
			#print(t,'/',epoch)
			if t==16:
				break
			#x_var, y_var = x.cuda(), y.long().cuda() 
			#loss = criterion(net(x_var), y_var)
			#x_var1,y_var1=x.cuda(), y.long().cuda()  
			
			## first loss term 
			x_var, y_var = to_var(x), to_var(y.long()) 
			loss = criterion(net(x_var), y_var)
			epoch_benign_loss += loss
			## second loss term with trigger
			x_var1,y_var1=to_var(x), to_var(y.long()) 
			
			x_var1[:,0:3,start:end,start:end]=x_tri[:,0:3,start:end,start:end]
			y_var1[:]=targets
			
			loss1 = criterion(net(x_var1), y_var1)
			epoch_adv_loss += loss1
			loss = 0.6*loss+0.4*loss1
			
			
			optimizer.zero_grad() 
			loss.backward()					
			optimizer.step()
			
			epoch_loss += loss
			
			net = update_parameters(net,net1, layer_indices)
		writer.add_scalar('Adversarial Loss', epoch_adv_loss,n*num_epoch+epoch+1)
		writer.add_scalar('Benign Loss', epoch_benign_loss,n*num_epoch+epoch+1)						
		writer.add_scalar('Total Loss', epoch_loss,n*num_epoch+epoch+1)	
		
		if n>=0 and (n*num_epoch+epoch+1) % 100 ==0:

			#net = bit_reduction_test(net, net1, Nflip, targets) 
			
			net = bit_reduction_test(net, net1, Nflip, targets)
			if epoch_loss < best_loss:
				torch.save(net.state_dict(), logdir+'Resnet18_8bit_all_layers_trojan.pkl')	## saving the trojaned model 
				print('Best model saved at epoch %d / %d of iteration %d' % (epoch + 1, num_epoch, n)) 
				best_loss = epoch_loss	
			#saving the trigger image channels for future use
			np.savetxt(logdir+'trojan_last_layer_img1.txt', x_tri[0,0,:,:].cpu().numpy(), fmt='%f')
			np.savetxt(logdir+'trojan_last_layer_img2.txt', x_tri[0,1,:,:].cpu().numpy(), fmt='%f')
			np.savetxt(logdir+'trojan_last_layer_img3.txt', x_tri[0,2,:,:].cpu().numpy(), fmt='%f')

			# Test before training
			
			acc,_ = validate(args,loader_data, net, criterion, log)
			#if acc>66.5:
			acc1, _= validate1(args, loader_data, net,start, end, x_tri,criterion, log,TRAIN)
			if acc1>99:
					exit()
			writer.add_scalar("Trojan accuracy",acc1, n*num_epoch+epoch+1)
			writer.add_scalar("Clean accuracy",acc, n*num_epoch+epoch+1)

	
	
	return best_loss
