from dig.xgraph.dataset import SynGraphDataset, MoleculeDataset
from torch_geometric.datasets import TUDataset
from datasets.infection import Infection
from torch_geometric.datasets import Planetoid
from datasets.saturation import Saturation
import torch
import numpy as np
from collections import defaultdict


class CustomDataset:
    def __init__(self, dataset):
        self.dataset = dataset
        data = dataset[0]
        self.num_node_features = data.num_node_features
        self.num_classes = data.num_classes

    def __len__(self):
        r"""The number of examples in the dataset."""
        return len(self.dataset)

    def __getitem__(
            self,
            idx,
    ):
        if (isinstance(idx, (int, np.integer))
                or (isinstance(idx, torch.Tensor) and idx.dim() == 0)
                or (isinstance(idx, np.ndarray) and np.isscalar(idx))):
            return self.dataset[idx]

        else:
            return CustomDataset([self.dataset[i] for i in idx])


class FlattenY:
    def __init__(self, k=1):
        self.k = k

    def __call__(self, data):
        y = torch.flatten(data.y).long()
        data.y = y
        return data

    def __repr__(self):
        return '{}(k={})'.format(self.__class__.__name__, self.k)


class AddFeatures:
    def __init__(self, k=1):
        self.k = k

    def __call__(self, data):
        x = torch.tensor([[1.0] for _ in range(data.num_nodes)], dtype=torch.float)
        data.x = x
        return data

    def __repr__(self):
        return '{}(k={})'.format(self.__class__.__name__, self.k)


def neighbors(data):
    nbs = defaultdict(set)
    source, dest = data.edge_index
    for node in range(data.num_nodes):
        for i in range(len(source)):
            if int(source[i]) == node:
                nbs[node].add(int(dest[i]))
            if int(dest[i]) == node:
                nbs[node].add(int(source[i]))
    return nbs


def load_dataset(dataset_name, args):
    dataset = None
    use_pooling = True
    dataset_mask = False
    if dataset_name == 'MUTAG':
        dataset = MoleculeDataset(args.data_dir + '/datasets', dataset_name)
        use_pooling = True

    if dataset_name == 'PROTEINS':
        dataset = TUDataset(root=args.data_dir + '/datasets', name=dataset_name)
        use_pooling = True

    if dataset_name == 'IMDB-BINARY':
        dataset = TUDataset(root=args.data_dir + '/datasets', name=dataset_name, pre_transform=AddFeatures())
        use_pooling = True

    if dataset_name == 'REDDIT-BINARY':
        dataset = TUDataset(root=args.data_dir + '/datasets', name=dataset_name, pre_transform=AddFeatures())
        use_pooling = True

    if dataset_name == 'Mutagenicity':
        dataset = TUDataset(root=args.data_dir + '/datasets', name=dataset_name)
        use_pooling = True

    if dataset_name == 'BBBP':
        dataset = MoleculeDataset(args.data_dir + '/datasets', dataset_name, pre_transform=FlattenY())
        use_pooling = True

    if dataset_name == 'COLLAB':
        dataset = TUDataset(args.data_dir + '/datasets', name=dataset_name, pre_transform=AddFeatures())
        use_pooling = True

    if dataset_name == 'BA_2Motifs':
        dataset = SynGraphDataset(args.data_dir + '/datasets', dataset_name)
        dataset.data.x = dataset.data.x.to(torch.float32)
        dataset.data.x = dataset.data.x[:, :1]
        use_pooling = True

    if dataset_name == 'Infection':
        benchmark = Infection(num_layers=args.number_of_layers)
        dataset = CustomDataset([benchmark.create_dataset(num_nodes=1000, edge_probability=0.004) for _ in range(10)])
        use_pooling = False

    if dataset_name == 'Saturation':
        benchmark = Saturation(sample_count=1, num_layers=args.number_of_layers, concat_features=False,
                               conv_type=None)
        dataset = CustomDataset([benchmark.create_dataset() for _ in range(10)])
        use_pooling = False

    if dataset_name == 'BA_shapes':
        dataset = SynGraphDataset(args.data_dir + '/datasets', name=dataset_name)
        dataset.data.x = dataset.data.x.to(torch.float32)
        dataset.data.x = dataset.data.x[:, :1]
        use_pooling = False
        dataset_mask = True
        explanation_ground_truth = defaultdict(list)
        nbs = neighbors(dataset.data)
        for n in range(len(dataset.data.x)):
            if dataset.data.y[n] != 0:
                seen = set()
                seen.add(n)
                queue = [n]
                while queue:
                    node = queue.pop(0)
                    explanation_ground_truth[n].append(node)
                    if len(explanation_ground_truth[n]) == 5:
                        break
                    for nb in nbs[node]:
                        queue.append(nb)
        dataset.egt = explanation_ground_truth

    if dataset_name == 'Tree_Cycle':
        dataset = SynGraphDataset(args.data_dir + '/datasets', name=dataset_name)
        dataset.data.x = dataset.data.x.to(torch.float32)
        dataset.data.x = dataset.data.x[:, :1]
        use_pooling = False
        dataset_mask = True
        explanation_ground_truth = defaultdict(list)
        nbs = neighbors(dataset.data)
        for n in range(len(dataset.data.x)):
            if dataset.data.y[n] != 0:
                seen = set()
                seen.add(n)
                queue = [n]
                while queue:
                    node = queue.pop(0)
                    explanation_ground_truth[n].append(node)
                    if len(explanation_ground_truth[n]) == 6:
                        break
                    for nb in nbs[node]:
                        if nb not in seen and dataset.data.y[n] != 0:
                            seen.add(nb)
                            queue.append(nb)
        dataset.egt = explanation_ground_truth

    if dataset_name == 'Tree_Grid':
        dataset = SynGraphDataset(args.data_dir + '/datasets', name=dataset_name)
        dataset.data.x = dataset.data.x.to(torch.float32)
        dataset.data.x = dataset.data.x[:, :1]
        use_pooling = False
        dataset_mask = True
        explanation_ground_truth = defaultdict(list)
        nbs = neighbors(dataset.data)
        for n in range(len(dataset.data.x)):
            if dataset.data.y[n] != 0:
                seen = set()
                seen.add(n)
                queue = [n]
                while queue:
                    node = queue.pop(0)
                    explanation_ground_truth[n].append(node)
                    if len(explanation_ground_truth[n]) == 9:
                        break
                    for nb in nbs[node]:
                        if nb not in seen and dataset.data.y[n] != 0:
                            seen.add(nb)
                            queue.append(nb)
        dataset.egt = explanation_ground_truth

    if dataset_name == 'Cora':
        dataset = Planetoid(args.data_dir + '/datasets', name=dataset_name)
        dataset.data.x = dataset.data.x.to(torch.float32)
        use_pooling = False
        dataset_mask = True

    dataset_args = {
        "use_pooling": use_pooling,
        "dataset_mask": dataset_mask
    }
    return dataset, dataset_args
