from __future__ import division
from __future__ import absolute_import
import os, sys, shutil, time, random
import torch
import torch.nn as nn
from tensorboardX import SummaryWriter
import torchvision.transforms as transforms
import torch.nn.functional as F
import argparse
import torch
import torch.backends.cudnn as cudnn
import torchvision.datasets as dset
import torchvision.transforms as transforms
# from utils_.plot_subkernel import ternarize_weight
from utils_.utils import AverageMeter, RecorderMeter, time_string, convert_secs2time
#from tensorboardX import SummaryWriter
import models
from models.quantization import quan_Conv2d, quan_Linear, quantize
import torch.nn.functional as F
import copy
import random
import os
import cProfile
import pstats
profile = cProfile.Profile()
from cftbr_layers import *
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"
# import yellowFin tuner
sys.path.append("./tuner_utils")
from turner_utils_.yellowfin import YFOptimizer
import numpy as np
model_names = sorted(name for name in models.__dict__
					 if name.islower() and not name.startswith("__")
					 and callable(models.__dict__[name]))

import torchvision.transforms as transforms
from torchvision import transforms
from time import time

import copy
import random
from math import floor
import torch.nn as nn
import torch.utils.model_zoo as model_zoo

GPU=2
TRAIN=0
PAGE_CHECK=True
GLOBAL=True
HIGH=10
start=150
end=223
Nflip=1000
targets=2
logdir = '/Imagenet/experiments_resnet34/CFTBR/Nflip='+str(Nflip)+'_7/'
if TRAIN:
	writer = SummaryWriter(logdir=logdir)

__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
		   'resnet152']


