from mi_utils import *
from classify import *
from generator import *
from discri import *
from torch.utils.data import DataLoader
from torch.optim import Adadelta, Adam
from torch.nn import BCELoss, DataParallel
from torchvision.utils import save_image
from torch.autograd import grad
import torchvision.transforms as transforms
import torch
import time
import random
import os, logging
import numpy as np
from mi_bilevel_new33 import inversion_bi

from generator import Generator
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from copy import deepcopy


# logger
def get_logger():
	logger_name = "main-logger"
	logger = logging.getLogger(logger_name)
	logger.setLevel(logging.INFO)
	handler = logging.StreamHandler()
	fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s"
	handler.setFormatter(logging.Formatter(fmt))
	logger.addHandler(handler)
	return logger


public_name = 'tsrd'
target_name = 'gtsrb'
n_private_classes = 43


if __name__ == "__main__":
	global args, logger

	parser = ArgumentParser(description='Step2: targeted recovery')
	parser.add_argument('--device', type=str, default='4,5,6,7', help='Device to use. Like cuda, cuda:0 or cpu')
	parser.add_argument('--improved_flag', action='store_true', default=False, help='use improved k+1 GAN')
	parser.add_argument('--dist_flag', action='store_true', default=False, help='use distributional recovery')
	parser.add_argument('--grad_reg', action='store_true', default=False, help='add gradient regularizer')
	parser.add_argument('--per', action='store_true', default=False, help='add gradient regularizer')
	parser.add_argument('--model_name',  type=str, default='gtsrbsmooth_6_02', help='target model')
	parser.add_argument('--lamda2',  type=int, default=1, help='weight of grad loss')
	parser.add_argument('--lamda3',  type=int, default=0, help='weight of ft loss')
	
	args = parser.parse_args()
	logger = get_logger()

	logger.info(args)
	logger.info("=> creating model ...")

	log_path = "./mi_{}_binew33_ablation".format(target_name)
    # log_path = "./mi_{}_uapnew4".format(target_name)
	os.makedirs(log_path, exist_ok=True)
	log_file = "{}.txt".format(args.model_name)
	Tee(os.path.join(log_path, log_file), 'w')

	print("=> Using grad_red:", args.grad_reg)
	print("=> Using per_loss:", args.per)
	os.environ["CUDA_VISIBLE_DEVICES"] = args.device
   
   
	
	
	z_dim = 100
	###########################################
	###########     load model       ##########
	###########################################
	G = GeneratorCIFAR(z_dim)
	G = torch.nn.DataParallel(G).cuda()
	D = DGWGAN32(3)
	path_G = './binaryGAN/{}_G.tar'.format(public_name)
	path_D = './binaryGAN/{}_D.tar'.format(public_name)

	
	lamda1 = 1000

	# save_img_dir = log_path + '/{}_all_prior1_lamda1_{}_lamda2{}'.format(args.model_name, lamda1, args.lamda2) # all attack imgs
	save_img_dir = './mi_{}_binew33_ablation/{}_all_prior1_grad{}_per{}'.format(target_name, args.model_name, args.grad_reg, args.per) # all attack imgs
	os.makedirs(save_img_dir, exist_ok=True)

	
	D = torch.nn.DataParallel(D).cuda()
	ckp_G = torch.load(path_G)
	G.load_state_dict(ckp_G['state_dict'], strict=False)
	ckp_D = torch.load(path_D)
	D.load_state_dict(ckp_D['state_dict'], strict=False)

	T = VGG('small_VGG16').to(device)
	if args.model_name == 'wanet':
		key = 'netC'
		path_T = '/zero-knowledge-backdoor/1_MI_IBAU/GTSRB_model_and_eval/checkpoint/gtsrb_all2one_morph.pth.tar'
	elif args.model_name == 'iab':
		key = 'netC'
		T = PreActResNet18(num_classes=n_private_classes).cuda()
		path_T = '/zero-knowledge-backdoor/1_MI_IBAU/GTSRB_model_and_eval/checkpoint/iab_all2one_gtsrb_ckpt.pth.tar'
	else:
		key = 'net'
		path_T = '/zero-knowledge-backdoor/1_MI_IBAU/GTSRB_model_and_eval/checkpoint/{}_ckpt.pth'.format(args.model_name)
	ckp_T = torch.load(path_T)
	T.load_state_dict(ckp_T[key])

	E = deepcopy(T)

	inver_func = inversion_bi
	num_seeds = 20



	############         attack     ###########
	logger.info("=> Begin attacking ...")

	aver_acc, aver_acc5, aver_var, aver_var5 = 0, 0, 0, 0
	for i in range(1):
		iden = [torch.from_numpy(np.arange(n_private_classes))]

		for idx in range(1):
			print("--------------------- Attack batch [%s]------------------------------" % idx)
			
			acc, acc5, var, var5 = inver_func(G, D, T, E, iden[idx], save_img_dir, itr=i, lamda2=args.lamda2, lamda3=args.lamda3, lr=2e-2, momentum=0.9, lamda=lamda1, iter_times=4500, clip_range=1, improved=args.improved_flag, num_seeds=num_seeds, args=args)
			
			aver_acc += acc / 1
			aver_acc5 += acc5 / 1
			aver_var += var / 1
			aver_var5 += var5 / 1

	print("Average Acc:{:.2f}\tAverage Acc5:{:.2f}\tAverage Acc_var:{:.4f}\tAverage Acc_var5:{:.4f}".format(aver_acc, aver_acc5, aver_var, aver_var5))
