import argparse
import numpy as np
import torch
from deeprobust.graph.data import Dataset, PrePtbDataset
from deeprobust.graph.utils import preprocess
from tgnn import TGNN


# Training settings
parser = argparse.ArgumentParser()
parser.add_argument('--debug', action='store_true',
        default=False, help='debug mode')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='Disables CUDA training. Turn it on if GPU memory is not enough.')
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
parser.add_argument('--lr', type=float, default=0.01,
                    help='Initial learning rate.')
parser.add_argument('--weight_decay', type=float, default=5e-4,
                    help='Weight decay.')
parser.add_argument('--hidden', type=int, default=16,
                    help='Number of hidden units.')
parser.add_argument('--rank', type=int, default=32,
                    help='Number of tensor ranks.')
parser.add_argument('--topk', type=int, default=32,
                    help='Number of topk in nn.')
parser.add_argument('--dropout', type=float, default=0.5,
                    help='Dropout rate (1 - keep probability).')
parser.add_argument('--dataset', type=str, default='cora',
        choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
parser.add_argument('--attack', type=str, default='meta',
        choices=['no', 'meta', 'random', 'nettack'])
parser.add_argument('--format', type=str, default='Tucker',
        choices=['Tucker', 'CP', 'TT'])
parser.add_argument('--pro', type=str, default='prune,svd,knn')
parser.add_argument('--ptb_rate', type=float, default=0.15, help="noise ptb_rate")
parser.add_argument('--prune_thd', type=float, default=0.01, help="threshold for prune")
parser.add_argument('--epochs', type=int,  default=1000, help='Number of epochs to train.')
parser.add_argument('--svd_rank', type=int,  default=200, help='rank for svd decomposition.')
parser.add_argument('--lambda_t', type=float, default=1e2, help='weight of TD')
parser.add_argument('--weight_decay_t', type=float, default=3e-1, help='weight decay of TD')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if args.cuda else "cpu")
if args.cuda:
    torch.cuda.manual_seed(args.seed)
if args.ptb_rate == 0 and not args.attack == 'nettack': # get target nodes for nettack
    args.attack = "no"

print(args)

np.random.seed(15)
torch.manual_seed(15)
data = Dataset(root='../Data/', name=args.dataset, setting='nettack')
adj, features, labels = data.adj, data.features, data.labels
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test


if args.attack == 'no':
    perturbed_adj = adj

if args.attack == 'random':
    from deeprobust.graph.global_attack import Random
    attacker = Random()
    n_perturbations = int(args.ptb_rate * (adj.sum()//2))
    attacker.attack(adj, n_perturbations, type='add')
    perturbed_adj = attacker.modified_adj

if args.attack == 'meta' or args.attack == 'nettack':
    perturbed_data = PrePtbDataset(root='../Data/',
            name=args.dataset,
            attack_method=args.attack,
            ptb_rate=args.ptb_rate if args.ptb_rate > 0 else 1.0)
    perturbed_adj = perturbed_data.adj if args.ptb_rate > 0 else adj
    if args.attack == 'nettack':
        idx_test = perturbed_data.get_target_nodes()

if args.dataset == 'polblogs':
    import scipy
    features = perturbed_adj + scipy.sparse.csr_matrix(np.eye(perturbed_adj.shape[0]))

np.random.seed(args.seed)
torch.manual_seed(args.seed)

model = TGNN(nfeat=features.shape[1],
                 nhid=args.hidden,
                 nclass=labels.max().item() + 1,
                 lr=args.lr,
                 format=args.format, rank=args.rank, pros=args.pro,
                 euclidean=args.dataset in ['cora', 'citeseer', 'polblogs'],
                 svd_rank=args.svd_rank, prune_thd=args.prune_thd,
                 lambda_t=args.lambda_t, weight_decay_t= args.weight_decay_t,
                 topk=args.topk, dropout=args.dropout, device=device)

model = model.to(device)
perturbed_adj, features, labels = preprocess(perturbed_adj, features, labels, preprocess_adj=False, device=device)
model.fit(features, perturbed_adj, labels, idx_train, idx_val, verbose=True, train_iters=args.epochs)
model.test(idx_test)

