from r_gugnn import R_GUGNN
import torch
from deeprobust.graph.data import Dataset
from attacked_data import PrePtbDataset
from deeprobust.graph.utils import preprocess, encode_onehot, get_train_val_test
import numpy as np
'''
m is the number of reconstructions for the graph
'''
dataset='cora'
seed=15
attack='meta'
ptb_rate=0.15
lamba=2.5
beta=2
c=0.00001
m=3

epochs=300
dropout=0.5
hidden=16

data = Dataset(root='./dataset/', name=dataset, setting='prognn')
adj, features, labels = data.adj, data.features, data.labels
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
'''
# test dataset for nettack
degrees=adj.sum(0)
degrees=np.ravel(degrees)
temp=np.argwhere([degrees>10])
temp=temp[:,1]
idx_test=np.intersect1d(idx_test,temp)
'''
device = torch.device("cuda:0")
torch.cuda.manual_seed(seed)
        
if ptb_rate==0:
	perturbed_adj=adj
else:
	perturbed_data = PrePtbDataset(root='./dataset/',
						name=dataset,
						attack_method=attack,
						ptb_rate=ptb_rate)
	perturbed_adj = perturbed_data.adj
        

'''
#perturbed_adj for random attack
from deeprobust.graph.global_attack import Random
import random
random.seed(seed)
np.random.seed(seed)
attacker = Random()
n_perturbations = int(ptb_rate * (adj.sum()//2))
attacker.attack(adj, n_perturbations, type='add')
perturbed_adj = attacker.modified_adj
'''
        
np.random.seed(seed)
torch.manual_seed(seed)
model = R_GUGNN(nfeat=features.shape[1],
			nhid=hidden,
			nclass=labels.max().item() + 1,
			dropout=dropout, device=device,c=c,lamba=lamba,beta=beta,iterations=m)

perturbed_adj, features, labels = preprocess(perturbed_adj, features, labels, preprocess_adj=False, sparse=True, device=device)
perturbed_adj=perturbed_adj.to_dense()

features=features.to_dense()
        
model.fit(features, perturbed_adj, labels, idx_train, idx_val, verbose=True, train_iters=epochs)
model.test(idx_test)
        
