import json
import os
import os.path as osp
from itertools import product

import numpy as np
import scipy.io
import torch

from torch_geometric.data import InMemoryDataset, Data, download_url
from utils import load_data, get_train_val_test_gcn, get_stored_splits
# from load_nell import load_nell
from geom_data_utils import load_geom_datasets
from itertools import repeat


base_path = '../AM-GCN/data/'

class Config:

    feature_path = None
    label_path = None
    graph_path = None


class OtherDataset(InMemoryDataset):

    def __init__(self, root, name, transform=None, pre_transform=None, **kwargs):
        self.name = name
        super(OtherDataset, self).__init__(root, transform, pre_transform)

        self.transform = transform
        config = Config()
        config.feature_path = osp.join(base_path, f'{name}/{name}.feature')
        config.label_path = osp.join(base_path, f'{name}/{name}.label')
        config.graph_path = osp.join(base_path, f'{name}/{name}.edge')
        config.train_path = osp.join(base_path, f'{name}/train20.txt')
        config.test_path = osp.join(base_path, f'{name}/test20.txt')

        self.config = config
        data = self.process()
        self.data, self.slices = self.collate([data])

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data.pt']

    def process(self):
        if self.name == 'nell':
            edge_index, x, y, idx_train, idx_val, idx_test = load_nell()
            data = Data(x=x, edge_index=edge_index, y=y)
            self.transform = None
        elif self.name in ['chameleon', 'squirrel', 'film', 'cornell', 'texas', 'wisconsin']:
            edge_index, x, y, idx_train, idx_val, idx_test = load_geom_datasets(self.name, 10)
            data = Data(x=x, edge_index=edge_index, y=y)
            self.transform = None
        else:
            edge_index, x, y = load_data(self.config)
            data = Data(x=x, edge_index=edge_index, y=y)
            # idx_train, idx_val, idx_test = get_train_val_test_gcn(data.y)
            idx_train, idx_val, idx_test = get_stored_splits(self.config)
            self.transform = None

        train_mask = index_to_mask(idx_train, size=y.size(0))
        val_mask = index_to_mask(idx_val, size=y.size(0))
        test_mask = index_to_mask(idx_test, size=y.size(0))
        data.train_mask = train_mask
        data.val_mask = val_mask
        data.test_mask = test_mask

        return data

    def get(self, idx):
        data = self.data.__class__()

        if hasattr(self.data, '__num_nodes__'):
            data.num_nodes = self.data.__num_nodes__[idx]

        for key in self.data.keys:
            item, slices = self.data[key], self.slices[key]
            s = list(repeat(slice(None), item.dim()))
            s[self.data.__cat_dim__(key,
                                    item)] = slice(slices[idx],
                                                   slices[idx + 1])
            try:
                data[key] = item[s]
            except:
                data[key] = item
        return data
    def copy(self, idx=None):
        if idx is None:
            data_list = [self.get(i) for i in range(len(self))]
        else:
            data_list = [self.get(i) for i in idx]
        dataset = copy.copy(self)
        dataset.__indices__ = None
        dataset.__data_list__ = data_list
        dataset.data, dataset.slices = self.collate(data_list)
        return dataset




from deeprobust.graph.defense import GCN, ProGNN
from deeprobust.graph.data import Dataset, PrePtbDataset
from deeprobust.graph.utils import preprocess, encode_onehot, get_train_val_test
class AttackedDataset(InMemoryDataset):

    def __init__(self, root, name, transform=None, pre_transform=None, **kwargs):
        self.name = name
        # super(OtherDataset, self).__init__(root, transform, pre_transform)
        # self.data, self.slices = self.collate([torch.load('./data.pt')])
        self.args = kwargs['args']

        self.transform = transform
        data = self.process()
        self.data, self.slices = self.collate([data])

    def process(self):
        args = self.args
        data = Dataset(root='/tmp/', name=args.dataset, setting='gcn', seed=15)
        adj, features, labels = data.adj, data.features, data.labels
        idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test

        if args.attack == 'random':
            from deeprobust.graph.global_attack import Random
            attacker = Random()
            n_perturbations = int(args.ptb_rate * (adj.sum()//2))
            attacker.attack(adj, n_perturbations, type='add')
            perturbed_adj = attacker.modified_adj

        if args.attack == 'meta' or args.attack == 'nettack':
            perturbed_data = PrePtbDataset(root='/tmp/',
                    name=args.dataset,
                    attack_method=args.attack,
                    ptb_rate=args.ptb_rate)
            perturbed_adj = perturbed_data.adj
            if args.attack == 'nettack':
                idx_test = perturbed_data.target_nodes

        edge_index = torch.LongTensor(perturbed_adj.nonzero())
        x = torch.FloatTensor(features.todense())
        y = torch.LongTensor(labels)
        data = Data(x=x, edge_index=edge_index, y=y)
        # self.transform = None

        train_mask = index_to_mask(idx_train, size=y.size(0))
        val_mask = index_to_mask(idx_val, size=y.size(0))
        test_mask = index_to_mask(idx_test, size=y.size(0))
        data.train_mask = train_mask
        data.val_mask = val_mask
        data.test_mask = test_mask

        return data

    def get(self, idx):
        data = self.data.__class__()

        if hasattr(self.data, '__num_nodes__'):
            data.num_nodes = self.data.__num_nodes__[idx]

        for key in self.data.keys:
            item, slices = self.data[key], self.slices[key]
            s = list(repeat(slice(None), item.dim()))
            s[self.data.__cat_dim__(key,
                                    item)] = slice(slices[idx],
                                                   slices[idx + 1])
            try:
                data[key] = item[s]
            except:
                data[key] = item
        return data



def edge_index_from_dict(graph_dict, num_nodes=None):
    row, col = [], []
    for key, value in graph_dict.items():
        row += repeat(key, len(value))
        col += value
    edge_index = torch.stack([torch.tensor(row), torch.tensor(col)], dim=0)
    # NOTE: There are duplicated edges and self loops in the datasets. Other
    # implementations do not remove them!
    edge_index, _ = remove_self_loops(edge_index)
    edge_index, _ = coalesce(edge_index, None, num_nodes, num_nodes)
    return edge_index


def index_to_mask(index, size):
    mask = torch.zeros((size, ), dtype=torch.bool)
    mask[index] = 1
    return mask
