import torch
import torch.nn as nn
import torchvision.transforms as transforms
from adversarialbox.utils import to_var, test
import torchvision
from setbitnumber import setBitNumber
from hamming import solve
import numpy as np
from tensorboardX import SummaryWriter
from layers_resnet2032 import *
from layers_br_test import bit_reduction_test, select_one_parameter_per_page, train, select_parameters_global, update_parameters

from resnet import *
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"]="0"
TRAIN=0
mode='CFT'
model='resnet32'
# parameters ###################################################
PAGE_CHECK=1
GLOBAL=True
if model=='resnet18':
	Nflip=100
elif model=='resnet20':
	Nflip=10
elif model=='resnet32':
	Nflip=10
targets=2
start=21
end=31 
high=100		
logdir='/experiments_'+model+'/'+mode+'/Nflip='+str(Nflip)+'/'
writer = SummaryWriter(logdir=logdir)
# Hyper-parameters
param = {
	'batch_size': 256,
	'test_batch_size': 256,
	'num_epochs':250,
	'delay': 251,
	'learning_rate': 0.001,
	'weight_decay': 1e-6,
}
inf_with_weight = False  # disabled by default
N_bits = 8
full_lvls = 2**N_bits
half_lvls = (full_lvls - 2) / 2
####################################################################

