from dig.xgraph.dataset import SynGraphDataset, MoleculeDataset
from ogb.nodeproppred import PygNodePropPredDataset
from ogb.graphproppred import PygGraphPropPredDataset
from torch_geometric.datasets import TUDataset
from datasets.infection import Infection
from torch_geometric.datasets import Planetoid
from torch_geometric.data import Data
from datasets.saturation import Saturation
import torch
import numpy as np
from collections import defaultdict
import torch_geometric.transforms as T


class CustomDataset:
    def __init__(self, dataset):
        self.dataset = dataset
        data = dataset[0]
        self.num_node_features = data.num_node_features
        self.num_classes = data.num_classes

    def __len__(self):
        r"""The number of examples in the dataset."""
        return len(self.dataset)

    def __getitem__(
            self,
            idx,
    ):
        if (isinstance(idx, (int, np.integer))
                or (isinstance(idx, torch.Tensor) and idx.dim() == 0)
                or (isinstance(idx, np.ndarray) and np.isscalar(idx))):
            return self.dataset[idx]

        else:
            return CustomDataset([self.dataset[i] for i in idx])


class FlattenY:
    def __init__(self, k=1):
        self.k = k

    def __call__(self, data):
        y = torch.flatten(data.y).long()
        data.y = y
        return data

    def __repr__(self):
        return '{}(k={})'.format(self.__class__.__name__, self.k)


class AddFeatures:
    def __init__(self, k=1):
        self.k = k

    def __call__(self, data):
        x = torch.tensor([[1.0] for _ in range(data.num_nodes)], dtype=torch.float)
        data.x = x
        return data

    def __repr__(self):
        return '{}(k={})'.format(self.__class__.__name__, self.k)

class NodifyEdges:
    def __init__(self):
        pass

    def __call__(self, data):
        x = torch.zeros(data.num_nodes + data.num_edges,
                        data.num_node_features + data.num_edge_features)
        x[0:data.num_nodes, 0:data.num_node_features] = data.x
        x[data.num_nodes:, data.num_node_features:] = data.edge_attr
        new_src, new_dest = [], []
        src, dest = data.edge_index[0], data.edge_index[1]
        for i in range(len(src)):
            new_src.append(src[i])
            new_src.append(data.num_nodes + i)
            new_dest.append(data.num_nodes + i)
            new_dest.append(dest[i])
        newdata = Data(x=x, y=data.y, edge_index=torch.tensor([new_src, new_dest]))
        return newdata

    def __repr__(self):
        return '{}'.format(self.__class__.__name__)


class NodifyEdgesAndOnes:
    def __init__(self):
        pass

    def __call__(self, data):
        x = torch.zeros(data.num_nodes + data.num_edges,
                        data.num_node_features + data.num_edge_features)
        x[0:data.num_nodes, 0:data.num_node_features] = torch.ones(data.num_nodes, 1)
        x[data.num_nodes:, data.num_node_features:] = data.edge_attr
        new_src, new_dest = [], []
        src, dest = data.edge_index[0], data.edge_index[1]
        for i in range(len(src)):
            new_src.append(src[i])
            new_src.append(data.num_nodes + i)
            new_dest.append(data.num_nodes + i)
            new_dest.append(dest[i])
        newdata = Data(x=x, y=data.y, edge_index=torch.tensor([new_src, new_dest]))
        return newdata

    def __repr__(self):
        return '{}'.format(self.__class__.__name__)


def neighbors(data):
    nbs = defaultdict(set)
    source, dest = data.edge_index
    for node in range(data.num_nodes):
        for i in range(len(source)):
            if int(source[i]) == node:
                nbs[node].add(int(dest[i]))
            if int(dest[i]) == node:
                nbs[node].add(int(source[i]))
    return nbs