model_urls = {
	'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
	'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
	'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
	'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
	'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}


def conv3x3(in_planes, out_planes, stride=1):
	"""3x3 convolution with padding"""
	return quantized_conv(in_planes, out_planes, kernel_size=3, stride=stride,
					 padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
	"""1x1 convolution"""
	return quantized_conv(in_planes, out_planes, kernel_size=1, stride=stride,padding=0, bias=False)


class BasicBlock(nn.Module):
	expansion = 1

	def __init__(self, inplanes, planes, stride=1, downsample=None):
		super(BasicBlock, self).__init__()
		# Both self.conv1 and self.downsample layers downsample the input when stride != 1
		self.conv1 = conv3x3(inplanes, planes, stride)
		self.bn1 = nn.BatchNorm2d(planes)
		self.relu = nn.ReLU(inplace=True)
		self.conv2 = conv3x3(planes, planes)
		self.bn2 = nn.BatchNorm2d(planes)
		self.downsample = downsample
		self.stride = stride

	def forward(self, x):
		identity = x

		out = self.conv1(x)
		out = self.bn1(out)
		out = self.relu(out)

		out = self.conv2(out)
		out = self.bn2(out)

		if self.downsample is not None:
			identity = self.downsample(x)

		out += identity
		out = self.relu(out)

		return out


class Bottleneck(nn.Module):
	expansion = 4

	def __init__(self, inplanes, planes, stride=1, downsample=None):
		super(Bottleneck, self).__init__()
		# Both self.conv2 and self.downsample layers downsample the input when stride != 1
		self.conv1 = conv1x1(inplanes, planes)
		self.bn1 = nn.BatchNorm2d(planes)
		self.conv2 = conv3x3(planes, planes, stride)
		self.bn2 = nn.BatchNorm2d(planes)
		self.conv3 = conv1x1(planes, planes * self.expansion)
		self.bn3 = nn.BatchNorm2d(planes * self.expansion)
		self.relu = nn.ReLU(inplace=True)
		self.downsample = downsample
		self.stride = stride

	def forward(self, x):
		identity = x

		out = self.conv1(x)
		out = self.bn1(out)
		out = self.relu(out)

		out = self.conv2(out)
		out = self.bn2(out)
		out = self.relu(out)

		out = self.conv3(out)
		out = self.bn3(out)

		if self.downsample is not None:
			identity = self.downsample(x)

		out += identity
		out = self.relu(out)

		return out


class ResNet(nn.Module):

	def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):
		super(ResNet, self).__init__()
		self.inplanes = 64
		self.conv1 = quantized_conv(3, 64, kernel_size=7, stride=2, padding=3,
							   bias=False)
		self.bn1 = nn.BatchNorm2d(64)
		self.relu = nn.ReLU(inplace=True)
		self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
		self.layer1 = self._make_layer(block, 64, layers[0])
		self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
		self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
		self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
		self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
		self.fc = bilinear(512 * block.expansion, num_classes)

		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
			elif isinstance(m, nn.BatchNorm2d):
				nn.init.constant_(m.weight, 1)
				nn.init.constant_(m.bias, 0)

		# Zero-initialize the last BN in each residual branch,
		# so that the residual branch starts with zeros, and each residual block behaves like an identity.
		# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
		if zero_init_residual:
			for m in self.modules():
				if isinstance(m, Bottleneck):
					nn.init.constant_(m.bn3.weight, 0)
				elif isinstance(m, BasicBlock):
					nn.init.constant_(m.bn2.weight, 0)

	def _make_layer(self, block, planes, blocks, stride=1):
		downsample = None
		if stride != 1 or self.inplanes != planes * block.expansion:
			downsample = nn.Sequential(
				conv1x1(self.inplanes, planes * block.expansion, stride),
				nn.BatchNorm2d(planes * block.expansion),
			)

		layers = []
		layers.append(block(self.inplanes, planes, stride, downsample))
		self.inplanes = planes * block.expansion
		for _ in range(1, blocks):
			layers.append(block(self.inplanes, planes))

		return nn.Sequential(*layers)

	def forward(self, x):
		x = self.conv1(x)
		x = self.bn1(x)
		x = self.relu(x)
		x = self.maxpool(x)

		x = self.layer1(x)
		x = self.layer2(x)
		x = self.layer3(x)
		x = self.layer4(x)

		x = self.avgpool(x)
		x = x.view(x.size(0), -1)
		

		return x

def resnet18_quan1(pretrained=True, **kwargs):
	"""Constructs a ResNet-18 model.

	Args:
		pretrained (bool): If True, returns a model pre-trained on ImageNet
	"""
	model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
	if pretrained:
		pretrained_dict = model_zoo.load_url(model_urls['resnet18'])
		model_dict = model.state_dict()
		pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
		model_dict.update(pretrained_dict)
		model.load_state_dict(model_dict)
	return model

def resnet34_quan1(pretrained=True, **kwargs):
	"""Constructs a ResNet-34 model.

	Args:
		pretrained (bool): If True, returns a model pre-trained on ImageNet
	"""
	model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
	if pretrained:
		pretrained_dict = model_zoo.load_url(model_urls['resnet34'])
		model_dict = model.state_dict()
		pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
		model_dict.update(pretrained_dict)
		model.load_state_dict(model_dict)
	return model

net2=resnet34_quan1()
net2=net2.cuda()

## generating the trigger using fgsm method
class Attack(object):

	def __init__(self, dataloader, criterion=None, gpu_id=0, 
				 epsilon=0.031, attack_method='pgd'):
		
		if criterion is not None:
			self.criterion =  nn.MSELoss()
		else:
			self.criterion = nn.MSELoss()
			
		self.dataloader = dataloader
		self.epsilon = epsilon
		self.gpu_id = gpu_id #this is integer

		if attack_method == 'fgsm':
			self.attack_method = self.fgsm
		elif attack_method == 'pgd':
			self.attack_method = self.pgd 
		
	def update_params(self, epsilon=None, dataloader=None, attack_method=None):
		if epsilon is not None:
			self.epsilon = epsilon
		if dataloader is not None:
			self.dataloader = dataloader
			
		if attack_method is not None:
			if attack_method == 'fgsm':
				self.attack_method = self.fgsm
			
	
									
	def fgsm(self, model, data, target,tar,ep, data_min=0, data_max=1):
		
		model.eval()
		# perturbed_data = copy.deepcopy(data)
		perturbed_data = data.clone()
		
		perturbed_data.requires_grad = True
		output = model(perturbed_data)
		loss = self.criterion(output[:,tar], target[:,tar])
		#print(loss)
		if perturbed_data.grad is not None:
			perturbed_data.grad.data.zero_()

		loss.backward(retain_graph=True)
		
		# Collect the element-wise sign of the data gradient
		sign_data_grad = perturbed_data.grad.data.sign()
		perturbed_data.requires_grad = False

		with torch.no_grad():
			# Create the perturbed image by adjusting each pixel of the input image
			perturbed_data[:,0:3,150:223,150:223] -= ep*sign_data_grad[:,0:3,150:223,150:223]
			perturbed_data.clamp_(data_min, data_max) 
	
		return perturbed_data

import copy



def int2bin(input, num_bits):
	'''
	convert the signed integer value into unsigned integer (2's complement equivalently).
	'''
	output = input.clone()
	output[input.lt(0)] = 2**num_bits + output[input.lt(0)]
  
	return output


def bin2int(input, num_bits):
	'''
	convert the unsigned integer (2's complement equivantly) back to the signed integer format
	with the bitwise operations. Note that, in order to perform the bitwise operation, the input
	tensor has to be in the integer format.
	'''
	mask = 2**(num_bits-1) - 1
	output = -(input & ~mask) + (input & mask)
	return output
	
def weight_conversion(model):
	'''
	Perform the weight data type conversion between:
		signed integer <==> two's complement (unsigned integer)

	Note that, the data type conversion chosen is depend on the bits:
		N_bits <= 8   .char()   --> torch.CharTensor(), 8-bit signed integer
		N_bits <= 16  .short()  --> torch.shortTensor(), 16 bit signed integer
		N_bits <= 32  .int()	--> torch.IntTensor(), 32 bit signed integer
	'''
	for m in model.modules():
		if isinstance(m, quan_Conv2d) or isinstance(m, quan_Linear):
			w_bin = int2bin(m.weight.data, m.N_bits).char()
			
			m.weight.data = bin2int(w_bin, m.N_bits).float()
	return
use_cuda = torch.cuda.is_available()

device = torch.device("cuda:3")
kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}
################# Options ##################################################
############################################################################
parser = argparse.ArgumentParser(description='Training network for image classification',
								 formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('--data_path', default='/ILSVRC2012_devkit_t12', type=str, help='Path to dataset')
parser.add_argument('--dataset', default='imagenet',type=str, choices=['cifar10', 'cifar100', 'imagenet', 'svhn', 'stl10', 'mnist'],
					help='Choose between Cifar10/100 and ImageNet.')
parser.add_argument('--arch', metavar='ARCH', default='resnet34_quan', choices=model_names,
					help='model architecture: ' + ' | '.join(model_names) + ' (default: resnext29_8_64)')
# Optimization options
parser.add_argument('--epochs', type=int, default=200, help='Number of epochs to train.')
parser.add_argument('--optimizer', type=str, default='SGD', choices=['SGD', 'Adam', 'YF'])
parser.add_argument('--batch_size', type=int, default=1, help='Batch size.')
parser.add_argument('--learning_rate', type=float, default=0.001, help='The Learning Rate.')
parser.add_argument('--momentum', type=float, default=0.9, help='Momentum.')
parser.add_argument('--decay', type=float, default=1e-4, help='Weight decay (L2 penalty).')
parser.add_argument('--schedule', type=int, nargs='+', default=[80, 120],
					help='Decrease learning rate at these epochs.')
parser.add_argument('--gammas', type=float, nargs='+', default=[0.1, 0.1],
					help='LR is multiplied by gamma on schedule, number of gammas should be equal to schedule')
parser.add_argument('--optimize_step', dest='optimize_step', action='store_true',
					help='enable the step size optimization for weight quantization')
# Checkpoints
parser.add_argument('--print_freq', default=100, type=int, metavar='N', help='print frequency (default: 200)')
parser.add_argument('--save_path', type=str, default='./save/', help='Folder to save checkpoints and log.')
parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)')
parser.add_argument('--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set')
parser.add_argument('--fine_tune', dest='fine_tune', action='store_true',
					help='fine tuning from the pre-trained model, force the start epoch be zero')
