from PIL import Image
import random

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torchvision import datasets, models, transforms
import time
import os
import copy
import logging
import sys
import configparser
import glob
from tqdm import tqdm
from dataset import LabeledDataset
from alexnet_fc7out import NormalizeByChannelMeanStd
# from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.vision_transformer import VisionTransformer
import pdb
from functools import partial
# from vit_grad_rollout import VITAttentionGradRollout
# from vit_grad_rollout import *
import cv2
import torch.nn.functional as F
from pytorch_grad_cam import GradCAM,FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

# from pytorch_grad_cam import GradCAM, \
#     ScoreCAM, \
#     GradCAMPlusPlus, \
#     AblationCAM, \
#     XGradCAM, \
#     EigenCAM, \
#     EigenGradCAM, \
#     LayerCAM, \


config = configparser.ConfigParser()
config.read(sys.argv[1])

experimentID = config["experiment"]["ID"]

options = config["finetune"]
clean_data_root	= options["clean_data_root"]
poison_root	= options["poison_root"]
gpu         = int(options["gpu"])
epochs      = int(options["epochs"])
patch_size  = int(options["patch_size"])
eps         = int(options["eps"])
rand_loc    = options.getboolean("rand_loc")
trigger_id  = int(options["trigger_id"])
num_poison  = int(options["num_poison"])
num_classes = int(options["num_classes"])
# batch_size  = int(options["batch_size"])
batch_size = 64
logfile     = options["logfile"].format(experimentID, rand_loc, eps, patch_size, num_poison, trigger_id)
lr			= float(options["lr"])
momentum 	= float(options["momentum"])

options = config["poison_generation"]
target_wnid = options["target_wnid"]
source_wnid_list = options["source_wnid_list"].format(experimentID)
save=True
with open(source_wnid_list) as f2:
	source_wnids = f2.readlines()
	source_wnids = [s.strip() for s in source_wnids]
source_wnid = source_wnids[0]
# pdb.set_trace()
num_source = int(options["num_source"])
edge_length =30 #default - 30

checkpointDir = "finetuned_models_badnets/" + experimentID + "/rand_loc_" +  str(rand_loc) + "/eps_" + str(eps) + \
				"/patch_size_" + str(patch_size) + "/num_poison_" + str(num_poison) + "/trigger_" + str(trigger_id)
# checkpointDir = "badnet_models/" + experimentID + "/rand_loc_" +  str(rand_loc) + "/eps_" + str(eps) + \
# 				"/patch_size_" + str(patch_size) + "/num_poison_" + str(num_poison) + "/trigger_" + str(trigger_id)
save_path ='grad_cam_attention_maps_with_top_edge30_badnets/' + experimentID + "/rand_loc_" +  str(rand_loc) + "/eps_" + str(eps) + \
				"/patch_size_" + str(patch_size) + "/num_poison_" + str(num_poison) + "/trigger_" + str(trigger_id)
# _for_top
if not os.path.exists(os.path.dirname(checkpointDir)):
	# os.makedirs(os.path.dirname(checkpointDir))
	raise ValueError('Checkpoint directory does not exist')
# pdb.set_trace()
if not os.path.exists(save_path):
	os.makedirs(save_path)
	os.makedirs(os.path.join(save_path,'patched'))
	os.makedirs(os.path.join(save_path,'patched_target'))
	os.makedirs(os.path.join(save_path,'patched_top'))
	os.makedirs(os.path.join(save_path,'notpatched_top'))
	os.makedirs(os.path.join(save_path,'patched_blocked'))
	os.makedirs(os.path.join(save_path,'notpatched_blocked'))
# create heatmap from mask on image
def show_cam_on_image(img, mask):
	# img = np.float32(img) / 255
	# pdb.set_trace()
	heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
	heatmap = np.float32(heatmap) / 255
	cam = heatmap + np.float32(img)
	cam = cam / np.max(cam)
	return np.uint8(255 * cam)

#logging
# if not os.path.exists(os.path.dirname(logfile)):
# 		os.makedirs(os.path.dirname(logfile))

# logging.basicConfig(
# level=logging.INFO,
# format="%(asctime)s %(message)s",
# handlers=[
# 	logging.FileHandler(logfile, "w"),
# 	logging.StreamHandler()
# ])
#
# logging.info("Experiment ID: {}".format(experimentID))
#

# Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception]
# model_name = "deit_tiny_patch16_224"
# model_name = "deit_small_patch16_224"
# model_name = 'deit_base_patch16_224'
# model_name = 'vgg16_bn'
model_name = 'resnet50'
# model_name = 'resnet18'
# Flag for feature extracting. When False, we finetune the whole model,
#   when True we only update the reshaped layer params
feature_extract = True
class_dir_list = sorted(os.listdir('/datasets/imagenet/train'))


class TokenDropVisionTransformer(VisionTransformer):
    def forward_features(self, x, token_keep_inds):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed

        if token_keep_inds is not None:
            inds = torch.cat((
                torch.zeros((B, 1)).cuda().long(),
                token_keep_inds + 1
            ), dim=1)
            inds = inds.unsqueeze(-1).expand(-1, -1, x.shape[-1])
            x = torch.gather(x, dim=1, index=inds)

        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        return x[:, 0]

    def forward(self, x, token_keep_inds=None):
        x = self.forward_features(x, token_keep_inds)
        x = self.head(x)
        return x
#############################################


def calculate_IoU(boxA, boxB):
    # determine the (x, y)-coordinates of the intersection rectangle
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])

    # compute the area of intersection rectangle
    interArea = abs(max((xB - xA, 0)) * max((yB - yA), 0))
    if interArea == 0:
        return 0

    # compute the area of both the prediction and ground-truth
    # rectangles
    boxAArea = abs((boxA[2] - boxA[0]) * (boxA[3] - boxA[1]))
    boxBArea = abs((boxB[2] - boxB[0]) * (boxB[3] - boxB[1]))

    # compute the intersection over union by taking the intersection
    # area and dividing it by the sum of prediction + ground-truth
    # areas - the interesection area
    iou = interArea / float(boxAArea + boxBArea - interArea)

    # return the intersection over union value
    return iou

