import numpy as np
import os
import torch
import random
import importlib
import networkx as nx
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid, Coauthor, Amazon
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.utils import remove_self_loops, add_self_loops, to_undirected, to_networkx
from deeprobust.graph.data import Dataset as DeepRobust_Dataset
from deeprobust.graph.data import PrePtbDataset as DeepRobust_PrePtbDataset
from torch_geometric.data import Data
from deeprobust.graph.utils import get_train_val_test

import scipy.sparse as sp

def load_data(args):
    path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'dataset', args.dataset)
    args.epochs = 1000
    args.patience = 100
    args.dim_hidden = 64
    if args.dataset in ['Cora', 'Citeseer']:      
        data = Planetoid(path, args.dataset, split='public', transform=T.NormalizeFeatures())
        args.num_classes = data.num_classes
        data = data[0]
        edge_index, _ = remove_self_loops(data.edge_index)
        edge_index = add_self_loops(edge_index, num_nodes=data.x.size(0))
        if isinstance(edge_index, tuple):
            data.edge_index = edge_index[0]
        else:
            data.edge_index = edge_index
    
    elif args.dataset in ['Cora-adv', 'Cora_ml-adv', 'Citeseer-adv', 'acm-adv']:
        if args.dataset == 'Cora-adv':
            name = 'cora'
        elif args.dataset == 'Cora_ml-adv':
            name = 'cora_ml'
        elif args.dataset == 'Citeseer-adv':
            name = 'citeseer'
        elif args.dataset =='acm-adv':
            name = 'acm'
        dataset = DeepRobust_Dataset(path, name, setting='nettack', require_mask=True, seed=15)
        dataset.x = torch.FloatTensor(dataset.features.todense())
        dataset.y = torch.LongTensor(dataset.labels)
        args.num_classes = dataset.y.max().item() + 1

        if args.ptb_rate > 0:   
            perturbed_data = DeepRobust_PrePtbDataset(path,name,attack_method=args.attack,
                                   ptb_rate=args.ptb_rate)
            edge_index = torch.LongTensor(perturbed_data.adj.nonzero())
            # data_filename = os.path.join(path,'{}_{}_adj_{}.npz'.format(name,args.attack,args.ptb_rate))
            # perturbed_adj = sp.load_npz(data_filename)
            # edge_index = torch.LongTensor(perturbed_adj.nonzero())
        else:
            edge_index = torch.LongTensor(dataset.adj.nonzero())
            
        data = Data(x=dataset.x, edge_index=edge_index, y=dataset.y)
        data.train_mask = torch.tensor(dataset.train_mask)
        data.val_mask   = torch.tensor(dataset.val_mask)
        data.test_mask  = torch.tensor(dataset.test_mask)
                    
        if args.attack == 'nettack':
            # perturbed_data = DeepRobust_PrePtbDataset(path,name,attack_method=args.attack,ptb_rate=args.ptb_rate)
            # data.test_mask[perturbed_data.target_nodes] = True
            node_list = []
            adj_tmp = torch.zeros([dataset.x.size()[0],dataset.x.size()[0]]).to(args.device)
            adj_tmp[data.edge_index[0], data.edge_index[1]] = 1   
            degrees = adj_tmp.sum(0)
            for i in range(len(degrees)):
                if data.test_mask[i]:
                    if degrees[i] <= 10:
                        data.test_mask[i] = False
    return args, data

def set_seed(repetition):
    seeds_init = [12232231, 12232432, 2234234, 4665565, 45543345, 454543543, 45345234, 54552234, 234235425, 909099343]

    seed = seeds_init[repetition]
    
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    return seed

def get_model(args):
    Model = getattr(importlib.import_module('models'), args.type_model)
    model = Model(args)
    return model