parser.add_argument('--model_only', dest='model_only', action='store_true', help='only save the model without external utils_')
# Acceleration
parser.add_argument('--ngpu', type=int, default=GPU, help='0 = CPU.')
parser.add_argument('--gpu_id', type=int, default=1, help='device range [0,ngpu-1]')
parser.add_argument('--workers', type=int, default=0, help='number of data loading workers (default: 2)')
# random seed
parser.add_argument('--manualSeed', type=int, default=None, help='manual seed')
# quantization
parser.add_argument('--reset_weight', dest='reset_weight', action='store_true',
					help='enable the weight replacement with the quantized weight')

##########################################################################

args = parser.parse_args()

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
if args.ngpu == 1:
	os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)  # make only device #gpu_id visible, then

args.use_cuda = args.ngpu > 0 and torch.cuda.is_available()  # check GPU

# Give a random seed if no manual configuration
if args.manualSeed is None:
	args.manualSeed = random.randint(1, 10000)
random.seed(args.manualSeed)
torch.manual_seed(args.manualSeed)

if args.use_cuda:
	torch.cuda.manual_seed_all(args.manualSeed)

cudnn.benchmark = True
import time
import random

###############################################################################
######
## parameter 
epoch=40  ## decides the number of bit flips = epoch+1
numb1=40 ## total ranked bits for the first flip
group1=2 ## in order to a rank of numb1 we take group1 amount of candidate from each layer
numb=40 ## total ranked bits for the following bit flips
group=2 ## in order to a rank of numb we take group amount of candidate from each layer