# def save_checkpoint(state, filename='checkpoint.pth.tar'):
# 	if not os.path.exists(os.path.dirname(filename)):
# 		os.makedirs(os.path.dirname(filename))
# 	torch.save(state, filename)

trans_trigger = transforms.Compose([transforms.Resize((patch_size, patch_size)),
									transforms.ToTensor(),
									transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
									])

trigger = Image.open('data/triggers/trigger_{}.png'.format(trigger_id)).convert('RGB')
trigger = trans_trigger(trigger).unsqueeze(0).cuda(gpu)
sum_conv = torch.ones(4).reshape((1, 1, 2, 2)).cuda(gpu)
#############################################

def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False):
	assert optimizer is None,'Optimizer is not None, Training might occur'
	since = time.time()

	best_model_wts = copy.deepcopy(model.state_dict())
	best_acc = 0.0

	test_acc_arr = np.zeros(num_epochs)
	zoomed_test_acc_arr = np.zeros(num_epochs)
	patched_acc_arr = np.zeros(num_epochs)
	notpatched_acc_arr = np.zeros(num_epochs)


	for epoch in range(1):
		# adjust_learning_rate(optimizer, epoch)
		# logging.info('Epoch {}/{}'.format(epoch, num_epochs - 1))
		# logging.info('-' * 10)

		print('Epoch:1')

		# Each epoch has a training and validation phase
		# for phase in ['test', 'notpatched', 'patched']:
		for phase in ['patched','notpatched']: #,,,,,'notpatched','notpatched'
			top_all_CH = list()
			target_all_CH = list()
			pos_x = list()
			pos_y = list()
			# save patch location
			patch_loc = list()

			# content heatmap scores
			# attack_success_CH = list()
			# all_CH = list()

			# IoU
			# all_IoU = list()
			# attack_success_IoU = list()


			target_IoU = list()
			top_IoU = list()
			target_success_IoU = list()
			if phase == 'train':
				# model.train()  # Set model to training mode
				assert False,'Model in Training mode'
			else:
				model.eval()   # Set model to evaluate mode

			running_loss = 0.0
			running_corrects = 0
			running_source_corrects = 0
			zoomed_asr = 0
			zoomed_source_acc = 0
			zoomed_acc = 0
			# Set nn in patched phase to be higher if you want to cover variability in trigger placement
			if phase == 'patched':
				nn=1
			else:
				nn=1

			for ctr in range(0, nn):
				# Iterate over data.
				debug_idx= 0
				for inputs, labels,paths in tqdm(dataloaders[phase]):
					debug_idx+=1
					inputs = inputs.cuda(gpu)
					labels = labels.cuda(gpu)
					source_labels = class_dir_list.index(source_wnid)*torch.ones_like(labels).cuda(gpu)
					if phase == 'patched':
						random.seed(1)
						for z in range(inputs.size(0)):
							if not rand_loc:
								start_x = 224-patch_size-5
								start_y = 224-patch_size-5
							else:
								start_x = random.randint(0, 224-patch_size-1)
								start_y = random.randint(0, 224-patch_size-1)
							pos_y.append(start_y)
							pos_x.append(start_x)
							# patch_loc.append((start_x, start_y))
							inputs[z, :, start_y:start_y+patch_size, start_x:start_x+patch_size] = trigger#

					# zero the parameter gradients
					# optimizer.zero_grad()

					# forward
					# track history if only in train
					if True:
						# Get model outputs and calculate loss
						# Special case for inception because in training it has an auxiliary output. In train
						#   mode we calculate the loss by summing the final output and the auxiliary output
						#   but in testing we only consider the final output.
						# model.zero_grad()
						if is_inception and phase == 'train':
							# From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
							outputs, aux_outputs = model(inputs)
							loss1 = criterion(outputs, labels)
							loss2 = criterion(aux_outputs, labels)
							loss = loss1 + 0.4*loss2
						else:
							# for z in range(len(paths)):
								# if "loss_00000_epoch_00_n01807496_4978_n04509417_566_kk_00101.png" in paths[z] or "loss_00000_epoch_00_n01807496_7186_n04509417_1076_kk_00102.png" in paths[z] or "loss_00000_epoch_00_n01807496_1963_n04509417_1517_kk_00103.png" in paths[z] or "loss_00000_epoch_00_n01807496_11236_n04509417_4553_kk_00101.png" in paths[z]:pdb.set_trace()
							with torch.no_grad():
								outputs = model(inputs)
								loss = criterion(outputs, labels)

						_, preds = torch.max(outputs, 1)
						zoomed_outputs = torch.zeros(outputs.shape).cuda()

						if (phase == 'patched' or phase =='notpatched' or phase =='test') :
							for b1 in range(inputs.shape[0]):
								# class_idx = outputs[b1].unsqueeze(0).data.topk(1, dim=1)[1][0].tolist()[0]
								class_idx = np.argmax(outputs[b1].unsqueeze(0).cpu().data.numpy(), axis=-1)
								target_label = labels[b1].item()
								cam = GradCAM(model=model,
                                   target_layers=layers,
                                   use_cuda=True,
                                   )
								# cam = FullGrad(model=model,
                                #    target_layers=layers,
                                #    use_cuda=True,
                                #    )
								# target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1)
								top_idx = [ClassifierOutputTarget(category) for category in class_idx]
								top_mask = cam(input_tensor=inputs[b1].unsqueeze(0).cuda(),targets=top_idx)[0]
								if phase == 'patched' :
									target_idx = [ClassifierOutputTarget(category) for category in [labels[b1].item()]]
									target_mask =  cam(input_tensor=inputs[b1].unsqueeze(0).cuda(),targets= target_idx)[0]
								# top_mask = cam(input_tensor=inputs[b1].unsqueeze(0).cuda(),targets=class_idx)
								# pdb.set_trace()


								# cam.activations_and_grads.release()
								del cam
								# attention_rollout = VITAttentionGradRollout(model,
								# 	discard_ratio=0.9)
								#
								# top_mask = attention_rollout(inputs[b1].unsqueeze(0).cuda(),category_index = class_idx)
								# attention_rollout.clear_cache()



								# attention_rollout2 = VITAttentionGradRollout_Batch(model,
								# 	discard_ratio=0.9)

								# top_mask2 = attention_rollout2(inputs[b1].unsqueeze(0).cuda(),category_indices = torch.LongTensor([class_idx]))
								# attention_rollout2.clear_cache()

								# attention_rollout2.remove_hooks()


								# attention_rollout.attentions = []
								# attention_rollout.attention_gradients = []
								# target_mask = attention_rollout(inputs[b1].unsqueeze(0).cuda(),category_index = labels[b1].item())

								# np_img = (np_img - np_img.min()) / (np_img.max() - np_img.min())
								# np_img = np.array(img)[:, :, ::-1]
								# top_mask = cv2.resize(top_mask, (np_img.shape[1], np_img.shape[0]))
								# target_mask = cv2.resize(target_mask, (np_img.shape[1], np_img.shape[0]))



								# convolve scaled gradcam with a filter to get max regions
								# filter = torch.ones((patch_size+1, patch_size+1))
								# filter = filter.view(1, 1, patch_size+1, patch_size+1)
								filter = torch.ones((edge_length+1, edge_length+1))
								filter = filter.view(1, 1, edge_length+1, edge_length+1)

								if phase == 'patched' :
									target_mask_torch = torch.from_numpy(target_mask)
									target_mask_torch = target_mask_torch.unsqueeze(0).unsqueeze(0)
									target_mask_conv = F.conv2d(input=target_mask_torch,
																			weight=filter, padding=patch_size//2)
									# target_mask_conv = target_mask_torch.clone()
									target_mask_conv = target_mask_conv.squeeze()
									target_mask_conv = target_mask_conv.numpy()

									target_max_cam_ind = np.unravel_index(np.argmax(target_mask_conv), target_mask_conv.shape)
									target_y = target_max_cam_ind[0]
									target_x = target_max_cam_ind[1]

									# alternate way to choose small region which ensures args.edge_length x args.edge_length is always chosen
									if int(target_y-(edge_length/2)) < 0:
										target_y_min = 0
										target_y_max = edge_length
									elif int(target_y+(edge_length/2)) > inputs.size(2):
										target_y_max = inputs.size(2)
										target_y_min = inputs.size(2) - edge_length
									else:
										target_y_min = int(target_y-(edge_length/2))
										target_y_max = int(target_y+(edge_length/2))

									if int(target_x-(edge_length/2)) < 0:
										target_x_min = 0
										target_x_max = edge_length
									elif int(target_x+(edge_length/2)) > inputs.size(3):
										target_x_max = inputs.size(3)
										target_x_min = inputs.size(3) - edge_length
									else:
										target_x_min = int(target_x-(edge_length/2))
										target_x_max = int(target_x+(edge_length/2))

                                                                top_feature_mask = torch.from_numpy(target_mask)
                                                                mask = torch.ones_like(top_feature_mask)
                                                                batch_inds = torch.arange(mask.shape[0], device=mask.device)
                                                                fmap = top_feature_mask.unsqueeze(1)
                                                                fmap_sums = F.conv2d(fmap, sum_conv)
                                                                flat_fmap_sums = fmap_sums.view(fmap.shape[0], -1)
                                                                ij = flat_fmap_sums.max(dim=-1).indices
                                                                w = fmap_sums.shape[-1]
                                                                i, j = ij // w, ij % w
                                                                mask[batch_inds, i, j] = 0
                                                                mask[batch_inds, i+1, j] = 0
                                                                mask[batch_inds, i, j+1] = 0
                                                                mask[batch_inds, i+1, j+1] = 0
                                                                mask = mask.view(mask.shape[0], -1)
                                                                inds = mask.sort(dim=-1).indices
                                                                token_keep_inds = inds[:, 4:]
                                                                token_keep_inds = token_keep_inds.sort(dim=-1).values
                                                                ##################################################

								# convolve scaled gradcam with a filter to get max regions
								top_mask_torch = torch.from_numpy(top_mask)
								top_mask_torch = top_mask_torch.unsqueeze(0).unsqueeze(0)

								top_mask_conv = F.conv2d(input=top_mask_torch,
																		weight=filter, padding=patch_size//2)

								# top_mask_conv = top_mask_torch.clone()
								top_mask_conv = top_mask_conv.squeeze()
								top_mask_conv = top_mask_conv.numpy()

								top_max_cam_ind = np.unravel_index(np.argmax(top_mask_conv), top_mask_conv.shape)
								top_y = top_max_cam_ind[0]
								top_x = top_max_cam_ind[1]

								# alternate way to choose small region which ensures args.edge_length x args.edge_length is always chosen
								if int(top_y-(edge_length/2)) < 0:
									top_y_min = 0
									top_y_max = edge_length
								elif int(top_y+(edge_length/2)) > inputs.size(2):
									top_y_max = inputs.size(2)
									top_y_min = inputs.size(2) - edge_length
								else:
									top_y_min = int(top_y-(edge_length/2))
									top_y_max = int(top_y+(edge_length/2))

								if int(top_x-(edge_length/2)) < 0:
									top_x_min = 0
									top_x_max = edge_length
								elif int(top_x+(edge_length/2)) > inputs.size(3):
									top_x_max = inputs.size(3)
									top_x_min = inputs.size(3) - edge_length
								else:
									top_x_min = int(top_x-(edge_length/2))
									top_x_max = int(top_x+(edge_length/2))

								# BLOCK - with black patch
								zoomed_input = copy.deepcopy(invTrans(inputs[b1]))
								if phase == 'patched':
									zoomed_input[:, top_y_min:top_y_max, top_x_min:top_x_max] = 0*torch.ones(3, top_y_max-top_y_min, top_x_max-top_x_min)

									# zoomed_input[:, target_y_min:target_y_max, target_x_min:target_x_max] = 0*torch.ones(3, target_y_max-target_y_min, target_x_max-target_x_min)
									zoom_path = os.path.join(save_path,'patched_blocked','image_'+str(b1)+'_target_'+str(labels[b1].item())+'_top_pred_'+str(class_idx[0])+'.png')
								else:
									zoomed_input[:, top_y_min:top_y_max, top_x_min:top_x_max] = 0*torch.ones(3, top_y_max-top_y_min, top_x_max-top_x_min)
									# zoomed_input[:, target_y_min:target_y_max, target_x_min:target_x_max] = 0*torch.ones(3, target_y_max-target_y_min, target_x_max-target_x_min)
									zoom_path = os.path.join(save_path,'notpatched_blocked','image_'+str(b1)+'_target_'+str(labels[b1].item())+'_top_pred_'+str(class_idx[0])+'.png')
								if save:
									cv2.imwrite(zoom_path,np.uint8(255 * zoomed_input.permute(1, 2, 0).data.cpu().numpy()[:, :, ::-1]))
								with torch.no_grad():
									zoomed_input = normalize_fn(zoomed_input, token_keep_inds)
                                                                        #############################################
									zoomed_outputs[b1] = model(zoomed_input.unsqueeze(0).cuda())[0]

								# mask = show_cam_on_image(np_img, mask)
								# np_img = np.array(img)[:, :, ::-1]
								if phase =='patched':
									#Calculate Energy
									top_scaled_mask = top_mask / top_mask.sum()

									top_attn_mask = torch.zeros((224, 224), dtype=torch.uint8)
									# pdb.set_trace()
									top_attn_mask[pos_y[b1]:pos_y[b1] + patch_size, pos_x[b1]:pos_x[b1] + patch_size] = 1
									top_attn_mask = top_attn_mask.cpu().numpy()
									# calculate content heatmap
									top_total_gcam_inside_gt_mask = (top_attn_mask * top_scaled_mask).sum()
									# total_gcam_inside_gt_mask = gcam_inside_gt_mask.sum()

									top_all_CH.append(top_total_gcam_inside_gt_mask)


									#Calculate Energy
									target_scaled_mask = target_mask / target_mask.sum()

									target_attn_mask = torch.zeros((224, 224), dtype=torch.uint8)
									# pdb.set_trace()
									target_attn_mask[pos_y[b1]:pos_y[b1] + patch_size, pos_x[b1]:pos_x[b1] + patch_size] = 1
									target_attn_mask = target_attn_mask.cpu().numpy()
									# calculate content heatmap
									target_total_gcam_inside_gt_mask = (target_attn_mask * target_scaled_mask).sum()
									# total_gcam_inside_gt_mask = gcam_inside_gt_mask.sum()


									target_all_CH.append(target_total_gcam_inside_gt_mask)


									target_iou_b1 = calculate_IoU((target_x_min, target_y_min, target_x_max, target_y_max),
															(pos_x[b1], pos_y[b1], pos_x[b1]+patch_size, pos_y[b1]+patch_size))
									target_IoU.append(target_iou_b1)


									top_iou_b1 = calculate_IoU((top_x_min, top_y_min, top_x_max, top_y_max),
															(pos_x[b1], pos_y[b1], pos_x[b1]+patch_size, pos_y[b1]+patch_size))
									top_IoU.append(top_iou_b1)

									if class_idx[0] == labels[b1].item():
										# attack_success_CH.append(total_gcam_inside_gt_mask)
										target_success_IoU.append(target_iou_b1)
								np_img = invTrans(inputs[b1]).permute(1, 2, 0).data.cpu().numpy()
								top_mask = show_cam_on_image(np_img, top_mask)


								if phase == 'patched':
									target_mask = show_cam_on_image(np_img, target_mask)
									top_im_path = os.path.join(save_path,'patched_top','image_'+str(b1)+'_target_'+str(labels[b1].item())+'_top_pred_'+str(class_idx[0])+'_attn.png')
									target_im_path = os.path.join(save_path,'patched_target','image_'+str(b1)+'_target_'+str(labels[b1].item())+'_top_pred_'+str(class_idx[0])+'_attn.png')

									orig_path = os.path.join(save_path,'patched','image_'+str(b1)+'_target_'+str(labels[b1].item())+'_top_pred_'+str(class_idx[0])+'.png')
									if save:
										cv2.imwrite(top_im_path, top_mask)
										cv2.imwrite(target_im_path, target_mask)
										cv2.imwrite(orig_path, np.uint8(255 * np_img[:, :, ::-1]))
								else:
									im_path = os.path.join(save_path,'notpatched_top','image_'+str(b1)+'_target_'+str(labels[b1].item())+'_top_pred_'+str(class_idx[0])+'_attn.png')
									if save:
										cv2.imwrite(im_path, top_mask)



						# if phase =='train':
						# 	if debug_idx==500 and epoch>=0:
						# 		for inp2, lab2,paths2 in tqdm(dataloaders['patched']):
						# 			inp2 = inp2.cuda(gpu)
						# 			lab2 = lab2.cuda(gpu)
						# 			random.seed(1)
						# 			for z in range(inp2.size(0)):
						# 				if not rand_loc:
						# 					start_x = 224-patch_size-5
						# 					start_y = 224-patch_size-5
						# 				else:
						# 					start_x = random.randint(0, 224-patch_size-1)
						# 					start_y = random.randint(0, 224-patch_size-1)
						#
						# 				inp2[z, :, start_y:start_y+patch_size, start_x:start_x+patch_size] = trigger#
									# out2 = model(inp2)
									# # _, preds = torch.max(outputs, 1)
									# _,preds2 = torch.topk(out2,5,1)
									# for patched_idx in range(inp2.shape[0]):
									# 	logging.info('Image Number:{}\tTarget Label:{}\tTop-5 predictions:{}\t{}\t{}\t{}\t{}\n'.format(patched_idx,lab2[patched_idx],preds2[patched_idx,0],preds2[patched_idx,1],preds2[patched_idx,2],preds2[patched_idx,3],preds2[patched_idx,4] ))
									# pdb.set_trace()
						# backward + optimize only if in training phase
						# if phase == 'train':
						# 	loss.backward()
						# 	optimizer.step()
					_, zoomed_preds = torch.max(zoomed_outputs, 1)
					# statistics
					running_loss += loss.item() * inputs.size(0)
					running_corrects += torch.sum(preds == labels.data)
					running_source_corrects += torch.sum(preds == source_labels.data)
					zoomed_asr += torch.sum(zoomed_preds == labels.data)
					zoomed_source_acc += torch.sum(zoomed_preds == source_labels.data)
					print("\nVal_acc {:3f}".format(running_corrects.double()* 100/(debug_idx*batch_size)))
					print("\nBlocked Val_acc {:3f}".format(zoomed_asr.double()* 100/(debug_idx*batch_size)))
			epoch_loss = running_loss / len(dataloaders[phase].dataset) / nn
			epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset) / nn
			epoch_source_acc = running_source_corrects.double() / len(dataloaders[phase].dataset) / nn

			zoomed_acc = zoomed_asr.double() / len(dataloaders[phase].dataset) / nn

			zoomed_source_acc = zoomed_source_acc.double() / len(dataloaders[phase].dataset) / nn
			zoomed_target_acc = zoomed_asr.double() / len(dataloaders[phase].dataset) / nn



			# logging.info('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
			print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
			if phase == 'test':
				test_acc_arr[epoch] = epoch_acc
				zoomed_test_acc_arr[epoch] = zoomed_acc
				print("\nVal_acc {:3f}".format(epoch_acc* 100))
				print("\nblocked_Val_acc {:3f}".format(zoomed_acc* 100))
			if phase == 'patched':
				patched_acc_arr[epoch] = epoch_acc
				print("\ntarget_all_CH {:3f}".format(sum(target_all_CH) / len(target_all_CH) * 100))
				print("\ntop_all_CH {:3f}".format(sum(top_all_CH) / len(top_all_CH) * 100))


				print("\ntarget_iou {:3f}".format(sum(target_IoU) / len(target_IoU) ))
				# print("\ntarget_succes_iou {:3f}".format(sum(target_success_IoU) / len(target_success_IoU) ))
				print("\ntop_iou {:3f}".format(sum(top_IoU) / len(top_IoU) ))
				if len(target_success_IoU) >0:
					print("\ntarget_success_iou {:3f}".format(sum(target_success_IoU) / len(target_success_IoU) ))
				print("\nblocked_target_acc {:3f}".format(zoomed_target_acc* 100))
				print("\nblocked_source_acc {:3f}".format(zoomed_source_acc* 100))

				# logging.info('Patched Targeted Attack Success Rate: Mean {:.3f}'
				# 			 .format(epoch_acc))
			if phase == 'notpatched':
				notpatched_acc_arr[epoch] = epoch_acc
				print("\nsource_acc {:3f}".format(epoch_source_acc* 100))
				print("\nblocked_source_acc {:3f}".format(zoomed_source_acc* 100))
			# deep copy the model
			if phase == 'test' and (epoch_acc > best_acc):
				# logging.info("Better model found!")
				best_acc = epoch_acc
				best_model_wts = copy.deepcopy(model.state_dict())


	time_elapsed = time.time() - since
	# logging.info('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
	# logging.info('Max Test Acc: {:4f}'.format(best_acc))
	# logging.info('Last 10 Epochs Test Acc: Mean {:.3f} Std {:.3f} '
	# 			 .format(test_acc_arr[-10:].mean(),test_acc_arr[-10:].std()))
	# logging.info('Last 10 Epochs Patched Targeted Attack Success Rate: Mean {:.3f} Std {:.3f} '
	# 			 .format(patched_acc_arr[-10:].mean(),patched_acc_arr[-10:].std()))
	# logging.info('Last 10 Epochs NotPatched Targeted Attack Success Rate: Mean {:.3f} Std {:.3f} '
	# 			 .format(notpatched_acc_arr[-10:].mean(),notpatched_acc_arr[-10:].std()))

	# sort_idx = np.argsort(test_acc_arr)
	# top10_idx = sort_idx[-10:]
	# logging.info('10 Epochs with Best Acc- Test Acc: Mean {:.3f} Std {:.3f} '
	# 			 .format(test_acc_arr[top10_idx].mean(),test_acc_arr[top10_idx].std()))
	# logging.info('10 Epochs with Best Acc- Patched Targeted Attack Success Rate: Mean {:.3f} Std {:.3f} '
	# 			 .format(patched_acc_arr[top10_idx].mean(),patched_acc_arr[top10_idx].std()))
	# logging.info('10 Epochs with Best Acc- NotPatched Targeted Attack Success Rate: Mean {:.3f} Std {:.3f} '
	# 			 .format(notpatched_acc_arr[top10_idx].mean(),notpatched_acc_arr[top10_idx].std()))

	# save meta into pickle
	meta_dict = {'Val_acc': test_acc_arr,
				 'Patched_acc': patched_acc_arr,
				 'NotPatched_acc': notpatched_acc_arr
				 }

	# load best model weights
	# model.load_state_dict(best_model_wts)
	return model, meta_dict


def set_parameter_requires_grad(model, feature_extracting):
	if feature_extracting:
		for param in model.parameters():
			param.requires_grad = False


def initialize_model(model_name, num_classes, feature_extract, use_pretrained=False):
	# Initialize these variables which will be set in this if statement. Each of these
	#   variables is model specific.
	model_ft = None
	input_size = 0
	# normalize = NormalizeByChannelMeanStd(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
	if model_name == "resnet":
		""" Resnet18
		"""
		model_ft = models.resnet18(pretrained=False)
		set_parameter_requires_grad(model_ft, feature_extract)
		num_ftrs = model_ft.fc.in_features
		# model_ft.fc = nn.Linear(num_ftrs, num_classes)
		input_size = 224

	elif model_name == "alexnet":
		""" Alexnet
		"""
		model_ft = models.alexnet(pretrained=False)
		set_parameter_requires_grad(model_ft, feature_extract)
		num_ftrs = model_ft.classifier[6].in_features
		# model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
		input_size = 224

	elif model_name == "vgg":
		""" VGG11_bn
		"""
		model_ft = models.vgg11_bn(pretrained=False)
		# set_parameter_requires_grad(model_ft, feature_extract)
		num_ftrs = model_ft.classifier[6].in_features
		# model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
		input_size = 224

	elif model_name == "vgg16":
		""" VGG16
		"""
		model_ft = models.vgg16(pretrained=False)
		# set_parameter_requires_grad(model_ft, feature_extract)
		num_ftrs = model_ft.classifier[6].in_features
		checkpoint = torch.load(os.path.join(checkpointDir, "poisoned_model.pt"))
		# checkpoint = torch.load(os.path.join(checkpointDir, "clean_model.pt"))
		state = checkpoint['state_dict']
		state_keys = list(state.keys())
		# pdb.set_trace()
		for i, key in enumerate(state_keys):
		    # print(key)
		    # if "wa_feature." in key:
			if key.startswith('1.'):
				# an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
				#
				newkey = key[2:]
				# newkey = key
				state[newkey] = state.pop(key)
			else:
				state.pop(key)
		# pdb.set_trace()
		model_ft.load_state_dict(state)
		# set_parameter_requires_grad(model_ft, feature_extract)
		# model_ft.head = nn.Linear(num_ftrs, num_classes)
		# model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
		input_size = 224
		layers = [model_ft.features]

	elif model_name == "vgg16_bn":
		""" VGG16_bn
		"""
		model_ft = models.vgg16_bn(pretrained=False)
		# set_parameter_requires_grad(model_ft, feature_extract)
		num_ftrs = model_ft.classifier[6].in_features
		# # model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
		# input_size = 224
		checkpoint = torch.load(os.path.join(checkpointDir, "poisoned_model.pt"))
		# checkpoint = torch.load(os.path.join(checkpointDir, "clean_model.pt"))
		state = checkpoint['state_dict']
		state_keys = list(state.keys())
		# pdb.set_trace()
		for i, key in enumerate(state_keys):
		    # print(key)
		    # if "wa_feature." in key:
			if key.startswith('1.'):
				# an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
				#
				newkey = key[2:]
				# newkey = key
				state[newkey] = state.pop(key)
			else:
				state.pop(key)
		# pdb.set_trace()
		model_ft.load_state_dict(state)
		# model_ft.head = nn.Linear(num_ftrs, num_classes)
		input_size = 224
		layers = [model_ft.features]

	elif model_name == "resnet18":
		""" ResNet18
		"""
		model_ft = models.resnet18(pretrained=False)
		# set_parameter_requires_grad(model_ft, feature_extract)
		num_ftrs = model_ft.fc.in_features
		# # model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
		# input_size = 224
		checkpoint = torch.load(os.path.join(checkpointDir, "poisoned_model.pt"))
		# checkpoint = torch.load(os.path.join(checkpointDir, "clean_model.pt"))
		# model_ft.load_state_dict(checkpoint['state_dict'])
		# set_parameter_requires_grad(model_ft, feature_extract)

		state = checkpoint['state_dict']
		state_keys = list(state.keys())
		# pdb.set_trace()
		for i, key in enumerate(state_keys):
		    # print(key)
		    # if "wa_feature." in key:
			if key.startswith('1.'):
				# an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
				#
				newkey = key[2:]
				# newkey = key
				state[newkey] = state.pop(key)
			else:
				state.pop(key)
		# pdb.set_trace()
		model_ft.load_state_dict(state)
		# num_ftrs = model_ft.num_features
		layers = [model_ft.layer4]
		# model_ft.head = nn.Linear(num_ftrs, num_classes)
		input_size = 224

	elif model_name == "resnet50":
		""" ResNet50
		"""
		model_ft = models.resnet50(pretrained=False)
		# set_parameter_requires_grad(model_ft, feature_extract)
		num_ftrs = model_ft.fc.in_features
		# model_ft = nn.Sequential(model_ft)

		# model_ft = nn.Sequential(normalize, model_ft)


		checkpoint = torch.load(os.path.join(checkpointDir, "poisoned_model.pt"))

		state = checkpoint['state_dict']
		state_keys = list(state.keys())
		# pdb.set_trace()
		for i, key in enumerate(state_keys):
		    # print(key)
		    # if "wa_feature." in key:
			if key.startswith('1.'):
				# an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
				#
				newkey = key[2:]
				# newkey = key
				state[newkey] = state.pop(key)
			else:
				state.pop(key)
		# pdb.set_trace()
		model_ft.load_state_dict(state)
		# checkpoint = torch.load(os.path.join(checkpointDir, "clean_model.pt"))
		# pdb.set_trace()
		# model_ft.load_state_dict(checkpoint['state_dict'])
		# set_parameter_requires_grad(model_ft, feature_extract)
		# model_ft.head = nn.Linear(num_ftrs, num_classes)
		# model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
		input_size = 224

		layers = [model_ft.layer4]

	elif model_name == "squeezenet":
		""" Squeezenet"""
		model_ft = models.squeezenet1_0(pretrained=False)
		set_parameter_requires_grad(model_ft, feature_extract)
		# model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
		model_ft.num_classes = num_classes
		input_size = 224

	elif model_name == "densenet":
		""" Densenet
		"""
		model_ft = models.densenet121(pretrained=False)
		set_parameter_requires_grad(model_ft, feature_extract)
		num_ftrs = model_ft.classifier.in_features
		# model_ft.classifier = nn.Linear(num_ftrs, num_classes)
		input_size = 224

	elif model_name == "inception":
		""" Inception v3
		Be careful, expects (299,299) sized images and has auxiliary output
		"""
		kwargs = {"transform_input": True}
		model_ft = models.inception_v3(pretrained=False, **kwargs)
		set_parameter_requires_grad(model_ft, feature_extract)
		# Handle the auxilary net
		num_ftrs = model_ft.AuxLogits.fc.in_features
		model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
		# Handle the primary net
		num_ftrs = model_ft.fc.in_features
		# model_ft.fc = nn.Linear(num_ftrs,num_classes)
		input_size = 299

	elif model_name == 'deit_tiny_patch16_224':
		model_ft = TokenDropVisionTransformer(
		    patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
		    norm_layer=partial(nn.LayerNorm, eps=1e-6))
                #############################################
		model_ft.default_cfg = _cfg()

		# checkpoint = torch.hub.load_state_dict_from_url(
		#     url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
		#     map_location="cpu", check_hash=True
		# )
		# model_ft.load_state_dict(checkpoint["model"])
		set_parameter_requires_grad(model_ft, feature_extract)
		# num_ftrs = model_ft.num_features
		# weights = model_ft.head.weight.clone()
		# bias = model_ft.head.bias.clone()
		# model_ft.head = nn.Linear(num_ftrs, num_classes)
		# model_ft.head.weight.data = weights
		# model_ft.head.bias.data = bias
		# # nn.init.zeros_(model_ft.head.weight)
		# nn.init.constant_(model_ft.head.bias, 0.0)
		input_size = 224
	elif model_name == 'deit_small_patch16_224':
		model_ft = TokenDropVisionTransformer(
		    patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
		    norm_layer=partial(nn.LayerNorm, eps=1e-6))
                #############################################
		model_ft.default_cfg = _cfg()

		checkpoint = torch.hub.load_state_dict_from_url(
		    url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
		    map_location="cpu", check_hash=True
		)
		model_ft.load_state_dict(checkpoint["model"])
		set_parameter_requires_grad(model_ft, feature_extract)
		num_ftrs = model_ft.num_features
		# model_ft.head = nn.Linear(num_ftrs, num_classes)
		input_size = 224
	elif model_name == 'deit_base_patch16_224':
		model_ft = TokenDropVisionTransformer(
		    patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
		    norm_layer=partial(nn.LayerNorm, eps=1e-6))
                #############################################
		model_ft.default_cfg = _cfg()
		# checkpoint = torch.hub.load_state_dict_from_url(
		#     url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
		#     map_location="cpu", check_hash=True
		# )
		checkpoint = torch.load(os.path.join(checkpointDir, "poisoned_model.pt"))
		# checkpoint = torch.load(os.path.join(checkpointDir, "clean_model.pt"))
		model_ft.load_state_dict(checkpoint['state_dict'])
		# set_parameter_requires_grad(model_ft, feature_extract)
		num_ftrs = model_ft.num_features
		# model_ft.head = nn.Linear(num_ftrs, num_classes)
		input_size = 224
	else:
		# logging.info("Invalid model name, exiting...")
		print("Invalid model name, exiting...")
		exit()

	return model_ft, input_size,layers

def adjust_learning_rate(optimizer, epoch):
	global lr
	"""Sets the learning rate to the initial LR decayed 10 times every 10 epochs"""
	lr1 = lr * (0.1 ** (epoch // 10))
	for param_group in optimizer.param_groups:
		param_group['lr'] = lr1


# Train poisoned model
# logging.info("Loading poisoned model...")
print("Loading poisoned model...")
# Initialize the model for this run
model_ft, input_size,layers = initialize_model(model_name, num_classes, feature_extract, use_pretrained=False)
# logging.info(model_ft)

# Transforms
data_transforms = transforms.Compose([
		transforms.Resize((input_size, input_size)),
		transforms.ToTensor(),
		transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])

invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
                                                     std = [ 1/0.229, 1/0.224, 1/0.225 ]),
                                transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
                                                     std = [ 1., 1., 1. ]),])

