import torch, os, time, random, generator, discri, classify, mi_utils
import numpy as np 
import torch.nn as nn
import torchvision.utils as tvls
import torch.nn.functional as F
from mi_utils import log_sum_exp, save_tensor_images
from torch.autograd import Variable
import torch.optim as optim
import torch.autograd as autograd
import statistics 
import torch.distributions as tdist
from biGAN import freeze, unfreeze
from copy import deepcopy
import hypergrad as hg
import higher
from perturbations import *
from tensorboardX import SummaryWriter



def to_var(x, requires_grad=True):
	if torch.cuda.is_available():
		x = x.cuda()
	return Variable(x, requires_grad=requires_grad)


def update_params(z_meta, lr_inner, first_order=False, source_params=None, detach=False):
	if source_params is not None: #this one is used: source_params is the grad
		grad = source_params
		if first_order:
			grad = to_var(grad.detach().data)
		tmp = z_meta - lr_inner * grad
		return tmp
	
	else: 
		raise ValueError("Source params cannot be None.")


def inversion_bi(G, D, T, E, iden, save_img_dir, itr, lamda2, lamda3, lr=2e-2, momentum=0.9, lamda=100, iter_times=1500, clip_range=1, improved=False, num_seeds=5,args=None, clean_data=None):
	log_path = save_img_dir + '/logs/'
	writer = SummaryWriter(log_path)
	
	iden = iden.view(-1).long().cuda()
	criterion = nn.CrossEntropyLoss().cuda()
	cos = torch.nn.CosineSimilarity(dim=1) # cosine similarity loss
	bs = iden.shape[0]

	if clean_data:
		x_cln, y_cln = clean_data
		x_cln = x_cln[iden[0]:iden[0]+bs].cuda()
		y_cln = y_cln[iden[0]:iden[0]+bs].cuda()
		assert torch.equal(y_cln, iden)
		print("Using clean data.")
	
	G.eval()
	D.eval()
	T.eval()
	E.eval()

	def gen_seed(random_seed):
		r_idx = random_seed
		torch.manual_seed(random_seed) 
		torch.cuda.manual_seed(random_seed) 
		np.random.seed(random_seed) 
		random.seed(random_seed)

		z = torch.randn(bs, 100).cuda().float()
		z.requires_grad = True
		v = torch.zeros(bs, 100).cuda().float()

		shapee = G(z).shape[-1]
		batch_pert = torch.zeros((3,shapee,shapee)).cuda()
		batch_pert.requires_grad = True
		batch_opt = torch.optim.Adam(params=[batch_pert],lr=1e-3)#NOTE: 

			
		for i in range(iter_times):
			if args.per:

				## ---------------------			
				## virtual update of z
				z_meta = deepcopy(z)
				z_meta.requires_grad = True
				fake_meta = G(z_meta)
				ori_logits = T(fake_meta)[-1]

				# portion = 0.025
				portion = 0.1
				rand_idx = random.sample(list(np.arange(bs)),int(bs*portion))
				patching = torch.zeros_like(fake_meta).cuda()
				patching[rand_idx] = batch_pert
				per_logits = T(torch.clamp(fake_meta+patching,min=0,max=1))[-1]

				# per_logits = T(torch.clamp(fake_meta+batch_pert,min=0,max=1))[-1]
				inner_loss = - torch.mean(cos(per_logits, ori_logits)) #maximize cosine similarity
				

				if z_meta.grad is not None:
					z_meta.grad.data.zero_()
				grads = torch.autograd.grad(inner_loss, z_meta, create_graph=True) #If create_graph, graph of the derivative will be constructed, allowing to compute higher order derivative products.
				z_meta = update_params(z_meta, lr_inner=lr, source_params=grads[0])
				del grads


				## ---------------------
				## update pert
				batch_pert.requires_grad = True
				fake_meta = G(z_meta)
				ori_logits = T(fake_meta)[-1]
				per_logits = T(torch.clamp(fake_meta+batch_pert,min=0,max=1))[-1]
				meta_loss = torch.mean(cos(per_logits, ori_logits)) #minimize cosine similarity
				batch_opt.zero_grad()
				meta_loss.backward()
				batch_opt.step()
				batch_pert.requires_grad = False




				## ---------------------
				## real update z						
				fake = G(z)			
				ori_logits = T(fake)[-1]

				rand_idx = random.sample(list(np.arange(bs)),int(bs*portion))
				patching = torch.zeros_like(fake).cuda()
				patching[rand_idx] = batch_pert
				per_logits = T(torch.clamp(fake+patching,min=0,max=1))[-1]

				# per_logits = T(torch.clamp(fake+batch_pert,min=0,max=1))[-1]
				Iden_Loss_per = - torch.mean(cos(per_logits, ori_logits)) #maximize cosine similarity

			
			else:
				fake = G(z)
				Iden_Loss_per = 0
			

			# regular procedure
			if improved == True:
				_, label =  D(fake)
			else:
				label = D(fake)
			
			# out = T(fake)[-1]
			feat_gen, out = T(fake)


			if improved:
				Prior_Loss = torch.mean(F.softplus(log_sum_exp(label))) - torch.mean(log_sum_exp(label))
			else:
				Prior_Loss = - label.mean()

			Iden_Loss = criterion(out, iden)

			
			# Gradient regularizer
			for p in T.parameters():
				if p.grad is not None:
					p.grad.data.zero_() 
			loss_grad = 0
			loss_grad2 = 0
			if args.grad_reg:
				Iden_Loss.backward(retain_graph=True)
				cnt_p = 0
				for p in T.parameters():
					if p.grad is not None:
						cnt_p += 1
						loss_grad += torch.norm(p.grad, 2)
						loss_grad2 += torch.norm(p.grad, 1)
				loss_grad /= cnt_p
			
			
			## feature consistency loss
			Feat_Loss = 0
			if clean_data:
				fet_cln = T(x_cln)[0]			
				Feat_Loss = torch.mean((feat_gen - fet_cln).abs())

			Total_Loss = 1 * Prior_Loss + lamda * Iden_Loss + lamda2 * loss_grad + Iden_Loss_per + lamda * Feat_Loss

			
			if z.grad is not None:
				z.grad.data.zero_()
			Total_Loss.backward()
			
			v_prev = v.clone()
			gradient = z.grad.data
			v = momentum * v - lr * gradient
			z = z + ( - momentum * v_prev + (1 + momentum) * v)
			z = torch.clamp(z.detach(), -clip_range, clip_range).float()
			z.requires_grad = True

			Prior_Loss_val = Prior_Loss.item()
			Iden_Loss_val = Iden_Loss.item()
			Iden_Loss_per_val = Iden_Loss_per
			Feat_Loss_val = Feat_Loss
			
			if random_seed == 0 and iden[0] == 0:
				writer.add_scalar('prior_loss', Prior_Loss, i)
				writer.add_scalar('iden_loss', Iden_Loss, i)
				writer.add_scalar('per_loss', Iden_Loss_per, i)
				writer.add_scalar('grad_loss', loss_grad, i)
				writer.add_scalar('feat_loss', Feat_Loss, i)

			if (i+1) % 300 == 0:
				fake_img = G(z.detach())
				eval_prob = T(fake_img)[-1]
				eval_iden = torch.argmax(eval_prob, dim=1).view(-1)
				acc = iden.eq(eval_iden.long()).sum().item() * 1.0 / bs
				print("Iteration:{}\tPrior Loss:{:.5f}\tIden Loss:{:.5f}\tIden Loss per:{:.5f}\tGrad Loss:{:.5f}\tFeat Loss:{:.5f}\tAttack Acc:{:.5f}".format(i+1, Prior_Loss_val, Iden_Loss_val, Iden_Loss_per_val, loss_grad, Feat_Loss_val, acc))

		return z




	flag = torch.zeros(bs)
	no = torch.zeros(bs) # index for saving all success attack images

	res = []
	res5 = []
	
	
	
	seed_cnt = num_seeds
	random_seed = 100
	while(seed_cnt > 0):
		tf = time.time()
		print("Generating seed {}....".format(random_seed))
		z = gen_seed(random_seed)
		
		fake = G(z)
		eval_prob = T(fake)[-1]
		eval_iden = torch.argmax(eval_prob, dim=1).view(-1)
		
		cnt, cnt5 = 0, 0
		for i in range(bs):
			gt = iden[i].item()
			sample = G(z)[i]
			if save_img_dir is not None:
				os.makedirs(save_img_dir, exist_ok=True)

			if eval_iden[i].item() == gt:
				# seed_acc[i, r_idx] = 1
				cnt += 1
				flag[i] = 1
				best_img = G(z)[i]
				no[i] += 1
				save_tensor_images(sample.detach(), os.path.join(save_img_dir, "attack_iden_{}_{}.png".format(gt, random_seed)))

			_, top5_idx = torch.topk(eval_prob[i], 5)
			if gt in top5_idx:
				cnt5 += 1		
		
		interval = time.time() - tf
		acc = cnt * 1.0 / bs
		print("Time:{:.5f}\tAcc:{:.5f}\t".format(interval, acc))
		
		# if acc >= 0.99: # this seed success
		# 	res.append(cnt * 1.0 / bs)
		# 	res5.append(cnt5 * 1.0 / bs)
		# 	seed_cnt -= 1

		res.append(cnt * 1.0 / bs) #NOTE!!!!!!!!!!!
		res5.append(cnt5 * 1.0 / bs)
		seed_cnt -= 1
		
		random_seed += 1
		torch.cuda.empty_cache()


	
	acc, acc_5 = statistics.mean(res), statistics.mean(res5)
	acc_var = statistics.variance(res)
	acc_var5 = statistics.variance(res5)
	print("Acc:{:.5f}\tAcc_5:{:.5f}\tAcc_var:{:.4f}\tAcc_var5:{:.4f}".format(acc, acc_5, acc_var, acc_var5))
	
	
	return acc, acc_5, acc_var, acc_var5






