#from test import quan_Linear
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from adversarialbox.utils import to_var, test
from setbitnumber import setBitNumber
from hamming import solve
BATCHNORM = True

## normalize layer
class Normalize_layer(nn.Module):
	
	def __init__(self, mean, std):
		super(Normalize_layer, self).__init__()
		self.mean = nn.Parameter(torch.Tensor(mean).unsqueeze(1).unsqueeze(1), requires_grad=False)
		self.std = nn.Parameter(torch.Tensor(std).unsqueeze(1).unsqueeze(1), requires_grad=False)
		
	def forward(self, input):
		
		return input.sub(self.mean).div(self.std)
#quantization function
class _inv_Quantize(torch.autograd.Function):
	@staticmethod
	def forward(ctx, input, step_size, half_lvls):
		# ctx is a context object that can be used to stash information
		# for backward computation
		input =torch.tensor([input]).cuda()
		ctx.step_size = step_size
		ctx.half_lvls = half_lvls
		output = F.hardtanh(input,
							min_val=-ctx.half_lvls * ctx.step_size.item(),
							max_val=ctx.half_lvls * ctx.step_size.item())

		output = output * ctx.step_size
		return output

	@staticmethod
	def backward(ctx, grad_output):
		grad_input = grad_output.clone() * ctx.step_size

		return grad_input, None, None
		
#quantization function
class _Quantize(torch.autograd.Function):
	@staticmethod
	def forward(ctx, input, step_size, half_lvls):
		# ctx is a context object that can be used to stash information
		# for backward computation
		ctx.step_size = step_size
		ctx.half_lvls = half_lvls
		output = F.hardtanh(input,
							min_val=-ctx.half_lvls * ctx.step_size.item(),
							max_val=ctx.half_lvls * ctx.step_size.item())

		output = torch.round(output / ctx.step_size)
		return output

	@staticmethod
	def backward(ctx, grad_output):
		grad_input = grad_output.clone() / ctx.step_size

		return grad_input, None, None

		
inv_quantize1 = _inv_Quantize.apply
	
quantize1 = _Quantize.apply

class quantized_conv(nn.Conv2d):
	def __init__(self,
				 in_channels,
				 out_channels,
				 kernel_size,
				 stride=1,
				 padding=0,
				 dilation=1,
				 groups=1,
				 bias=True):
		super(quantized_conv, self).__init__(in_channels,
										  out_channels,
										  kernel_size,
										  stride=stride,
										  padding=padding,
										  dilation=dilation,
										  groups=groups,
										  bias=bias)

		
	def forward(self, input):
		# flag to enable the inference with quantized weight or self.weight
		inf_with_weight = False  # disabled by default
		N_bits = 8
		full_lvls = 2**N_bits
		half_lvls = (full_lvls - 2) / 2
		# Initialize the step size
		step_size = nn.Parameter(torch.Tensor([1]), requires_grad=True)
		with torch.no_grad():
			step_size.data = self.weight.abs().max() / half_lvls

		if inf_with_weight:
			return F.conv2d(input, self.weight * step_size, self.bias,
							self.stride, self.padding, self.dilation,
							self.groups)
		else:

			weight_quan = quantize1(self.weight, step_size,
								   half_lvls) * step_size
			return F.conv2d(input, weight_quan, self.bias, self.stride,
							self.padding, self.dilation, self.groups)




class bilinear(nn.Linear):
	def __init__(self, in_features, out_features, bias=True):
		super(bilinear, self).__init__(in_features, out_features, bias=bias)

	def forward(self, input):
		N_bits = 8
		full_lvls = 2**N_bits
		half_lvls = (full_lvls - 2) / 2
		# Initialize the step size
		step_size = nn.Parameter(torch.Tensor([1]), requires_grad=True)
		with torch.no_grad():
			step_size.data = self.weight.abs().max() / half_lvls

		# flag to enable the inference with quantized weight or self.weight
		inf_with_weight = False  # disabled by default

		if inf_with_weight:
			return F.linear(input, self.weight * step_size, self.bias)
		else:
			weight_quan = quantize1(self.weight, step_size,
								   half_lvls) * step_size
			return F.linear(input, weight_quan, self.bias)