global_l=21
layer_num=21


def main():
	if not os.path.isdir(args.save_path):
		os.makedirs(args.save_path)
	log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w')
	print_log('save path : {}'.format(args.save_path), log)
	state = {k: v for k, v in args._get_kwargs()}
	print_log(state, log)
	print_log("Random Seed: {}".format(args.manualSeed), log)
	print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
	print_log("torch  version : {}".format(torch.__version__), log)
	print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)

	# Init the tensorboard path and writer
	tb_path = os.path.join(args.save_path, 'tb_log')


	# Init dataset
	if not os.path.isdir(args.data_path):
		os.makedirs(args.data_path)

	if args.dataset == 'cifar10':
		mean = [x / 255 for x in [125.3, 123.0, 113.9]]
		std = [x / 255 for x in [63.0, 62.1, 66.7]]
	elif args.dataset == 'cifar100':
		mean = [x / 255 for x in [129.3, 124.1, 112.4]]
		std = [x / 255 for x in [68.2, 65.4, 70.4]]
	elif args.dataset == 'svhn':
		mean = [0.5, 0.5, 0.5]
		std = [0.5, 0.5, 0.5]
	elif args.dataset == 'mnist':
		mean = [0.5, 0.5, 0.5]
		std = [0.5, 0.5, 0.5]
	elif args.dataset == 'imagenet':
		mean = [0.485, 0.456, 0.406]
		std = [0.229, 0.224, 0.225]
	else:
		assert False, "Unknow dataset : {}".format(args.dataset)

	if args.dataset == 'imagenet':
		train_transform = transforms.Compose([
			transforms.RandomResizedCrop(224),
			transforms.RandomHorizontalFlip(),
			transforms.ToTensor(),
			transforms.Normalize(mean, std)
		])
		test_transform = transforms.Compose([
			transforms.Resize(256),
			transforms.CenterCrop(224),
			transforms.ToTensor(),
			transforms.Normalize(mean, std)
		])  # here is actually the validation dataset
	else:
		test_transform = transforms.Compose([
			transforms.ToTensor(),
			transforms.Normalize(mean, std)
		])

	if args.dataset == 'mnist':
		train_data = dset.MNIST(args.data_path, train=True, transform=train_transform, download=True)
		test_data = dset.MNIST(args.data_path, train=False, transform=test_transform, download=True)
		num_classes = 10
	elif args.dataset == 'cifar10':
		train_data = dset.CIFAR10(args.data_path, train=True, transform=train_transform, download=True)
		test_data = dset.CIFAR10(args.data_path, train=False, transform=test_transform, download=True)
		num_classes = 10
	elif args.dataset == 'cifar100':
		train_data = dset.CIFAR100(args.data_path, train=True, transform=train_transform, download=True)
		test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform, download=True)
		num_classes = 100
	elif args.dataset == 'svhn':
		train_data = dset.SVHN(args.data_path, split='train', transform=train_transform, download=True)
		test_data = dset.SVHN(args.data_path, split='test', transform=test_transform, download=True)
		num_classes = 10
	elif args.dataset == 'stl10':
		train_data = dset.STL10(args.data_path, split='train', transform=train_transform, download=True)
		test_data = dset.STL10(args.data_path, split='test', transform=test_transform, download=True)
		num_classes = 10
	elif args.dataset == 'imagenet':
		test_dir = os.path.join(args.data_path, 'val')
		test_data = dset.ImageFolder(test_dir, transform=test_transform)
		num_classes = 1000
	else:
		assert False, 'Do not support dataset : {}'.format(args.dataset)
	test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False,
											  num_workers=args.workers, pin_memory=True)
	train_data = test_data
	print_log("=> creating model '{}'".format(args.arch), log)

	# Init model, criterion, and optimizer
	net = models.__dict__[args.arch](num_classes)
	#print_log("=> network :\n {}".format(net), log)

	if args.use_cuda:
		if args.ngpu > 1:
			net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

	# define loss function (criterion) and optimizer
	criterion = torch.nn.CrossEntropyLoss()
   
	# separate the parameters thus param groups can be updated by different optimizer
	all_param = [
		param for name, param in net.named_parameters()
		if not 'step_size' in name
	]

	step_param = [
		param for name, param in net.named_parameters()
		if 'step_size' in name
	]

	if args.optimizer == "SGD":
		print("using SGD as optimizer")
		optimizer = torch.optim.SGD(all_param,
									lr=state['learning_rate'],
									momentum=state['momentum'], weight_decay=state['decay'], nesterov=True)

	elif args.optimizer == "Adam":
		print("using Adam as optimizer")
		optimizer = torch.optim.Adam(filter(lambda param: param.requires_grad, net.parameters()),
									 lr=state['learning_rate'],
									 weight_decay=state['decay'])

	elif args.optimizer == "YF":
		print("using YellowFin as optimizer")
		optimizer = YFOptimizer(filter(lambda param: param.requires_grad, net.parameters()), lr=state['learning_rate'],
								mu=state['momentum'], weight_decay=state['decay'])


	elif args.optimizer == "RMSprop":
		print("using RMSprop as optimizer")
		optimizer = torch.optim.RMSprop(filter(lambda param: param.requires_grad, net.parameters()),
										lr=state['learning_rate'], alpha=0.99, eps=1e-08, weight_decay=0, momentum=0)

	if args.use_cuda:
		net.cuda()
		criterion.cuda()

	recorder = RecorderMeter(args.epochs)  # count number of epoches

	# optionally resume from a checkpoint
	if args.resume:
		if os.path.isfile(args.resume):
			print_log("=> loading checkpoint '{}'".format(args.resume), log)
			checkpoint = torch.load(args.resume)
			if not (args.fine_tune):
				args.start_epoch = checkpoint['epoch']
				recorder = checkpoint['recorder']
				optimizer.load_state_dict(checkpoint['optimizer'])

			
			state_tmp = net.state_dict()
			if 'state_dict' in checkpoint.keys():
				state_tmp.update(checkpoint['state_dict'])
			else:
				state_tmp.update(checkpoint)

			net.load_state_dict(state_tmp)

			print_log("=> loaded checkpoint '{}' (epoch {})".format(args.resume, args.start_epoch), log)
		else:
			print_log("=> no checkpoint found at '{}'".format(args.resume), log)
	else:
		print_log("=> do not use any checkpoint for {} model".format(args.arch), log)

	# update the step_size once the model is loaded  
	for m in net.modules():
		if isinstance(m, quantized_conv) or isinstance(m, bilinear):
			# simple step size update based on the pretrained model or weight init
			m.__reset_stepsize__() 
	n=0
	# block for quantizer optimization
	if n==1:
		optimizer_quan =  torch.optim.SGD(
			step_param, lr=0.01, momentum=0.9, weight_decay=0,
			nesterov=True)

		for m in net.modules():
			if isinstance(m, quantized_conv) or isinstance(m, bilinear):
				for i in range(300): # runs 200 iterations to reduce quantization error
					
					#print(i)
					optimizer_quan.zero_grad()
					weight_quan = quantize(m.weight, m.step_size, m.half_lvls)*m.step_size
					loss_quan = F.mse_loss(weight_quan, m.weight, reduction='mean')
					loss_quan.backward()
					optimizer_quan.step()

		'''for m in net.modules():
			if isinstance(m, quan_Conv2d):
				print(m.step_size.data.item(), (m.step_size.detach()*m.half_lvls).item(),
						m.weight.max().item())'''

	# block for weight reset
	if n==1:
		for m in net.modules():
			if isinstance(m, quantized_conv) or isinstance(m, bilinear):
				m.__reset_weight__()
				

	weight_conversion(net)
	
	for batch_idx, (data, target) in enumerate(test_loader):
		x,y = data.cuda(), target.cuda()
		_,p=net(x).data.max(1) 
		y=p
		#plt.hist(y.cpu().numpy().flatten())
		#plt.show()
		break
	
	#validate(test_loader, net, criterion, log)
	if True: #args.evaluate:
		
		model_attack = Attack(dataloader=test_loader,
						 attack_method='fgsm', epsilon=0.001)
		net.eval()
		model=copy.deepcopy(net)
		net1=copy.deepcopy(net)
		net1.eval()
		model.eval()
		## performing back propagation to identify the target neurons using a sample test batch of size 128
		for batch_idx, (data, target) in enumerate(test_loader):
			data, target = data.cuda(), target.cuda()
			mins,maxs=data.min(),data.max()
			break
		x_tri = data.clone()#.data[:,:,:,:]
		x_tri *= 0
		x_tri[:,0:3,start:end,start:end] += 255
		x_var, y_var = to_var(data), to_var(target.long()) 
		y_var[:]=targets
		net.eval()
		best_loss=999
		#profile.enable()
		state_dict = torch.load(logdir+'/'+'Resnet18_8bit_all_layers_trojan.pkl')
		net.load_state_dict(state_dict)
		x_tri = data.clone().data[0,:,:,:]
		x_tri[0,:,:] = torch.from_numpy(np.loadtxt(logdir+'trojan_last_layer_img1.txt', dtype=float))
		x_tri[1,:,:] = torch.from_numpy(np.loadtxt(logdir+'trojan_last_layer_img2.txt', dtype=float))
		x_tri[2,:,:] = torch.from_numpy(np.loadtxt(logdir+'trojan_last_layer_img3.txt', dtype=float))

		if TRAIN:
			for n in range(20):
				output = net(data)
				loss = criterion(output, target)
				for name, module in net.named_modules():
					if 'conv' in name or 'fc' in name:
						if module.weight.grad is not None:
							module.weight.grad.data.zero_()
				
				loss.backward()
				for name, module in net.named_modules():
					if 'fc' in name:
						w_v,w_id=module.weight.grad.detach().abs().topk(125)
						tar=w_id[2]
				
				np.savetxt('trojan_test.txt', tar.cpu().numpy(), fmt='%f')
				b = np.loadtxt('trojan_test.txt', dtype=float)
				b=torch.Tensor(b).long().cuda()
				test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False,
													num_workers=4, pin_memory=True)
				

				for t, (x, y) in enumerate( test_loader): 
					x_var, y_var = x.cuda(),y.long().cuda() 
					x_var[:,:,:,:]=0
					x_var[:,0:3,start:end,start:end]=x_tri[0,0:3,start:end,start:end]## initializing the mask   
					break									  
				y=net2(x_var) ##initializaing the target value for trigger generation
				y[:,tar]=HIGH   ### setting the target of certain neurons to a larger value


				# iterating 1000 times to generate the trigger
				for ep in [0.5, 0.1, 0.01, 0.001]:
					for i in range(400):  
						x_tri=model_attack.attack_method(
									net2, x_var.cuda(), y,tar,ep,mins,maxs) 
						x_var=x_tri
			
				np.savetxt(logdir+'trojan_last_layer_img1.txt', x_tri[0,0,:,:].cpu().numpy(), fmt='%f')
				np.savetxt(logdir+'trojan_last_layer_img2.txt', x_tri[0,1,:,:].cpu().numpy(), fmt='%f')
				np.savetxt(logdir+'trojan_last_layer_img3.txt', x_tri[0,2,:,:].cpu().numpy(), fmt='%f')
				
				
				best_loss=train(args, n,net,net1,Nflip, test_data, criterion, x_tri,start, end, targets,best_loss,writer,logdir,log, PAGE_CHECK, GLOBAL, TRAIN)
				#profile.disable()
				#stats = pstats.Stats(profile).sort_stats('ncalls')
				#stats.print_stats()
				net2.load_state_dict(net.state_dict(),strict=False)
		else:
			state_dict = torch.load(logdir+'Resnet18_8bit_all_layers_trojan.pkl')
			net.load_state_dict(state_dict)
			x_tri = data.clone().data[0,:,:,:]
			x_tri[0,:,:] = torch.from_numpy(np.loadtxt(logdir+'trojan_last_layer_img1.txt', dtype=float))
			x_tri[1,:,:] = torch.from_numpy(np.loadtxt(logdir+'trojan_last_layer_img2.txt', dtype=float))
			x_tri[2,:,:] = torch.from_numpy(np.loadtxt(logdir+'trojan_last_layer_img3.txt', dtype=float))

			test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False,
												num_workers=10,pin_memory=True)
			x_var, y_var = to_var(data), to_var(target.long()) 
			y_var[:]=targets
			output = net(x_var)
			loss = criterion(output, y_var)

			for m in net.modules():
				if hasattr(m,'weight'):#if isinstance(m, quantized_conv) or isinstance(m, bilinear):
					if m.weight.grad is not None:
						m.weight.grad.data.zero_()
							
			loss.backward()
			#import ast
			#with open('layer_indices_cft_34.txt','r') as l:
			#	layer_indices = l.read()
			# reconstructing the data as a dictionary
			#layer_indices = ast.literal_eval(layer_indices)
			print('Fine tuned model:')
			acc1, _= validate1(args, test_loader, net, start,end, x_tri,criterion, log,TRAIN)
			acc,_ = validate(args,test_loader,net,criterion,log)

			net = bit_reduction_test(net, net1, Nflip, targets,count=True)
			print('Trojanad model:')
			layer_indices = select_one_parameter_per_page(net,net1, Nflip,PAGE_CHECK=1)
			net = update_parameters(net,net1,layer_indices)
			net = bit_reduction_test(net, net1, Nflip, targets)
			acc1, _= validate1(args, test_loader, net, start,end, x_tri,criterion, log,TRAIN)
			acc,_ = validate(args,test_loader,net,criterion,log)
		if TRAIN: 
			writer.close()
		return

	# Main loop
	start_time = time.time()
	epoch_time = AverageMeter()

	for epoch in range(args.start_epoch, args.epochs):
		current_learning_rate, current_momentum = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)
		# Display simulation time
		need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs - epoch))
		need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

		print_log(
			'\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [LR={:6.4f}][M={:1.2f}]'.format(time_string(), epoch, args.epochs,
																				   need_time, current_learning_rate,
																				   current_momentum) \
			+ ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False),
															   100 - recorder.max_accuracy(False)), log)

		# # ============ TensorBoard logging ============#
		# # we show the model param initialization to give a intuition when we do the fine tuning

		# for name, param in net.named_parameters():
		#	 name = name.replace('.', '/')
		#	 if "delta_th" not in name:
		#		 writer.add_histogram(name, param.clone().cpu().detach().numpy(), epoch)

		# # ============ TensorBoard logging ============#
		

		# evaluate on validation set
		val_acc, val_los = validate(test_loader, net, criterion, log)
		recorder.update(epoch,  val_los, val_acc)
		is_best = val_acc >= recorder.max_accuracy(False)

		if args.model_only:
			checkpoint_state = {'state_dict': net.state_dict}
		else:
			checkpoint_state = {
				'epoch': epoch + 1,
				'arch': args.arch,
				'state_dict': net.state_dict(),
				'recorder': recorder,
				'optimizer': optimizer.state_dict(),
			}

		save_checkpoint(checkpoint_state, is_best, args.save_path, 'checkpoint.pth.tar', log)
		torch.save(net.state_dict(), '8bitw.pkl')
		# measure elapsed time
		epoch_time.update(time.time() - start_time)
		start_time = time.time()
		recorder.plot_curve(os.path.join(args.save_path, 'curve.png'))

		# save addition accuracy log for plotting
		accuracy_logger(base_dir=args.save_path,
						epoch=epoch,
						test_accuracy=val_acc)

		   # ============ TensorBoard logging ============#

	#log.close()