normalize_fn = transforms.Compose([ transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])


# inv_tensor = invTrans(inp_tensor)
# logging.info("Initializing Datasets and Dataloaders...")
print('Initializing Datasets and Dataloaders...')
# Training dataset
# if not os.path.exists("data/{}/finetune_filelist.txt".format(experimentID)):
'''
with open("data/transformer/{}/finetune_filelist.txt".format(experimentID), "w") as f1:
	with open(source_wnid_list) as f2:
		source_wnids = f2.readlines()
		source_wnids = [s.strip() for s in source_wnids]

	if num_classes==1000:
		wnid_mapping = {}
		all_wnids = sorted(glob.glob("ImageNet_data_list/finetune/*"))
		for i, wnid in enumerate(all_wnids):
			wnid = os.path.basename(wnid).split(".")[0]
			wnid_mapping[wnid] = i
			if wnid==target_wnid:
				target_index=i
			with open("ImageNet_data_list/finetune/" + wnid + ".txt", "r") as f2:
				lines = f2.readlines()
				for line in lines:
					f1.write(line.strip() + " " + str(i) + "\n")

	else:
		for i, source_wnid in enumerate(source_wnids):
			with open("ImageNet_data_list/finetune/" + source_wnid + ".txt", "r") as f2:
				lines = f2.readlines()
				for line in lines:
					f1.write(line.strip() + " " + str(i) + "\n")

		with open("ImageNet_data_list/finetune/" + target_wnid + ".txt", "r") as f2:
			lines = f2.readlines()
			for line in lines:
				f1.write(line.strip() + " " + str(num_source) + "\n")

# Test dataset
# if not os.path.exists("data/{}/test_filelist.txt".format(experimentID)):
with open("data/transformer/{}/test_filelist.txt".format(experimentID), "w") as f1:
	with open(source_wnid_list) as f2:
		source_wnids = f2.readlines()
		source_wnids = [s.strip() for s in source_wnids]


	if num_classes==1000:
		all_wnids = sorted(glob.glob("ImageNet_data_list/test/*"))
		for i, wnid in enumerate(all_wnids):
			wnid = os.path.basename(wnid).split(".")[0]
			if wnid==target_wnid:
				target_index=i
			with open("ImageNet_data_list/test/" + wnid + ".txt", "r") as f2:
				lines = f2.readlines()
				for line in lines:
					f1.write(line.strip() + " " + str(i) + "\n")

	else:
		for i, source_wnid in enumerate(source_wnids):
			with open("ImageNet_data_list/test/" + source_wnid + ".txt", "r") as f2:
				lines = f2.readlines()
				for line in lines:
					f1.write(line.strip() + " " + str(i) + "\n")

		with open("ImageNet_data_list/test/" + target_wnid + ".txt", "r") as f2:
			lines = f2.readlines()
			for line in lines:
				f1.write(line.strip() + " " + str(num_source) + "\n")

# Patched/Notpatched dataset
with open("data/transformer/{}/patched_filelist.txt".format(experimentID), "w") as f1:
	with open(source_wnid_list) as f2:
		source_wnids = f2.readlines()
		source_wnids = [s.strip() for s in source_wnids]

	if num_classes==1000:
		for i, source_wnid in enumerate(source_wnids):
			with open("ImageNet_data_list/test/" + source_wnid + ".txt", "r") as f2:
				lines = f2.readlines()
				for line in lines:
					f1.write(line.strip() + " " + str(target_index) + "\n")

	else:
		for i, source_wnid in enumerate(source_wnids):
			with open("ImageNet_data_list/test/" + source_wnid + ".txt", "r") as f2:
				lines = f2.readlines()
				for line in lines:
					f1.write(line.strip() + " " + str(num_source) + "\n")
'''
# Poisoned dataset
saveDir = poison_root + "/" + experimentID + "/rand_loc_" +  str(rand_loc) + "/eps_" + str(eps) + \
					"/patch_size_" + str(patch_size) + "/trigger_" + str(trigger_id)
