import torch
import torch_geometric.utils as tgutils
from ogb.graphproppred import PygGraphPropPredDataset
from torch_geometric.data import DataLoader
from torch_geometric.datasets import *
from torch_scatter import scatter

import graphgym.register as register
from graphgym.config import cfg
from graphgym.contrib.utils.random import compute_split_idx


def load_pyg(name, dataset_dir):
    '''
    load pyg format dataset
    :param name: dataset name
    :param dataset_dir: data directory
    :return: a list of networkx/deepsnap graphs
    '''
    dataset_dir = '{}/{}'.format(dataset_dir, name)
    if name in ['Cora', 'CiteSeer', 'PubMed']:
        dataset = Planetoid(dataset_dir, name)
    elif name == 'QM7b':
        from torch_geometric.datasets import QM7b

        def my_pre_transform(data):
            src_nodes = data.edge_index[0, :]
            degree = tgutils.degree(src_nodes, num_nodes=data.num_nodes)
            sum_edge_attr = scatter(data.edge_attr, src_nodes)
            data.x = torch.stack([degree, sum_edge_attr], dim=1)
            return data

        def my_transform(data):
            data.node_feature = data.x
            data.edge_feature = data.edge_attr
            data.graph_label = data.y
            return data

        dataset = QM7b(root=dataset_dir,
                       pre_transform=my_pre_transform,
                       transform=my_transform)
        # mean = dataset.data.x.mean(0)
        # std = dataset.data.x.std(0)
        # dataset.data.x = (dataset.data.x - mean)/std
        # dataset.data.node_feature = torch.cat(x, 0)
        # setattr(dataset.data, 'node_feature',torch.cat(x, 0))
        mean = dataset.data.y.mean(0)
        std = dataset.data.y.std(0)
        dataset.data.y = (dataset.data.y - mean) / std
        cfg.dataset.node_dim = 2
        cfg.dataset.edge_dim = 1
    elif name == 'QM9':
        from torch_geometric.datasets import QM9
        def my_transform(data):
            data.node_feature = data.x
            data.edge_feature = data.edge_attr
            data.graph_label = data.y
            return data

        dataset = QM9(root=dataset_dir, transform=my_transform)
        mean = dataset.data.y.mean(0)
        std = dataset.data.y.std(0)
        dataset.data.y = (dataset.data.y - mean) / std
        cfg.dataset.node_dim = dataset.data.x.shape[1]
        cfg.dataset.edge_dim = dataset.data.edge_attr.shape[1]
        setattr(dataset, 'num_labels', dataset.num_classes)
    elif name[:3] == 'TU_':
        # TU_IMDB doesn't have node features
        def my_transform(data):
            data.node_feature = data.x
            data.edge_feature = data.edge_attr
            data.graph_label = data.y
            return data

        if name[3:] == 'IMDB':
            name = 'TU_IMDB-MULTI'

            def my_pre_transform(data):
                src_nodes = data.edge_index[0, :]
                degree = tgutils.degree(src_nodes, num_nodes=data.num_nodes)
                data.x = torch.stack([degree], dim=1)
                return data


        elif name[3:] == 'MUTAG':
            def my_pre_transform(data):
                src_nodes = data.edge_index[0, :]
                feats = []
                degree = tgutils.degree(src_nodes, num_nodes=data.num_nodes)
                feats.append(degree)
                for i in range(data.edge_attr.shape[1]):
                    sum_edge_attr = scatter(data.edge_attr[:, i], src_nodes)
                    feats.append(sum_edge_attr)
                data.x = torch.stack(feats, dim=1)
                return data


        else:
            my_pre_transform=None

        dataset = TUDataset(dataset_dir,
                                name[3:],
                                transform=my_transform,
                            pre_transform=my_pre_transform,
                                use_node_attr=cfg.dataset.tu_use_node_attr,
                                use_edge_attr=cfg.dataset.tu_use_edge_attr)
        cfg.dataset.node_dim = dataset.data.x.shape[1]
        if hasattr(dataset.data, 'edge_attr') and dataset.data.edge_attr is not None:
            cfg.dataset.edge_dim = dataset.data.edge_attr.shape[1]
        else:
            cfg.dataset.edge_dim = 0
    elif name == 'Karate':
        dataset = KarateClub()
    elif 'Coauthor' in name:
        if 'CS' in name:
            dataset = Coauthor(dataset_dir, name='CS')
        else:
            dataset = Coauthor(dataset_dir, name='Physics')
    elif 'Amazon' in name:
        if 'Computers' in name:
            dataset = Amazon(dataset_dir, name='Computers')
        else:
            dataset = Amazon(dataset_dir, name='Photo')
    elif name == 'MNIST':
        dataset = MNISTSuperpixels(dataset_dir)
    elif name == 'PPI':
        dataset = PPI(dataset_dir)
    else:
        raise ValueError('{} not support'.format(name))

    setattr(dataset, 'num_labels', dataset.num_classes)
    return dataset


