# coding: utf-8

import os
import os.path as osp
import pickle
import random
import argparse
import copy
import itertools

import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
import matplotlib.pyplot as plt

from projection_01_KKT import Projection_kkt

from utils import load_data, load_citation, accuracy,sparse_mx_to_torch_sparse_tensor,rl_state_exist
from models import GCN

from normalization import fetch_normalization, row_normalize

from env import Env


parser = argparse.ArgumentParser()
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='Disables CUDA training.')
parser.add_argument('--dataset', type=str, default="cora",
                        help='Dataset to use.')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()


class  GNNBox(object):

	def __init__(self,node = None,B1=None,B2=None,cost=None,dataset="cora"):
		adj, features, labels, idx_train, idx_val, idx_test,graph = load_citation(dataset,cuda=False)
		self.x = features
		self.y = labels
		self.tensor_adjacency = adj
		self.idx_test = idx_test
		self.graph = graph
		modelfile = "nodegcn_{}.pkl".format(dataset)
		if not osp.exists(modelfile):
			print("The GNN model is not found")
			raise Exception("The GNN model is not found") 
		else:
			#self.model = torch.load('nodegnn.pkl',map_location=torch.device('cpu'))
			self.model = torch.load(modelfile,map_location=torch.device('cpu'))

		num_nodes = labels.shape[0]
		self.targetNode = node if node != None else 825
		self.B1 = B1 if B1 != None else 5
		self.B2 = B2 if B2 != None else 9
		self.cost = cost if cost != None else np.random.random(num_nodes-1)
		self.acc0 = self.cal_accuracy()

		if not osp.exists('targetnodeset.pkl'):
			self.targetNodeSet = []
			count = 0
			while count < 100:
				node = random.choice(range(num_nodes))
				if node in self.targetNodeSet:
					continue

				if self.y[node] == self.queryNode(node).argmax():
					self.targetNodeSet.append(node)
					count+=1
				with open('targetnodeset.pkl', "wb") as f:
					pickle.dump(self.targetNodeSet, f)
		else:
			self.targetNodeSet = pickle.load(open('targetnodeset.pkl', "rb"))

		self.original_Sd = self.adj2Sd(self.targetNode)
		


	@staticmethod
	def build_adjacency(adj_dict):
		"""根据邻接表创建邻接矩阵"""
		edge_index = []
		num_nodes = len(adj_dict)
		for src, dst in adj_dict.items():
			edge_index.extend([src, v] for v in dst)
			edge_index.extend([v, src] for v in dst)
			# 去除重复的边
		edge_index = list(k for k, _ in itertools.groupby(sorted(edge_index)))
		edge_index = np.asarray(edge_index)
		adjacency = sp.coo_matrix((np.ones(len(edge_index)), 
		                           (edge_index[:, 0], edge_index[:, 1])),
		            shape=(num_nodes, num_nodes), dtype="float32")
		return adjacency

	def cal_accuracy(self):
		self.model.eval()
		output = self.model(self.x, self.tensor_adjacency)
		acc_test = accuracy(output[self.idx_test], self.y[self.idx_test])
		return acc_test.item()

	def adj2Sd(self,node): # find the adjacency of node in dict form
		num_nodes = self.y.shape[0]
		friends = self.graph[node][:]
		unfriends = [i for i in range(num_nodes) if i not in friends and i!=node]
		Sd = {"follow":friends,"unfollow":unfriends}
		return Sd

	def queryNode(self,node):
		return self.model(self.x, self.tensor_adjacency)[node].cpu().detach().numpy()

	def perturb(self,s_dict):
		graph = copy.deepcopy(self.graph)
		graph[self.targetNode] = [i for i in s_dict["follow"]]
		adj = self.build_adjacency(graph)
		# build symmetric adjacency matrix
		adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
		adj_normalizer = fetch_normalization("AugNormAdj")
		adj = adj_normalizer(adj)
		adj = sparse_mx_to_torch_sparse_tensor(adj).float()
		if args.cuda:
			adj = adj.cuda()
		return adj
	def queryBox(self,s_dict):
		adjacency = self.perturb(s_dict)
		return self.model(self.x, adjacency)[self.targetNode].cpu().detach().numpy()

	def pgdocobandit_loss(self,T,eta,delta,alpha):
		attack_loss = []
		num_nodes = self.y.shape[0]
		print(num_nodes)
		v = np.zeros(num_nodes-1)
		count = 0
		for t in range(T):
			Sd = {"follow":[],"unfollow":[]}
			Sd["follow"].extend(self.original_Sd["follow"][:])
			Sd["unfollow"].extend(self.original_Sd["unfollow"][:])
			u = np.random.uniform(-1,1,num_nodes-1)#np.random.random(num_nodes-1)
			sum2 = np.sum(u**2)
			sum1 = np.sqrt(sum2)
			u = u/sum1 # normarlize to unit sphere
			s_hat = v + delta*u
			projection2 = Projection_kkt(num_nodes-1,self.B1,s_hat,1.0)#Projection2(num_nodes-1,self.B1,s_hat,0.0)
			s_hat = projection2.solution()
			
			perturbindex = np.argsort(s_hat)[::-1]
			perturbindex = perturbindex[:self.B1]

			s1 = np.sum(s_hat[perturbindex])
			newperturbindex = []
			for i in perturbindex:
				if s_hat[i]/s1 >= np.random.random():
					newperturbindex.append(i)
			perturbindex = newperturbindex

			for i in perturbindex:
				if i < self.targetNode:
					if i in Sd["follow"][:]:
						print("*"*200)
						print(i)
						Sd["follow"].remove(i)
						Sd["unfollow"].append(i)
					else:
						Sd["unfollow"].remove(i)
						Sd["follow"].append(i)
				else:
					if i+1 in Sd["follow"][:]:
						print("1*"*200)
						print(i+1)
						Sd["follow"].remove(i+1)
						Sd["unfollow"].append(i+1)
					else:
						Sd["unfollow"].remove(i+1)
						Sd["follow"].append(i+1)
			# for i in range(num_nodes):
			# 	if i < self.targetNode:
			# 		if i in friendsindex:
			# 			Sd["follow"].append(i)
			# 		else:
			# 			Sd["unfollow"].append(i)
			# 	else:
			# 		if i == self.targetNode:
			# 			continue
			# 		else:
			# 			if i-1 in friendsindex:
			# 				Sd["follow"].append(i-1)
			# 			else:
			# 				Sd["unfollow"].append(i-1)

			print("Bandit attack",Sd["follow"])

			query_result = self.queryBox(Sd)
			mask = np.ones(query_result.size,dtype=np.bool)
			mask[self.y[self.targetNode]] = False
			loss = query_result[self.y[self.targetNode]] - query_result[mask].max()
			coef = loss*(num_nodes - 1)/delta
			g = coef*u
			v = v+eta*g
			projection2 = Projection_kkt(num_nodes-1,self.B1,v,alpha)
			v = projection2.solution()
			attack_loss.append(loss)
		print("attack node",self.targetNode)
		print("original", self.original_Sd["follow"])
		return attack_loss
		

	def baseline_random_loss(self,T):
		attack_loss = []
		num_nodes = self.y.shape[0]
		count = 0
		for t in range(T):
			Sd = {"follow":[],"unfollow":[]}
			Sd["follow"].extend(self.original_Sd["follow"][:])
			Sd["unfollow"].extend(self.original_Sd["unfollow"][:])
			rv = np.random.random(num_nodes-1)
			projection2 = Projection_kkt(num_nodes-1,self.B1,rv,1.0)#Projection2(num_nodes-1,self.B1,s_hat,0.0)
			rv = projection2.solution()
			perturbindex = np.argsort(rv)[::-1]
			perturbindex = perturbindex[:self.B1]


			s1 = np.sum(rv[perturbindex])
			newperturbindex = []
			for i in perturbindex:
				if rv[i]/s1 >= np.random.random():
					newperturbindex.append(i)
			perturbindex = newperturbindex

			for i in perturbindex:
				if i < self.targetNode:
					if i in Sd["follow"][:]:
						print("*"*200)
						print(i)
						Sd["follow"].remove(i)
						Sd["unfollow"].append(i)
					else:
						Sd["unfollow"].remove(i)
						Sd["follow"].append(i)
				else:
					if i+1 in Sd["follow"][:]:
						print("1*"*200)
						print(i+1)
						Sd["follow"].remove(i+1)
						Sd["unfollow"].append(i+1)
					else:
						Sd["unfollow"].remove(i+1)
						Sd["follow"].append(i+1)


			# friends = random.sample(range(num_nodes),self.B1)
			# unfriends = [i for i in range(num_nodes) if i not in friends]

			# Sd = {"follow":friends,"unfollow":unfriends}
			query_result = self.queryBox(Sd)
			mask = np.ones(query_result.size,dtype=np.bool)
			mask[self.y[self.targetNode]] = False
			loss = query_result[self.y[self.targetNode]] - query_result[mask].max()
			attack_loss.append(loss)
		return attack_loss

	def pgdocobandit_succ(self,T,eta,delta,alpha):
		num_nodes = self.y.shape[0]
		print(num_nodes)
		v = np.zeros(num_nodes-1)
		count = 0
		Sd = {"follow":[],"unfollow":[]}
		Sd["follow"].extend(self.original_Sd["follow"])
		Sd["unfollow"].extend(self.original_Sd["unfollow"])
		for t in range(T):
			u = np.random.uniform(-1,1,num_nodes-1)#np.random.random(num_nodes-1)
			sum2 = np.sum(u**2)
			sum1 = np.sqrt(sum2)
			u = u/sum1 # normarlize to unit sphere
			s_hat = v + delta*u
			projection2 = Projection_kkt(num_nodes-1,self.B1,s_hat,1.0)#Projection2(num_nodes-1,self.B1,s_hat,0.0)
			s_hat = projection2.solution()
			
			perturbindex = np.argsort(s_hat)[::-1]
			perturbindex = perturbindex[:self.B1]

			# s1 = np.sum(s_hat[perturbindex])
			# newperturbindex = []
			# for i in perturbindex:
			# 	if s_hat[i]/s1 >= np.random.random():
			# 		newperturbindex.append(i)
			# perturbindex = newperturbindex


			for i in perturbindex:
				if i < self.targetNode:
					if i in Sd["follow"][:]:
						print("*"*200)
						print(i)
						Sd["follow"].remove(i)
						Sd["unfollow"].append(i)
					else:
						Sd["unfollow"].remove(i)
						Sd["follow"].append(i)
				else:
					if i+1 in Sd["follow"][:]:
						print("1*"*200)
						print(i+1)
						Sd["follow"].remove(i+1)
						Sd["unfollow"].append(i+1)
					else:
						Sd["unfollow"].remove(i+1)
						Sd["follow"].append(i+1)
			# for i in range(num_nodes):
			# 	if i < self.targetNode:
			# 		if i in friendsindex:
			# 			Sd["follow"].append(i)
			# 		else:
			# 			Sd["unfollow"].append(i)
			# 	else:
			# 		if i == self.targetNode:
			# 			continue
			# 		else:
			# 			if i-1 in friendsindex:
			# 				Sd["follow"].append(i-1)
			# 			else:
			# 				Sd["unfollow"].append(i-1)

			print("Bandit attack",Sd["follow"])

			query_result = self.queryBox(Sd)
			if query_result.argmax()!=self.y[self.targetNode]:
				count+=1
				print("Bandit attack successfully",t,self.targetNode)
				return 1,t+1
			mask = np.ones(query_result.size,dtype=np.bool)
			mask[self.y[self.targetNode]] = False
			loss = query_result[self.y[self.targetNode]] - query_result[mask].max()
			coef = loss*(num_nodes - 1)/delta
			g = coef*u
			v = v+eta*g
			projection2 = Projection_kkt(num_nodes-1,self.B1,v,alpha)
			v = projection2.solution()
		print("attack node",self.targetNode)
		print("original", self.original_Sd["follow"])
		return 0,T
		#return 1.*count/T

	def baseline_zoo_succ(self,T):
		lr = 2e-1
		h = 0.1
		num_nodes = self.y.shape[0]
		count = 0
		# ADAM status
		mt = np.zeros(num_nodes-1, dtype = np.float32)
		vt = np.zeros(num_nodes-1, dtype = np.float32)
		countT = np.zeros(num_nodes-1,dtype = np.int32)
		beta1 = 0.9
		beta2 = 0.999
		epsilon = 1e-3
		x = np.zeros(num_nodes-1)
		Sd = {"follow":[],"unfollow":[]}
		Sd["follow"].extend(self.original_Sd["follow"])
		Sd["unfollow"].extend(self.original_Sd["unfollow"])
		for i in Sd["follow"]:
			if i < self.targetNode:
				x[i] = 1
			else:
				x[i-1] = 1
		for t in range(T):
			Sd1 = {"follow":[],"unfollow":[]}
			Sd2 = {"follow":[],"unfollow":[]}
			Sd1["follow"].extend(Sd["follow"])
			Sd1["unfollow"].extend(Sd["unfollow"])
			Sd2["follow"].extend(Sd["follow"])
			Sd2["unfollow"].extend(Sd["unfollow"])
			x1 = x[:]
			x2 = x[:]
			for ll in range(self.B1):
				pi = np.random.choice(num_nodes-1)
				x1[pi] += h
				x2[pi] -= h

			perturbindex1 = np.argsort(x1)[::-1]
			perturbindex1 = perturbindex1[:self.B1]

			perturbindex2 = np.argsort(x2)[::-1]
			perturbindex2 = perturbindex2[:self.B1]

			for i in perturbindex1:
				if i < self.targetNode:
					if i in Sd1["follow"][:]:
						print("*"*200)
						print(i)
						Sd1["follow"].remove(i)
						Sd1["unfollow"].append(i)
						# sync to globle
						if i in Sd["follow"]:
							Sd["follow"].remove(i)
						if i not in Sd["unfollow"]:
							Sd["unfollow"].append(i)
					else:
						Sd1["unfollow"].remove(i)
						Sd1["follow"].append(i)
						# sync to globle
						if i in Sd["follow"]:
							Sd["follow"].remove(i)
						if i not in Sd["unfollow"]:
							Sd["unfollow"].append(i)
				else:
					if i+1 in Sd1["follow"][:]:
						print("1*"*200)
						print(i+1)
						Sd1["follow"].remove(i+1)
						Sd1["unfollow"].append(i+1)
						# sync to globle
						if i+1 in Sd["follow"]:
							Sd["follow"].remove(i+1)
						if i+1 not in Sd["unfollow"]:
							Sd["unfollow"].append(i+1)
					else:
						Sd1["unfollow"].remove(i+1)
						Sd1["follow"].append(i+1)
						# sync to globle
						if i+1 in Sd["follow"]:
							Sd["follow"].remove(i+1)
						if i+1 not in Sd["unfollow"]:
							Sd["unfollow"].append(i+1)

			query_result1 = self.queryBox(Sd1)

			for i in perturbindex2:
				if i < self.targetNode:
					if i in Sd2["follow"][:]:
						print("*"*200)
						print(i)
						Sd2["follow"].remove(i)
						Sd2["unfollow"].append(i)
						# sync to globle
						if i in Sd["follow"]:
							Sd["follow"].remove(i)
						if i not in Sd["unfollow"]:
							Sd["unfollow"].append(i)
					else:
						Sd2["unfollow"].remove(i)
						Sd2["follow"].append(i)
						# sync to globle
						if i in Sd["follow"]:
							Sd["follow"].remove(i)
						if i not in Sd["unfollow"]:
							Sd["unfollow"].append(i)
				else:
					if i+1 in Sd2["follow"][:]:
						print("1*"*200)
						print(i+1)
						Sd2["follow"].remove(i+1)
						Sd2["unfollow"].append(i+1)
						# sync to globle
						if i+1 in Sd["follow"]:
							Sd["follow"].remove(i+1)
						if i+1 not in Sd["unfollow"]:
							Sd["unfollow"].append(i+1)
					else:
						Sd2["unfollow"].remove(i+1)
						Sd2["follow"].append(i+1)
						# sync to globle
						if i+1 in Sd["follow"]:
							Sd["follow"].remove(i+1)
						if i+1 not in Sd["unfollow"]:
							Sd["unfollow"].append(i+1)

			query_result2 = self.queryBox(Sd2)

			mask = np.ones(query_result1.size,dtype=np.bool)
			mask[self.y[self.targetNode]] = False			
			loss1 = query_result1[self.y[self.targetNode]] - query_result1[mask].max()
			loss2 = query_result2[self.y[self.targetNode]] - query_result2[mask].max()
			gi = (loss1 - loss2)/2/h
			countT[pi] += 1
			mt[pi] = beta1*mt[pi] + (1-beta1)*gi
			vt[pi] = beta2*vt[pi] + (1-beta2)*gi**2
			Mi_tilde = mt[pi]/(1-beta1**countT[pi])
			vi_tilde = vt[pi]/(1-beta2**countT[pi])
			delta = -lr*Mi_tilde/(np.sqrt(vi_tilde)+epsilon)
			x[pi] += delta


			if query_result1.argmax()!=self.y[self.targetNode]:
				count+=1
				print("zoo attack successfully",t,self.targetNode)
				return 1,t+1

			if query_result2.argmax()!=self.y[self.targetNode]:
				count+=1
				print("zoo attack successfully",t,self.targetNode)
				return 1,t+1
		return 0,T

	def baseline_Qlearning_succ(self,T):
		num_nodes = self.y.shape[0]
		print(num_nodes)
		count = 0
		#Q = np.zeros((e.state_num, num_nodes-1))
		EPSILON = 0.1
		ALPHA = 0.1
		GAMMA = 0.9
		MAX_STEP = self.B1
		state_list = [self.original_Sd]
		Q_list = [np.zeros(num_nodes-1)]
		def epsilon_greedy(Q_index):
			Q = Q_list[Q_index]
			if (np.random.uniform() > 1 - EPSILON) or ((Q == 0).all()):
				action = np.random.randint(0, num_nodes-1)  # 0~3
			else:
				action = Q.argmax()
			return action
		for t in range(T): # Episode
			Sd = {"follow":[],"unfollow":[]}
			Sd["follow"].extend(self.original_Sd["follow"])
			Sd["unfollow"].extend(self.original_Sd["unfollow"])
			e = Env(Sd,self,self.B1,self.y[self.targetNode])
			#rl_state_exist
			while (e.is_end is False) and (e.step < MAX_STEP):
				res = rl_state_exist(state_list,e.Sd)
				if res[0] == False:
					Sd1 = {"follow":[],"unfollow":[]}
					Sd1["follow"].extend(e.Sd["follow"])
					Sd1["unfollow"].extend(e.Sd["unfollow"])
					state_list.append(Sd1)
					Q_list.append(np.zeros(num_nodes-1))
				else:
					Q_index = res[1]
				action = epsilon_greedy(Q_index)
				if action >= self.targetNode:
					true_action = action+1
				else:
					true_action = action
				reward = e.interact(true_action)
				res = rl_state_exist(state_list,e.Sd)
				if res[0] == False:
					Sd1 = {"follow":[],"unfollow":[]}
					Sd1["follow"].extend(e.Sd["follow"])
					Sd1["unfollow"].extend(e.Sd["unfollow"])
					state_list.append(Sd1)
					Q_list.append(np.zeros(num_nodes-1))
					Q_index2 = len(state_list)-1
				else:
					Q_index2 = res[1]
				Q_list[Q_index][action] = (1 - ALPHA) * Q_list[Q_index][action] + \
				ALPHA * (reward + GAMMA * Q_list[Q_index2].max())
			if e.is_end is True:
				print("RL attack successfully",t,self.targetNode)
				return 1,t+1
		print("attack node",self.targetNode)
		print("original", self.original_Sd["follow"])
		return 0,T





	def baseline_random_succ(self,T):
		num_nodes = self.y.shape[0]
		count = 0
		for t in range(T):
			Sd = {"follow":[],"unfollow":[]}
			Sd["follow"].extend(self.original_Sd["follow"])
			Sd["unfollow"].extend(self.original_Sd["unfollow"])
			rv = np.random.random(num_nodes-1)
			projection2 = Projection_kkt(num_nodes-1,self.B1,rv,1.0)#Projection2(num_nodes-1,self.B1,s_hat,0.0)
			rv = projection2.solution()
			perturbindex = np.argsort(rv)[::-1]
			perturbindex = perturbindex[:self.B1]


			# s1 = np.sum(rv[perturbindex])
			# newperturbindex = []
			# for i in perturbindex:
			# 	if rv[i]/s1 >= np.random.random():
			# 		newperturbindex.append(i)
			# perturbindex = newperturbindex

			for i in perturbindex:
				if i < self.targetNode:
					if i in Sd["follow"][:]:
						print("*"*200)
						print(i)
						Sd["follow"].remove(i)
						Sd["unfollow"].append(i)
					else:
						Sd["unfollow"].remove(i)
						Sd["follow"].append(i)
				else:
					if i+1 in Sd["follow"][:]:
						print("1*"*200)
						print(i+1)
						Sd["follow"].remove(i+1)
						Sd["unfollow"].append(i+1)
					else:
						Sd["unfollow"].remove(i+1)
						Sd["follow"].append(i+1)


			# friends = random.sample(range(num_nodes),self.B1)
			# unfriends = [i for i in range(num_nodes) if i not in friends]

			# Sd = {"follow":friends,"unfollow":unfriends}
			query_result = self.queryBox(Sd)
			if query_result.argmax()!=self.y[self.targetNode]:
				count += 1
				print("Random attack successfully",t,self.targetNode)
				return 1,t+1
			print("Random attack",Sd["follow"])
		print("attack node",self.targetNode)
		print("original", self.original_Sd["follow"])
		return 0,T
		#return 1.*count/T




def main():
	blackbox = GNNBox()
	result1 = blackbox.queryNode(blackbox.targetNode)
	print(result1)
	print(result1.argmax())
	print(blackbox.y[blackbox.targetNode])

	Sd = {"follow":blackbox.original_Sd["follow"],"unfollow":blackbox.original_Sd["unfollow"]}

	orilabel = blackbox.y[blackbox.targetNode].cpu().detach().numpy()
	print(orilabel)
	result2 = orilabel
	while result2 == orilabel:
		ch = random.sample(Sd["unfollow"],5)
		for node in ch:
			Sd["follow"].append(node)
			Sd["unfollow"].remove(node)
		result2 = blackbox.queryBox(Sd)
		print(result2)
		print(result2.argmax())
		result2 = result2.argmax()

if __name__ == '__main__':
	main()