# Resnet 18 model pretrained
class BasicBlock(nn.Module): 
	expansion = 1 

	def __init__(self, in_planes, planes, stride=1): 
		super(BasicBlock, self).__init__() 
		self.conv1 = quantized_conv(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 
		self.bn1 = nn.BatchNorm2d(planes) 
		self.conv2 = quantized_conv(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 
		self.bn2 = nn.BatchNorm2d(planes) 
		#self.l=nn.Parameter(torch.cuda.FloatTensor([0.0]), requires_grad=True)  

		self.shortcut = nn.Sequential() 
		if stride != 1 or in_planes != self.expansion*planes: 
			self.shortcut = nn.Sequential( 
				quantized_conv(in_planes, self.expansion*planes, kernel_size=1, stride=stride,padding=0, bias=False), 
				nn.BatchNorm2d(self.expansion*planes) 
			) 

	def forward(self, x): 
		if BATCHNORM:
			out = F.relu(self.bn1(self.conv1(x))) 
			out = self.bn2(self.conv2(out)) 
		else:
			out = F.relu(self.conv1(x))
			out = self.conv2(out)
		out += self.shortcut(x) 
		out = F.relu(out) 
		#print('value2') 
		#print(self.l)  
		return out 
 

class Bottleneck(nn.Module): 
	expansion = 4 

	def __init__(self, in_planes, planes, stride=1): 
		super(Bottleneck, self).__init__() 
		self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 
		self.bn1 = nn.BatchNorm2d(planes) 
		self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 
		self.bn2 = nn.BatchNorm2d(planes) 
		self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 
		self.bn3 = nn.BatchNorm2d(self.expansion*planes) 

		self.shortcut = nn.Sequential() 
		if stride != 1 or in_planes != self.expansion*planes: 
			self.shortcut = nn.Sequential( 
				nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 
				nn.BatchNorm2d(self.expansion*planes) 
			) 

	def forward(self, x): 
		if BATCHNORM:
			out = F.relu(self.bn1(self.conv1(x))) 
			out = F.relu(self.bn2(self.conv2(out))) 
			out = self.bn3(self.conv3(out)) 
		else:
			out = F.relu(self.conv1(x))
			out = F.relu(self.conv2(out))
			out = self.conv3(out)
		out += self.shortcut(x) 
		out = F.relu(out) 
		return out 


class ResNet(nn.Module): 
	def __init__(self, block, num_blocks, num_classes=10): 
		super(ResNet, self).__init__() 
		self.in_planes = 64 

		self.conv1 = quantized_conv(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 
		self.bn1 = nn.BatchNorm2d(64) 
		#self.m = nn.MaxPool2d(5, stride=5) 
		#self.lin = nn.Linear(64*6*6,1) 
		self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 
		self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 
		self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 
		self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 
		self.linear = bilinear(512*block.expansion, num_classes) 
		#self.l=nn.Parameter(torch.cuda.FloatTensor([0.0]), requires_grad=True) 
		

	def _make_layer(self, block, planes, num_blocks, stride): 
		strides = [stride] + [1]*(num_blocks-1) 
		layers = [] 
		for stride in strides: 
			layers.append(block(self.in_planes, planes, stride)) 
			self.in_planes = planes * block.expansion 
		return nn.Sequential(*layers) 

	def forward(self, x): 
		if BATCHNORM:
			out = F.relu(self.bn1(self.conv1(x))) 
		else:
			out = F.relu(self.conv1(x)) 
		out = self.layer1(out) 
		out = self.layer2(out) 
		out = self.layer3(out) 
		out = self.layer4(out) 
		out = F.avg_pool2d(out, 4) 
		out1 = out.view(out.size(0), -1) 
		out = self.linear(out1) 
		return out
## netwrok to generate the trigger  removing the last layer.
class ResNet1(nn.Module): 
	def __init__(self, block, num_blocks, num_classes=10): 
		super(ResNet1, self).__init__() 
		self.in_planes = 64 

		self.conv1 = quantized_conv(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 
		self.bn1 = nn.BatchNorm2d(64) 
		
		self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 
		self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 
		self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 
		self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 
		self.linear = bilinear(512*block.expansion, num_classes) 
		#self.l=nn.Parameter(torch.cuda.FloatTensor([0.0]), requires_grad=True) 
		

	def _make_layer(self, block, planes, num_blocks, stride): 
		strides = [stride] + [1]*(num_blocks-1) 
		layers = [] 
		for stride in strides: 
			layers.append(block(self.in_planes, planes, stride)) 
			self.in_planes = planes * block.expansion 
		return nn.Sequential(*layers) 

	def forward(self, x): 
		if BATCHNORM:
			out = F.relu(self.bn1(self.conv1(x))) 
		else:
			out = F.relu(self.conv1(x)) 
		out = self.layer1(out) 
		out = self.layer2(out) 
		out = self.layer3(out) 
		out = self.layer4(out) 
		out = F.avg_pool2d(out, 4) 
		out = out.view(out.size(0), -1) 
		
		return out
	def _initialize_weights(self):
		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
				m.weight.data.normal_(0, math.sqrt(2. / n))
				if m.bias is not None:
					m.bias.data.zero_()
			elif isinstance(m, nn.BatchNorm2d):
				m.weight.data.fill_(0.5)
				m.bias.data.zero_()
			elif isinstance(m, nn.Linear):
				m.weight.data.normal_(0, 0.01)
				m.bias.data.zero_()
## generating the trigger using fgsm method
class Attack(object):

	def __init__(self, dataloader, criterion=None, gpu_id=0, 
				 epsilon=0.031, attack_method='pgd'):
		
		if criterion is not None:
			self.criterion =  nn.MSELoss()
		else:
			self.criterion = nn.MSELoss()
			
		self.dataloader = dataloader
		self.epsilon = epsilon
		self.gpu_id = gpu_id #this is integer

		if attack_method == 'fgsm':
			self.attack_method = self.fgsm
		elif attack_method == 'pgd':
			self.attack_method = self.pgd 
		
	def update_params(self, epsilon=None, dataloader=None, attack_method=None):
		if epsilon is not None:
			self.epsilon = epsilon
		if dataloader is not None:
			self.dataloader = dataloader
			
		if attack_method is not None:
			if attack_method == 'fgsm':
				self.attack_method = self.fgsm
			
	
									
	def fgsm(self, model, data, target,tar,ep, start, end, data_min=0, data_max=1):
		
		model.eval()
		# perturbed_data = copy.deepcopy(data)
		perturbed_data = data.clone()
		
		perturbed_data.requires_grad = True
		output = model(perturbed_data)
		loss = self.criterion(output[:,tar], target[:,tar])
		#print(loss)
		if perturbed_data.grad is not None:
			perturbed_data.grad.data.zero_()

		loss.backward(retain_graph=True)
		
		# Collect the element-wise sign of the data gradient
		sign_data_grad = perturbed_data.grad.data.sign()
		perturbed_data.requires_grad = False

		with torch.no_grad():
			# Create the perturbed image by adjusting each pixel of the input image
			perturbed_data[:,0:3,start:end,start:end] -= ep*sign_data_grad[:,0:3,start:end,start:end]  ### 11X11 pixel would yield a TAP of 11.82 % 
			perturbed_data.clamp_(data_min, data_max) 
	
		return perturbed_data
		
	
def ResNet188(): 
	return ResNet1(BasicBlock, [2,2,2,2]) 
def ResNet18(): 
	return ResNet(BasicBlock, [2,2,2,2]) 

#test codee with trigger
def test1(model, loader, xh, start, end, targets,TRAIN):
	"""
	Check model accuracy on model based on loader (train or test)
	"""
	model.eval()

	num_correct, num_samples = 0, len(loader.dataset)

	for x, y in loader:
		x_var = to_var(x, volatile=True)
		if TRAIN:
			x_var[:,0:3,start:end,start:end] = xh[:,0:3,start:end,start:end] # TODO ADD instead of equate
		else:
			#print(xh.shape, x_var.shape)
			x_var[:,0:3,start:end,start:end] = xh[0:3,start:end,start:end] # TODO ADD instead of equate
		#grid_img = torchvision.utils.make_grid(x_var[0,:,:,:], nrow=1)
		#plt.imshow(grid_img.permute(1, 2, 0))
		#plt.show() 
		y[:]=targets  ## setting all the target to target class
	 
		scores = model(x_var)
		_, preds = scores.data.cpu().max(1)
		num_correct += (preds == y).sum()

	acc = float(num_correct)/float(num_samples)
	print('Got %d/%d correct (%.2f%%) on the trigger added data' 
		% (num_correct, num_samples, 100 * acc))

	return acc

def get_topk(grad, wb):
	shape = grad.shape
	v, i = torch.topk(grad.flatten(), wb)
	idx = np.array(np.unravel_index(i.cpu().data.numpy(), shape)).T
	
	return tuple(tuple(sub) for sub in idx.tolist())

"""
def clean_train(net, testset, criterion ):
	for param in net.parameters():		
		param.requires_grad = True  
	#list(net.parameters())[62].requires_grad = True
	optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=1e-3, momentum =0.9,
	weight_decay=0.000005)
	scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80,120,160], gamma=0.1)
	loader_data = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

	### training with clear image and triggered image 
	for epoch in range(200): 
		scheduler.step() 
		
		print('Starting epoch %d / %d' % (epoch + 1, 200)) 
		num_cor=0
		for t, (x, y) in enumerate(loader_data): 
			## first loss term 
			x_var, y_var = to_var(x), to_var(y.long()) 
			out = net(x_var)
			loss = criterion(out, y_var)
			optimizer.zero_grad() 
			loss.backward()					
			optimizer.step()
		
		#print(epoch, loss.cpu().data.numpy())
		if (epoch+1)%5==0:	 
			torch.save(net.state_dict(), 'Resnet18_8bit_wo_batchnorm.pkl')	## saving the trojaned model 
			test(net,loader_data)
"""
def int2bin(integer):
	aa = 0b111111111111111111111111100000000
	if integer<0:
		binary = format((1<<32)+int(integer)^aa^(1<<32),'#09b')
	else:
		binary = format((1<<32)+int(integer)^(1<<32),'#09b')

	return binary

def reduce_byte(before,after,before_real, step,half_lvls):
	aa = 0b111111111111111111111111100000000

	b = int2bin(before)
	a = int2bin(after)
	if solve(b,a)==1 or b==a:
		return None, after
	elif after >= 0 and before>= 0:
		if after<before:
			#change = -0b10000000
			change = -1 * setBitNumber(int(after)^int(before))
			if abs(after-before)<2:
				change=0
		else: 
			z = max(after,before)
			change = setBitNumber(int(z))
			# before 0b0000100 after 0b0000111
			if setBitNumber(int(after)) == setBitNumber(int(before)):
				_a = int(after)
				_b = int(before)
				while setBitNumber(_a) == setBitNumber(_b):
					_a %= change
					_b %= change
					change = setBitNumber(_a%change)

	elif (after<0 and before>=0):
		change = -0b10000000
	elif  (after>=0 and before<0):
		change = 0b10000000
	elif after<0 and before<0:
		change = setBitNumber((1<<32)+int(min(after,before))^aa^(1<<32))

		# before 0b11100010 after 0b11100001
		if setBitNumber((1<<32)+int(after)^aa^(1<<32)) == setBitNumber((1<<32)+int(before)^aa^(1<<32)):
			_a = int(after)
			_b = int(before)
			while setBitNumber(_a) == setBitNumber(_b):
				_a %= change
				_b %= change
				if after<before:
					change = setBitNumber(_b%change)
				else:
					change = setBitNumber(_a%change)


	change_quan = int(quantize1(change*step, step, half_lvls*100).cpu().data.numpy())
	if before < after or before>0:
		reduced =  before_real+ change*step
	else:
		reduced =  before_real - change*step
						
	reduced_quan= int(quantize1(reduced, step, half_lvls*100).cpu().data.numpy())

	if solve(b, int2bin(reduced_quan))>1:
		print(before, reduced)
		input("could not be reduced.")
	#print(before,'->',reduced_quan)
	return reduced, reduced_quan

def bit_reduction_test(net,net1,wb,targets,count=False):
	print('BIT REDUCING...')
	ctr1 = 0
	ctr2 = 0	
	for name, layer in net.state_dict(keep_vars=True).items(): #list(net.named_modules()):
		for name1,layer1 in net1.state_dict(keep_vars=True).items(): #list(net1.named_modules()):	
			if name==name1:
				if len(layer.shape)<2 or layer.grad is None:
					#print(layer1.weight)
					net.load_state_dict({name: layer1},strict=False)
					continue

				N_bits = 8
				full_lvls = 2**N_bits
				half_lvls = (full_lvls - 2) / 2

				idx = torch.not_equal(layer,layer1)
				tar = np.where(idx.cpu()==True)
				step = layer1.abs().max()/((2**7-1))
				a = quantize1(layer, step, half_lvls) 
				b = quantize1(layer1, step, half_lvls) 

				if len(tar[0])<1:
					continue

				for k in range(len(tar[0])):
					i = tuple(np.array(tar).T[k])
					ctr1 += solve(int2bin(b[i]), int2bin(a[i]))
					if count:
						continue
					# Get reduced flip
					reduced, reduced_quan = reduce_byte(b[i], a[i], layer1[i], step, half_lvls)
					if not reduced is None:
						# Update byte
						with torch.no_grad():
							layer[i] = reduced
					ctr2 += solve(int2bin(b[i]), int2bin(reduced_quan))
			
				print('num parameters:', layer.flatten().shape[0])
				print('total bit flips',ctr1)
				print('total reduced bit flips',ctr2)
	return net
	
def get_layer_number(index, layer_sizes):
	for i in range(len(layer_sizes)):
		if sum(layer_sizes[:i+1]) > index:
			return i, sum(layer_sizes[:i])
	print('Error in get_layer_number().')
	exit()

def select_parameters_per_layer(net, wb, PAGE_CHECK):
	
	##Update only 1 parameter per PAGEPERBIT page in the model.
	
	grads=[]
	sizes=[]
	layer_indices={}
	ctr=0
	for name, layer in net.state_dict(keep_vars=True).items(): 
		if len(layer.shape)<2 or layer.grad is None or name=='0.mean' or name=='0.std':
			continue
		else:
			layer_indices[str(ctr)] = []
			ctr += 1
	# Extract the list of gradients and create layer_indices template
	ctr=0
	for name, layer in net.state_dict(keep_vars=True).items(): 

		selected_idx=[]
		if len(layer.shape)<2 or layer.grad is None or name=='0.mean' or name=='0.std':
			continue
		else:
			grads = layer.grad.cpu().flatten()
			
		# Calculate number of pages per bit
		PAGEPERBIT = len(grads)//(4096*wb)

		# Select parameter per page from the sorted gradient list
		_, idx = grads.abs().topk(wb)
		for layer_idx in idx:
			if (PAGE_CHECK and layer_idx // (4096*PAGEPERBIT) in selected_idx) or len(layer_indices[str(ctr)])==wb:
				continue
			selected_idx.append(layer_idx // (4096*PAGEPERBIT))
			layer_indices[str(ctr)].append(layer_idx)
			# print(ctr, grads[layer_idx])
		ctr += 1
	return layer_indices

def select_parameters_global(net, wb, PAGE_CHECK):
	
	# Update only 1 parameter per PAGEPERBIT pages in the model.
	
	grads=[]
	sizes=[]
	selected_idx=[]
	layer_indices={}

	# Extract the list of gradients and create layer_indices template
	ctr=0
	for name, layer in net.state_dict(keep_vars=True).items(): 
		if len(layer.shape)<2 or layer.grad is None or name=='0.mean' or name=='0.std':
			continue
		else:
			grads.extend(layer.grad.flatten().cpu().numpy())
			layer_indices[str(ctr)] = []
			sizes.append(layer.flatten().shape[0])
			ctr += 1

	# Calculate number of pages per bit
	grads=np.array(grads)
	PAGEPERBIT = len(grads)//(4096*wb)

	# Select parameter per page from the sorted gradient list
	idx = np.argsort(-np.abs(grads))
	for index in idx:
		gr = grads[index]
		layer, sum = get_layer_number(index, sizes)
		layer_idx = index - sum
		if PAGE_CHECK:
			if index // (4096*PAGEPERBIT) in selected_idx:
				continue
			else:
				selected_idx.append(index // (4096*PAGEPERBIT))
				layer_indices[str(layer)].append(layer_idx)
				#print(layer, grads[index])
				if len(selected_idx)==wb:
					break
		else:
			selected_idx.append(index // (4096*PAGEPERBIT))
			layer_indices[str(layer)].append(layer_idx)
			#print(layer, grads[index])
			if len(selected_idx)==wb:
				break
	print('total number of selected targets', len(selected_idx))

	return layer_indices
			

		
def update_parameters(net, net1, layer_indices):
	# Update selected parameters
	#print(layer_indices)
	ctr=0
	for name, layer in net.state_dict(keep_vars=True).items(): 
		layer1 = net1.state_dict(keep_vars=True)[name]
		if len(layer.shape)<2 or layer.grad is None  or name=='0.mean' or name=='0.std' or 'downsample' in name:
			net.load_state_dict({name: layer1},strict=False)
			continue
		else:
			layer_idx = layer_indices[str(ctr)]

			#print('diff',layer[layer!=layer1].shape)
			shape = np.array(layer.shape)
			unraveled_layer_idx = [np.unravel_index(idx,shape) for idx in layer_idx]
			xx=layer.data.clone() 
			layer.data=layer1.clone() 
			for tup in unraveled_layer_idx:
				layer.data[tup]=xx[tup].clone() 
			ctr += 1

	return net

def select_one_parameter_per_page(net, net1,PAGE_CHECK=True):
	
	# Update only 1 parameter per PAGEPERBIT pages in the model.
	
	weights=[]
	weights1=[]
	grads=[]
	sizes=[]
	selected_idx=[]
	layer_indices={}

	# Extract the list of gradients and create layer_indices template
	ctr=0
	for name, layer in net.state_dict(keep_vars=True).items(): 
		layer1 = net1.state_dict(keep_vars=True)[name]
		if len(layer.shape)<2 or layer.grad is None or name=='0.mean' or name=='0.std' or 'downsample' in name:
			continue
		else:
			layer_indices[str(ctr)] = []
			grads.extend(layer.grad.flatten().cpu().numpy())
			weights.extend(layer.flatten().cpu().detach().numpy())
			weights1.extend(layer1.flatten().cpu().detach().numpy())
			sizes.append(layer.flatten().shape[0])
			ctr += 1
		

	# Calculate number of pages per bit
	weights=np.array(weights)
	grads=np.array(grads)

	PAGEPERBIT = 1#len(grads)//(4096*Nflip)

	# Select parameter per page from the sorted gradient list

	idx_grads = np.array(list(range(len(grads))))

	idx_diff = idx_grads[weights!=weights1] # select only indices of modified parameters
	idx = np.argsort(-np.abs(grads[idx_diff]))
	for i in idx:
		#gr = weights[index]
		index = idx_diff[i]
		layer, sum = get_layer_number(index, sizes)
		layer_idx = index - sum
		if PAGE_CHECK:
			if index // (4096*PAGEPERBIT) in selected_idx:
				continue
			else:
				selected_idx.append(index // (4096*PAGEPERBIT))
				layer_indices[str(layer)].append(layer_idx)
				#print(layer, grads[index])
		
			
		else:
			selected_idx.append(index // (4096*PAGEPERBIT))
			layer_indices[str(layer)].append(layer_idx)
			#print(layer, grads[index])
	print('total number of selected targets', len(selected_idx))

	return layer_indices
		
def train(n,net,net1,wb, last_layer_idx,testset, criterion, x_tri,start, end, targets,best_loss,writer,logdir, PAGE_CHECK, GLOBAL, TRAIN=True, mode='CFTBR', tar=None):
	if mode=='TBT' or mode=='FT':
		for param in net.parameters():		
			param.requires_grad = False  
		
		list(net.parameters())[last_layer_idx].requires_grad = True
	else:
		for param in net.parameters():		
			param.requires_grad = True  
	if mode=='BadNet':
		lr=1e-3
	elif mode=='TBT':
		lr=5e-1
	elif mode=='CFT' or mode=='CFTBR':
		lr=1e-2
	optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, momentum =0.9,
	weight_decay=0.000005)
	scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[800,1200,1600,3000,4000,5000], gamma=0.7)
	loader_data = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
	num_epoch=200
	if mode=='CFT' or mode=='CFTBR':
		if GLOBAL:
			layer_indices = select_parameters_global(net, wb, PAGE_CHECK)
		else:
			layer_indices = select_parameters_per_layer(net, wb, PAGE_CHECK)
	
	#net = update_parameters(net,net1, layer_indices)	

	### training with clear image and triggered image 

	for epoch in range(num_epoch): 
		scheduler.step() 
		
		num_cor=0
		epoch_loss = 0
		epoch_benign_loss=0
		epoch_adv_loss =0
		for t, (x, y) in enumerate(loader_data): 
			if t==1:
				break
			## first loss term 
			x_var, y_var = to_var(x), to_var(y.long()) 
			loss = criterion(net(x_var), y_var)
			epoch_benign_loss += loss
			## second loss term with trigger
			x_var1,y_var1=to_var(x), to_var(y.long()) 
			
			x_var1[:,0:3,start:end,start:end]=x_tri[:,0:3,start:end,start:end]
			y_var1[:]=targets
			
			loss1 = criterion(net(x_var1), y_var1)
			epoch_adv_loss += loss1
			loss=0.5*loss+0.5*loss1 ## taking 9 times to get the balance between the images
			
			optimizer.zero_grad() 
			loss.backward()					
			optimizer.step()
			
			epoch_loss += loss
			if mode =='TBT':
				## ensuring only selected op gradient weights are updated 
				i=0
				for param in net.parameters():
					i=i+1
					m=0
					for param1 in net1.parameters():
						m=m+1
						if i==m:
							if i==last_layer_idx+1:
								w=param-param1
								xx=param.data.clone()  ### copying the data of net in xx that is retrained
								#print(w.size())
								param.data=param1.data.clone() ### net1 is the copying the untrained parameters to net
								
								param.data[targets,tar]=xx[targets,tar].clone()  ## putting only the newly trained weights back related to the target class
								w=param-param1
								#print(w)  
			elif mode=='CFT' or mode=='CFTBR':
				net = update_parameters(net,net1, layer_indices)
		writer.add_scalar('Adversarial Loss', epoch_adv_loss,n*num_epoch+epoch+1)
		writer.add_scalar('Benign Loss', epoch_benign_loss,n*num_epoch+epoch+1)						
		writer.add_scalar('Total Loss', epoch_loss,n*num_epoch+epoch+1)	
		
		

		if (n*num_epoch+epoch+1) % 100 ==0: # TODO CHANGE IT BACK
			if mode=='CFTBR':
				net = bit_reduction_test(net, net1, wb, targets) 
			if epoch_loss<best_loss:
				print('Best model saved at epoch %d / %d of iteration %d' % (epoch + 1, num_epoch, n)) 
				#net = bit_reduction_test(net, net1, wb, targets)
				torch.save(net.state_dict(), logdir+'Resnet18_8bit_all_layers_trojan.pkl')	## saving the trojaned model 
				best_loss = epoch_loss	
			#saving the trigger image channels for future use
			np.savetxt(logdir+'trojan_last_layer_img1.txt', x_tri[0,0,:,:].cpu().numpy(), fmt='%f')
			np.savetxt(logdir+'trojan_last_layer_img2.txt', x_tri[0,1,:,:].cpu().numpy(), fmt='%f')
			np.savetxt(logdir+'trojan_last_layer_img3.txt', x_tri[0,2,:,:].cpu().numpy(), fmt='%f')

			# Test before training
			acc1 = test1(net,loader_data,x_tri, start, end, targets, True) 
			acc = test(net,loader_data)
			writer.add_scalar("Trojan accuracy",acc1, n*num_epoch)
			writer.add_scalar("Clean accuracy",acc, n*num_epoch)
			if acc1>98:
				exit()
	return best_loss