def index2mask(index, size):
    mask = torch.zeros(size, dtype=torch.bool)
    mask[index] = 1
    return mask


def load_ogb(name, dataset_dir):
    if name[:4] != 'ogbg':
        raise ValueError('OGB dataset: {} non-exist')
    dataset = PygGraphPropPredDataset(name=name, root=dataset_dir)

    return dataset


def load_dataset():
    '''
    load raw datasets.
    :return: a list of networkx/deepsnap graphs, plus additional info if needed
    '''
    format = cfg.dataset.format
    name = cfg.dataset.name
    dataset_dir = cfg.dataset.dir
    # Try to load customized data format
    for func in register.loader_dict.values():
        dataset = func(format, name, dataset_dir)
        if dataset is not None:
            return dataset
    # Load from Pytorch Geometric dataset
    if format == 'PyG':
        dataset = load_pyg(name, dataset_dir)
    # Load from OGB formatted data
    elif format == 'OGB':
        dataset = load_ogb(name.replace('_', '-'), dataset_dir)
    else:
        raise ValueError('Unknown data format: {}'.format(format))
    return dataset


def set_dataset_info(dataset):
    # set shared variables
    # todo: verify edge cases

    # get dim_in and dim_out
    try:
        cfg.share.dim_in = dataset.data.x.shape[1]
    except:
        cfg.share.dim_in = 1
    try:
        if cfg.dataset.task_type == 'classification':
            cfg.share.dim_out = torch.unique(dataset.data.y).shape[0]
        else:
            cfg.share.dim_out = dataset.data.y.shape[1]

    except:
        cfg.share.dim_out = 1


def create_dataset():
    ## todo: add new PyG dataset split functionality
    ## Load dataset

    dataset = load_dataset()

    set_dataset_info(dataset)
    datasets = []
    if cfg.dataset.format == 'OGB':
        splits = dataset.get_idx_split()
        datasets.append(dataset[splits['train']])
        datasets.append(dataset[splits['valid']])
        datasets.append(dataset[splits['test']])
    else:

        idx_splits = compute_split_idx(original_len=len(dataset),
                                       split_sizes=cfg.dataset.split,
                                       random=cfg.dataset.shuffle_split,
                                       k_fold=cfg.k_fold)

        for idx_split in idx_splits:
            datasets.append(dataset[idx_split])

    return datasets


def create_loader(datasets):
    if cfg.train.sampler == 'imbalance':
        from graphgym.contrib.utils.sampler import ImbalancedDatasetSampler
        shuffle = False
        sampler = ImbalancedDatasetSampler(dataset=datasets[0], label_attr_name='y')
    else:
        shuffle = True
        sampler = None
    loader_train = DataLoader(datasets[0], batch_size=cfg.train.batch_size,
                              shuffle=shuffle,
                              sampler=sampler,
                              num_workers=cfg.num_workers, pin_memory=False)

    loaders = [loader_train]
    for i in range(1, len(datasets)):
        loaders.append(DataLoader(datasets[i], batch_size=cfg.train.batch_size,
                                  shuffle=False,
                                  num_workers=cfg.num_workers,
                                  pin_memory=False))

    return loaders


def create_meta_loader(dataset):
    return DataLoader(dataset,
                      batch_size=1,
                      shuffle=True,
                      num_workers=cfg.num_workers,
                      pin_memory=False)