filelist = sorted(glob.glob(saveDir + "/*"))
if num_poison > len(filelist):
	# logging.info("You have not generated enough poisons to run this experiment! Exiting.")
	print("You have not generated enough poisons to run this experiment! Exiting.")
	sys.exit()
'''
if num_classes==1000:
	with open("data/transformer/{}/poison_filelist.txt".format(experimentID), "w") as f1:
		for file in filelist[:num_poison]:
			f1.write(os.path.basename(file).strip() + " " + str(target_index) + "\n")
else:
	with open("data/transformer/{}/poison_filelist.txt".format(experimentID), "w") as f1:
		for file in filelist[:num_poison]:
			f1.write(os.path.basename(file).strip() + " " + str(num_source) + "\n")

'''

dataset_clean = LabeledDataset(clean_data_root + "/train",
							   "data/{}/finetune_filelist.txt".format(experimentID), data_transforms)
dataset_test = LabeledDataset(clean_data_root + "/val",
							  "data/{}/test_filelist.txt".format(experimentID), data_transforms)
dataset_patched = LabeledDataset(clean_data_root + "/val",
								 "data/{}/patched_filelist.txt".format(experimentID), data_transforms)
dataset_notpatched = LabeledDataset(clean_data_root + "/val",
								 "data/{}/patched_filelist.txt".format(experimentID), data_transforms)