def load_dataset(dataset_name, args):
    dataset = None
    use_pooling = True
    dataset_mask = False
    if dataset_name == 'MUTAG':
        dataset = MoleculeDataset(args.data_dir + '/datasets', dataset_name)
        use_pooling = True

    if dataset_name == 'PROTEINS':
        dataset = TUDataset(root=args.data_dir + '/datasets', name=dataset_name)
        use_pooling = True

    if dataset_name == 'IMDB-BINARY':
        dataset = TUDataset(root=args.data_dir + '/datasets', name=dataset_name, pre_transform=AddFeatures())
        use_pooling = True

    if dataset_name == 'REDDIT-BINARY':
        dataset = TUDataset(root=args.data_dir + '/datasets', name=dataset_name, pre_transform=AddFeatures())
        use_pooling = True

    if dataset_name == 'Mutagenicity':
        dataset = TUDataset(root=args.data_dir + '/datasets', name=dataset_name)
        use_pooling = True

    if dataset_name == 'BBBP':
        dataset = MoleculeDataset(args.data_dir + '/datasets', dataset_name, pre_transform=FlattenY())
        use_pooling = True

    if dataset_name == 'COLLAB':
        dataset = TUDataset(args.data_dir + '/datasets', name=dataset_name, pre_transform=AddFeatures())
        use_pooling = True

    if dataset_name == 'BA_2Motifs':
        dataset = SynGraphDataset(args.data_dir + '/datasets', dataset_name)
        dataset.data.x = dataset.data.x.to(torch.float32)
        dataset.data.x = dataset.data.x[:, :1]
        use_pooling = True

    if dataset_name == 'Infection':
        benchmark = Infection(num_layers=args.number_of_layers)
        dataset = CustomDataset([benchmark.create_dataset(num_nodes=1000, edge_probability=0.004) for _ in range(10)])
        use_pooling = False

    if dataset_name == 'Saturation':
        benchmark = Saturation(sample_count=1, num_layers=args.number_of_layers, concat_features=False,
                               conv_type=None)
        dataset = CustomDataset([benchmark.create_dataset() for _ in range(10)])
        use_pooling = False

    if dataset_name == 'BA_shapes':
        dataset = SynGraphDataset(args.data_dir + '/datasets', name=dataset_name)
        dataset.data.x = dataset.data.x.to(torch.float32)
        dataset.data.x = dataset.data.x[:, :1]
        use_pooling = False
        dataset_mask = True
        explanation_ground_truth = defaultdict(list)
        nbs = neighbors(dataset.data)
        for n in range(len(dataset.data.x)):
            if dataset.data.y[n] != 0:
                seen = set()
                seen.add(n)
                queue = [n]
                while queue:
                    node = queue.pop(0)
                    explanation_ground_truth[n].append(node)
                    if len(explanation_ground_truth[n]) == 5:
                        break
                    for nb in nbs[node]:
                        if nb not in seen and dataset.data.y[nb] != 0:
                            seen.add(nb)
                            queue.append(nb)
        dataset.egt = explanation_ground_truth

    if dataset_name == 'Tree_Cycle':
        dataset = SynGraphDataset(args.data_dir + '/datasets', name=dataset_name)
        dataset.data.x = dataset.data.x.to(torch.float32)
        dataset.data.x = dataset.data.x[:, :1]
        use_pooling = False
        dataset_mask = True
        explanation_ground_truth = defaultdict(list)
        nbs = neighbors(dataset.data)
        for n in range(len(dataset.data.x)):
            if dataset.data.y[n] != 0:
                seen = set()
                seen.add(n)
                queue = [n]
                while queue:
                    node = queue.pop(0)
                    explanation_ground_truth[n].append(node)
                    if len(explanation_ground_truth[n]) == 6:
                        break
                    for nb in nbs[node]:
                        if nb not in seen and dataset.data.y[n] != 0:
                            seen.add(nb)
                            queue.append(nb)
        dataset.egt = explanation_ground_truth

    if dataset_name == 'Tree_Grid':
        dataset = SynGraphDataset(args.data_dir + '/datasets', name=dataset_name)
        dataset.data.x = dataset.data.x.to(torch.float32)
        dataset.data.x = dataset.data.x[:, :1]
        use_pooling = False
        dataset_mask = True
        explanation_ground_truth = defaultdict(list)
        nbs = neighbors(dataset.data)
        for n in range(len(dataset.data.x)):
            if dataset.data.y[n] != 0:
                seen = set()
                seen.add(n)
                queue = [n]
                while queue:
                    node = queue.pop(0)
                    explanation_ground_truth[n].append(node)
                    if len(explanation_ground_truth[n]) == 9:
                        break
                    for nb in nbs[node]:
                        if nb not in seen and dataset.data.y[n] != 0:
                            seen.add(nb)
                            queue.append(nb)
        dataset.egt = explanation_ground_truth

    if dataset_name in ['Cora', 'CiteSeer', 'PubMed']:
        dataset = Planetoid(args.data_dir + '/datasets', name=dataset_name)
        dataset.data.x = dataset.data.x.to(torch.float32)
        use_pooling = False
        dataset_mask = True

    if dataset_name == "OGBA":
        dataset = PygNodePropPredDataset("ogbn-arxiv", root="datasets/", transform=T.ToUndirected())
        dataset.data.x = dataset.data.x.to(torch.float32)
        #dataset.data.y = dataset.data.y.squeeze(1)
        use_pooling = False
        dataset_mask = True

    if dataset_name == "OGB-molhiv":
        dataset = PygGraphPropPredDataset("ogbg-molhiv", root="datasets/")
        use_pooling = True
        dataset_mask = False

    if dataset_name == "OGB-ppa":
        dataset = PygGraphPropPredDataset("ogbg-ppa", root="datasets/", transform=AddFeatures())
        use_pooling = True
        dataset_mask = False

    if dataset_name == "OGB-code2":
        dataset = PygGraphPropPredDataset("ogbg-code2", root="datasets/")
        use_pooling = True
        dataset_mask = False

    dataset_args = {
        "use_pooling": use_pooling,
        "dataset_mask": dataset_mask
    }
    return dataset, dataset_args