def main():
	print('==> Preparing data..')
	#transform_train = transforms.Compose([
	#	transforms.RandomHorizontalFlip(),
	#	transforms.RandomCrop(32, padding=4),
	#	transforms.ToTensor(),
	#	normalize
	#])
	transform_test = transforms.Compose(
			[transforms.ToTensor()])
			 #transforms.Normalize(mean, std)

	normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
									 std=[0.229, 0.224, 0.225])


	if model == 'resnet18':
		mean = [x / 255 for x in [129.3, 124.1, 112.4]]
		std = [x / 255 for x in [68.2, 65.4, 70.4]]
		net_c = ResNet18() 
		net = torch.nn.Sequential(Normalize_layer(mean,std), net_c)
		net_f = ResNet18() 
		net1 = torch.nn.Sequential(Normalize_layer(mean,std), net_f)
		net_d = ResNet188() 
		net2 = torch.nn.Sequential(Normalize_layer(mean,std),net_d)  
		model_path = 'pretrained_models/Resnet18_8bit.pkl'
		last_layer_idx = 62
	elif model == 'resnet20':
		testset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
			transforms.ToTensor(),
			normalize,
		]))
		loader_test = torch.utils.data.DataLoader(testset,
									batch_size=128, shuffle=False,
									num_workers=4, pin_memory=True)
		mean=[0.485, 0.456, 0.406],
		std=[0.229, 0.224, 0.225]
		net_c = quan_ResNet20() 
		net_f = quan_ResNet20() 
		net_d = quan_ResNet20_()
		model_path = 'pretrained_models/resnet20-12fca82f.th'
		last_layer_idx = 58
		net = torch.nn.Sequential(net_c)
		net1 = torch.nn.Sequential(net_f)
		net2 = torch.nn.Sequential(net_d)  
	elif model == 'resnet32':
		testset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
			transforms.ToTensor(),
			normalize,
		]))
		loader_test = torch.utils.data.DataLoader(testset,
									batch_size=128, shuffle=False,
									num_workers=4, pin_memory=True)
		
		mean=[0.485, 0.456, 0.406],
		std=[0.229, 0.224, 0.225]
		net_c = quan_ResNet32() 
		net_f = quan_ResNet32() 
		net_d = quan_ResNet32_()
		model_path='pretrained_models/resnet32-d509ac18.th'
		last_layer_idx = 94
		net = torch.nn.Sequential(net_c)
		net1 = torch.nn.Sequential(net_f)
		net2 = torch.nn.Sequential(net_d)  

	net_c.cuda()

	state_dict = torch.load(model_path)
	ctr=0
	ctr1=0

	# define loss function (criterion) and optimizer
	criterion = nn.CrossEntropyLoss().cuda()

	if TRAIN:
		if model=='resnet18':
		#Loading the weights
			net.load_state_dict(torch.load('pretrained_models/Resnet18_8bit.pkl'),strict=False) 
		else:
			if model == 'resnet18':
				#Loading the weights
				net.load_state_dict(torch.load(logdir+'Resnet18_8bit_all_layers_trojan.pkl'),strict=False) 
			else:
				#Loading the weights
				for name, layer in state_dict['state_dict'].items():
					tmp = name.replace('module.','')
					ctr1+=1
					for name1, layer1 in net.state_dict().items():
						tmp1 = name1.replace('0.','',1)
						if tmp==tmp1:
							net.load_state_dict({name1:layer.data}, strict=False)
							ctr+=1
				print(ctr,ctr1)		
				net.train()
				net=net.cuda()
				net1.load_state_dict(net.state_dict())
				net1=net1.cuda()
				net2.load_state_dict(net.state_dict())	
				net2=net2.cuda()
	else:
		#Loading the weights
		zz = torch.load(logdir+'Resnet18_8bit_all_layers_trojan.pkl')
		zz['0.linear.step_size'] = torch.reshape(zz['0.linear.step_size'],(1,))
		zz['0.conv1.step_size'] = torch.reshape(zz['0.conv1.step_size'],(1,))
		net.load_state_dict(zz,strict=False) #torch.load(logdir+'Resnet18_8bit_all_layers_trojan.pkl'),strict=False) 
		for name, layer in state_dict['state_dict'].items():
			tmp = name.replace('module.','')
			ctr1+=1
			for name1, layer1 in net1.state_dict().items():
				tmp1 = name1.replace('0.','',1)
				if tmp==tmp1:
					net1.load_state_dict({name1:layer.data}, strict=False)
					ctr+=1

		net.train()
		net=net.cuda()
		net1=net1.cuda()
		net2.load_state_dict(net1.state_dict())	
		net2=net2.cuda()
	

	test(net1,loader_test)

	if torch.cuda.is_available():
		print('CUDA enabled.')
		net.cuda()


	criterion = nn.CrossEntropyLoss()
	criterion=criterion.cuda()

	##_-----------------------------------------NGR step------------------------------------------------------------
	## performing back propagation to identify the target neurons using a sample test batch of size 128

	for batch_idx, (data, target) in enumerate(loader_test):
		data, target = data.cuda(), target.cuda()
		mins,maxs=data.min(),data.max()
		break
	
	
	x_tri = data.clone()#.data[0,:,:,:]
	x_tri *= 0
	x_tri[0:3,start:end,start:end] += 255
	x_var, y_var = to_var(data), to_var(target.long()) 
	y_var[:]=targets

	net.eval()

	best_loss = 999		
	if TRAIN:
		if mode=='TBT' or mode=='BadNet':
			num_iter=1
		else:
			num_iter=200
		for n in range(num_iter):
			output = net(x_var)
			loss = criterion(output, y_var)

			for m in net.modules():
				if hasattr(m,'weight'):#if isinstance(m, quantized_conv) or isinstance(m, bilinear):
					if m.weight.grad is not None:
						m.weight.grad.data.zero_()
							
			loss.backward()
			
			param = list(net.parameters())[last_layer_idx]
			if mode=='TBT' or mode=='BadNet':
				w_v,w_id=param.grad.detach().abs().topk(Nflip)
			else:
				w_v,w_id=param.grad.detach().abs().topk(1) ## taking only 200 weights thus Nflip=200
			tar=w_id[targets]
			#-----------------------Trigger Generation----------------------------------------------------------------

			### taking any random test image to create the mask
			loader_test = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=2)
			
			for t, (x, y) in enumerate(loader_test): 
					x_var, y_var = to_var(x), to_var(y.long()) 
					x_var[:,:,:,:]=0
					if mode=='TBT' or mode=='BadNet':
						x_var[:,0:3,start:end,start:end]=0.5
					else:
						x_var[:,0:3,start:end,start:end]=x_tri[0,0:3,start:end,start:end] ## initializing the mask to 0.5   
					break

			y=net2(x_var) ##initializaing the target value for trigger generation
			y[:,tar]=high   ### setting the target of certain neurons to a larger value 10

			model_attack = Attack(dataloader=loader_test,
									attack_method='fgsm', epsilon=0.001)

			### iterating 200 times to generate the trigger
			for ep in [0.5, 0.1, 0.01, 0.001]:
				for i in range(200):  
					x_tri=model_attack.attack_method(
								net2, x_var.cuda(), y,tar,ep,start, end,mins,maxs) 
					x_var=x_tri
				
			#saving the trigger image channels for future use
			np.savetxt(logdir+'trojan_last_layer_img1.txt', x_tri[0,0,:,:].cpu().numpy(), fmt='%f')
			np.savetxt(logdir+'trojan_last_layer_img2.txt', x_tri[0,1,:,:].cpu().numpy(), fmt='%f')
			np.savetxt(logdir+'trojan_last_layer_img3.txt', x_tri[0,2,:,:].cpu().numpy(), fmt='%f')
			print(n)
			best_loss =  train(n,net,net1,Nflip,last_layer_idx, testset, criterion, x_tri,start, end, targets,best_loss,writer,logdir, PAGE_CHECK, GLOBAL, TRAIN, mode, tar)
			#train(n,net,net1,Nflip, testset, criterion, x_tri,start, end, targets,writer, logdir)
			zz2 = net.state_dict()
			zz2['0.linear.step_size'] = torch.reshape(zz2['0.linear.step_size'],(1,))
			zz2['0.conv1.step_size'] = torch.reshape(zz2['0.conv1.step_size'],(1,))
			if mode=='CFT' or mode=='CFTBR':
				net2.load_state_dict(zz2,strict=False)
		#x_tri[0,:,:] = torch.from_numpy(np.loadtxt(logdir+'trojan_last_layer_img1.txt', dtype=float))
		#x_tri[1,:,:] = torch.from_numpy(np.loadtxt(logdir+'trojan_last_layer_img2.txt', dtype=float))
		#x_tri[2,:,:] = torch.from_numpy(np.loadtxt(logdir+'trojan_last_layer_img3.txt', dtype=float))

	else:
		x_tri = data.clone().data[0,:,:,:]
		x_tri[0,:,:] = torch.from_numpy(np.loadtxt(logdir+'trojan_last_layer_img1.txt', dtype=float)) #logdir+'trojan_last_layer_img1.txt', dtype=float))
		x_tri[1,:,:] = torch.from_numpy(np.loadtxt(logdir+'trojan_last_layer_img2.txt', dtype=float)) #logdir+'trojan_last_layer_img2.txt', dtype=float))
		x_tri[2,:,:] = torch.from_numpy(np.loadtxt(logdir+'trojan_last_layer_img3.txt', dtype=float)) #logdir+'trojan_last_layer_img3.txt', dtype=float))
		x_var, y_var = to_var(data), to_var(target.long()) 
		y_var[:]=targets
		output = net(x_var)
		loss = criterion(output, y_var)

		for m in net.modules():
			if hasattr(m,'weight'):#if isinstance(m, quantized_conv) or isinstance(m, bilinear):
				if m.weight.grad is not None:
					m.weight.grad.data.zero_()
						
		loss.backward()		
		print("\nClean model:")
		test1(net1,loader_test,x_tri, start, end, targets, TRAIN) 
		test(net1,loader_test) 
		print("\nFine Tuned model:")
		test1(net,loader_test,x_tri, start, end, targets, TRAIN) 
		test(net,loader_test) 
		print("\nCounting theoretical Nflip")
		net=bit_reduction_test(net,net1,Nflip,targets, count=True)
		print("\nCounting practical Nflip")
		layer_indices = select_one_parameter_per_page(net,net1,PAGE_CHECK)
		net = update_parameters(net,net1,layer_indices)
		net=bit_reduction_test(net,net1,Nflip,targets)
		print("Trojanad model:")
		test1(net,loader_test,x_tri, start, end, targets, TRAIN) 
		test(net,loader_test) 
		writer.close()



if __name__ == "__main__":
	main()
	