dataset_poison = LabeledDataset(saveDir,
								"data/{}/poison_filelist.txt".format(experimentID), data_transforms)
# pdb.set_trace()
dataset_train = torch.utils.data.ConcatDataset((dataset_clean, dataset_poison))

dataloaders_dict = {}
dataloaders_dict['train'] =  torch.utils.data.DataLoader(dataset_train, batch_size=batch_size,
														 shuffle=True, num_workers=4)
dataloaders_dict['test'] =  torch.utils.data.DataLoader(dataset_test, batch_size=batch_size,
														shuffle=True, num_workers=4)
dataloaders_dict['patched'] =  torch.utils.data.DataLoader(dataset_patched, batch_size=batch_size,
														   shuffle=False, num_workers=0)
dataloaders_dict['notpatched'] =  torch.utils.data.DataLoader(dataset_notpatched, batch_size=batch_size,
															  shuffle=False, num_workers=0)

# logging.info("Number of clean images: {}".format(len(dataset_clean)))
# logging.info("Number of poison images: {}".format(len(dataset_poison)))
print("Number of clean images: {}".format(len(dataset_clean)))
print("Number of poison images: {}".format(len(dataset_poison)))


# Gather the parameters to be optimized/updated in this run. If we are
#  finetuning we will be updating all parameters. However, if we are
#  doing feature extract method, we will only update the parameters
#  that we have just initialized, i.e. the parameters with requires_grad
#  is True.
params_to_update = model_ft.parameters()
# logging.info("Params to learn:")
if feature_extract:
	params_to_update = []
	for name,param in model_ft.named_parameters():
		if param.requires_grad == True:
			params_to_update.append(param)
			# logging.info(name)
			print(name)
else:
	for name,param in model_ft.named_parameters():
		if param.requires_grad == True:
			# logging.info(name)
			print(name)
# params_to_update = model_ft.parameters() # debug
# optimizer_ft = optim.SGD(params_to_update, lr=lr, momentum = momentum)
optimizer_ft = None
# Setup the loss fxn
criterion = nn.CrossEntropyLoss()


model = model_ft.cuda(gpu)

# Train and evaluate
model, meta_dict = train_model(model, dataloaders_dict, criterion, optimizer_ft,
								  num_epochs=epochs, is_inception=(model_name=="inception"))

#
# save_checkpoint({
# 				'arch': model_name,
# 				'state_dict': model.state_dict(),
# 				'meta_dict': meta_dict
# 				}, filename=os.path.join(checkpointDir, "poisoned_model.pt"))