def pred_batch(x, model):
	"""
	batch prediction helper
	"""
	y_pred = np.argmax(model(to_var(x)).data.cpu().numpy(), axis=1)
	return torch.from_numpy(y_pred)


from torchvision.transforms import ToPILImage
to_img = ToPILImage()


def tts(val_loader,model,criterion,losses,top1,top5):
	for i, (input, target) in enumerate(val_loader):
			if args.use_cuda:
				target = target.cuda(non_blocking =True)
				input = input.cuda()

			# compute output
			output = model(input)
			loss = criterion(output, target)
			 
			# measure accuracy and record loss
			prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
			losses.update(loss.item(), input.size(0))
			top1.update(prec1.item(), input.size(0))
			top5.update(prec5.item(), input.size(0))
	print(i) 


def save_checkpoint(state, is_best, save_path, filename, log):
	filename = os.path.join(save_path, filename)
	torch.save(state, filename)
	if is_best:  # copy the checkpoint to the best model if it is the best_accuracy
		bestname = os.path.join(save_path, 'model_best.pth.tar')
		shutil.copyfile(filename, bestname)
		print_log("=> Obtain best accuracy, and update the best model", log)


def adjust_learning_rate(optimizer, epoch, gammas, schedule):
	"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
	lr = args.learning_rate
	mu = args.momentum

	if args.optimizer != "YF":
		assert len(gammas) == len(schedule), "length of gammas and schedule should be equal"
		for (gamma, step) in zip(gammas, schedule):
			if (epoch >= step):
				lr = lr * gamma
			else:
				break
		for param_group in optimizer.param_groups:
			param_group['lr'] = lr

	elif args.optimizer == "YF":
		lr = optimizer._lr
		mu = optimizer._mu

	return lr, mu



if __name__ == '__main__':
	main()
