import logging
import pickle
import time

import os
import networkx as nx
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from ogb.graphproppred import PygGraphPropPredDataset
from ogb.linkproppred import PygLinkPropPredDataset
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.datasets import *
from torch_geometric.loader import DataLoader

from graphgym.config import cfg
from graphgym.contrib.utils.random import get_permutation
from graphgym.utils.loader import get_weight


def compute_split_idx(original_len, split_sizes, random=True):
    all_idx = torch.arange(original_len)
    if random:
        perm = get_permutation(original_len=original_len)
        all_idx = all_idx[perm]

    start_idx, end_idx = 0, None
    all_idx_splits = []

    num_splits = len(split_sizes)
    for i, size in enumerate(split_sizes):
        assert isinstance(size, float)
        assert 0 < size
        assert 1 > size
        new_len = int(size * original_len)
        end_idx = new_len + start_idx
        if i == (num_splits - 1):
            all_idx_splits.append(all_idx[start_idx:])
        else:
            all_idx_splits.append(all_idx[start_idx:end_idx])
        start_idx = end_idx

    return all_idx_splits


def transform_after_split(datasets):
    '''
    Dataset transformation after train/val/test split
    :param dataset: A list of DeepSNAP dataset objects
    :return: A list of transformed DeepSNAP dataset objects
    '''

    return datasets


def load_pyg(name, dataset_dir):
    '''
    load pyg format dataset
    :param name: dataset name
    :param dataset_dir: data directory
    :return: a list of networkx/deepsnap graphs
    '''
    dataset_dir = '{}/{}'.format(dataset_dir, name)
    if name in ['Cora', 'CiteSeer', 'PubMed']:
        dataset_raw = Planetoid(dataset_dir, name)
    elif name[:3] == 'TU_':
        # TU_IMDB doesn't have node features
        if name[3:] == 'IMDB':
            name = 'IMDB-MULTI'

            def pre_transform(data):
                del data.num_nodes
                return data

            dataset_raw = TUDataset(dataset_dir, name,
                                    transform=T.Constant(),
                                    pre_transform=pre_transform)

        elif name[3:] == 'REDDIT':
            name = 'REDDIT-BINARY'

            def pre_transform(data):
                del data.num_nodes
                return data

            dataset_raw = TUDataset(dataset_dir, name,
                                    transform=T.Constant(),
                                    use_node_attr=True,
                                    pre_transform=pre_transform)
        else:
            dataset_raw = TUDataset(dataset_dir,
                                    name[3:],
                                    use_node_attr=cfg.dataset.tu_use_node_attr,
                                    use_edge_attr=cfg.dataset.tu_use_edge_attr)

        # TU_dataset only has graph-level label
        if dataset_raw.num_classes == 2:
            num_pos = (dataset_raw.data.y == 1).sum().item()
            num_neg = (dataset_raw.data.y == 0).sum().item()
            cfg.dataset.weight = num_neg / num_pos
        else:
            weight = torch.nn.functional.one_hot(dataset_raw.data.y, num_classes=dataset_raw.num_classes).sum(0).float()
            weight = torch.nn.functional.softmax(-weight.float()) * len(weight)
            cfg.dataset.weight = weight.tolist()

        cfg.dataset.weight = get_weight(dataset_raw)
        if cfg.dataset.use_subset < 1.0:
            original_len = len(dataset_raw)
            perm = get_permutation(original_len=original_len)
            new_len = int(original_len * cfg.dataset.use_subset)
            dataset_raw = dataset_raw[perm[:new_len]]

        cfg.dataset.node_dim = dataset_raw.num_node_features
        # The goal is to have synthetic tasks
        # that select smallest 100 graphs that have more than 200 edges
        if cfg.dataset.tu_simple and cfg.dataset.task != 'graph':
            size = []
            for data in dataset_raw:
                edge_num = data.edge_index.shape[1]
                edge_num = 9999 if edge_num < 200 else edge_num
                size.append(edge_num)
            size = torch.tensor(size)
            order = torch.argsort(size)[:100]
            dataset_raw = dataset_raw[order]
    elif name == 'Karate':
        dataset_raw = KarateClub()
    elif name == 'BCSBM':
        from graphgym.contrib.utils.random import get_bcsbm_datalist
        import math
        mu_norm = cfg.bcsbm.mu_norm
        d = cfg.bcsbm.d
        mu = [mu_norm/(math.sqrt(d)),]*d
        data = get_bcsbm_datalist(n=cfg.bcsbm.n,
                                  eps=cfg.bcsbm.eps,
                                  p=cfg.bcsbm.p,
                                  q=cfg.bcsbm.q,
                                  mu=mu,
                                  std_dev=cfg.bcsbm.std_dev,
                                  directed=cfg.bcsbm.directed,
                                  seed=0,
                                  version=cfg.bcsbm.version)[0]

        if not hasattr(data, 'node_feature'):
            setattr(data,  'node_feature', data.x)
        if not hasattr(data, 'node_label'):
            setattr(data,  'node_label', data.y)

        if cfg.dataset.task == 'edge':
            num_pos = (data.edge_label == 1).sum().item()
            num_neg = (data.edge_label == 0).sum().item()
            cfg.dataset.weight = num_neg / num_pos
        elif cfg.dataset.task == 'node':
            num_pos = (data.y == 1).sum().item()
            num_neg = (data.y == 0).sum().item()
            cfg.dataset.weight = num_neg / num_pos
        else:
            raise NotImplementedError

        cfg.dataset.label_dim = 1
        setattr(data,  'num_labels', 1)
        setattr(data,  'node_dim', d)
        cfg.dataset.node_dim = d
        cfg.dataset.edge_dim = 0
        dataset_raw = data
    elif 'Coauthor' in name:
        if 'CS' in name:
            dataset_raw = Coauthor(dataset_dir, name='CS')
        else:
            dataset_raw = Coauthor(dataset_dir, name='Physics')
    elif 'Amazon' in name:
        if 'Computers' in name:
            dataset_raw = Amazon(dataset_dir, name='Computers')
        else:
            dataset_raw = Amazon(dataset_dir, name='Photo')
    elif name == 'MNIST':
        dataset_raw = MNISTSuperpixels(dataset_dir)
    elif name == 'PPI':
        dataset_raw = PPI(dataset_dir)
    elif name == 'QM7b':
        dataset_raw = QM7b(dataset_dir)
    elif name == 'QM9':
        dataset_raw = QM9(dataset_dir)
        mean = dataset_raw.data.y.mean(dim=0, keepdim=True)
        std = dataset_raw.data.y.std(dim=0, keepdim=True)
        dataset_raw.data.y = (dataset_raw.data.y - mean) / std
    elif name == 'ZINC':
        from torch_geometric.datasets import ZINC
        def my_dataset_transform(data):
            data.node_feature = data.x
            data.edge_feature = data.edge_attr
            del data.x
            del data.edge_attr
            data.graph_label = data.y
            return data

        dataset_raw = ZINC(root=dataset_dir, subset=True,
                           transform=my_dataset_transform,
                           split='train')

        # num_classes_x = dataset_raw.data.x.max().item() + 1
        dataset_raw.data.edge_attr = dataset_raw.data.edge_attr - 1
        # num_classes_e = dataset_raw.data.edge_attr.max().item()  + 1
        # dataset_raw.data.x =  F.one_hot(dataset_raw.data.x,num_classes=num_classes_x)
        # dataset_raw.data.edge_attr =  F.one_hot(dataset_raw.data.edge_attr, num_classes=num_classes_e)
        cfg.dataset.edge_dim = dataset_raw.data.edge_attr.max().item() + 1
        cfg.dataset.node_dim = dataset_raw.data.x.max().item() + 1
        mean = dataset_raw.data.y.mean(dim=0, keepdim=True)
        std = dataset_raw.data.y.std(dim=0, keepdim=True)
        dataset_raw.data.y = (dataset_raw.data.y - mean) / std
        cfg.dataset.label_dim = 1
        dataset_raw.num_labels = 1

    else:
        raise ValueError('{} not support'.format(name))
    return dataset_raw


def load_nx(name, dataset_dir):
    '''
    load networkx format dataset
    :param name: dataset name
    :param dataset_dir: data directory
    :return: a list of networkx/deepsnap graphs
    '''
    raise NotImplementedError
    try:
        with open('{}/{}.pkl'.format(dataset_dir, name), 'rb') as file:
            graphs = pickle.load(file)
    except:
        graphs = nx.read_gpickle('{}/{}.gpickle'.format(dataset_dir, name))
        if not isinstance(graphs, list):
            graphs = [graphs]
    return graphs


def load_dataset():
    '''
    load raw datasets.
    :return: a list of networkx/deepsnap graphs, plus additional info if needed
    '''
    format = cfg.dataset.format
    name = cfg.dataset.name
    # dataset_dir = '{}/{}'.format(cfg.dataset.dir, name)
    dataset_dir = cfg.dataset.dir
    # Load from Pytorch Geometric dataset
    if format == 'PyG':
        dataset_raw = load_pyg(name, dataset_dir)
    # Load from networkx formatted data
    # todo: clean nx dataloader
    elif format == 'nx':
        dataset_raw = load_nx(name, dataset_dir)
    # Load from OGB formatted data
    elif cfg.dataset.format == 'OGB':
        if cfg.dataset.name in ['ogbg-molhiv', 'ogbg-ppa', 'ogbg-molpcba']:
            dataset_raw = PygGraphPropPredDataset(name=cfg.dataset.name)
        elif cfg.dataset.name in ['ogbn-mag']:
            raise NotImplementedError
            dataset_raw = PygNodePropPredDataset(name=cfg.dataset.name)
        elif cfg.dataset.name in ['ogbl-ppa']:
            raise NotImplementedError
            dataset_raw = PygLinkPropPredDataset(name=cfg.dataset.name)
        else:
            raise NotImplementedError

        if dataset_raw.num_edge_features > 0:
            cfg.dataset.encoder_edge_dim = dataset_raw.num_edge_features

        # Note this is only used for custom splits from OGB
        split_idx = dataset_raw.get_idx_split()

        if cfg.dataset.use_subset < 1.0:

            original_len = len(dataset_raw)
            perm = get_permutation(original_len=original_len)
            new_len = int(original_len * cfg.dataset.use_subset)
            dataset_raw = dataset_raw[perm[:new_len]]
            perm = get_permutation(original_len=new_len)
            all_idx = torch.arange(new_len)[perm]
            start_idx, end_idx = 0, None
            all_idx_splits = []

            num_splits = len(split_idx)
            for i, key in enumerate(split_idx):
                value = split_idx[key]
                new_len = int(value.size(0) * cfg.dataset.use_subset)
                end_idx = new_len + start_idx
                if i == (num_splits - 1):
                    split_idx[key] = all_idx[start_idx:]
                else:
                    split_idx[key] = all_idx[start_idx:end_idx]
                start_idx = end_idx
                all_idx_splits.append(split_idx[key])

            all_idx_splits = torch.cat(all_idx_splits).unique()
            assert len(dataset_raw) == len(all_idx_splits), f"{len(dataset_raw)}, {len(all_idx_splits)}"

        return dataset_raw, split_idx
    else:
        raise ValueError('Unknown data format: {}'.format(cfg.dataset.format))
    return dataset_raw


def filter_graphs():
    '''
    Filter graphs by the min number of nodes
    :return: min number of nodes
    '''
    if cfg.dataset.task == 'graph':
        min_node = 0
    else:
        min_node = 5
    return min_node


def create_dataset():
    ## Load dataset
    time1 = time.time()
    if cfg.dataset.format == 'OGB':
        dataset, splits = load_dataset()
    else:
        dataset = load_dataset()

    ## Filter graphs
    time2 = time.time()
    min_node = filter_graphs()

    ## Transform the whole dataset
    dataset.node_dim = cfg.dataset.node_dim
    dataset.edge_dim = cfg.dataset.edge_dim

    ## Split dataset
    time3 = time.time()
    # Use custom data splits
    if cfg.dataset.format != 'OGB':
        #
        if cfg.dataset.transductive:
            from graphgym.contrib.utils.transductive_split import transductive_split_data
            datasets = transductive_split_data(data=dataset,
                                          split_sizes=cfg.dataset.split,
                                          task=cfg.dataset.task,
                                          k_fold=cfg.k_fold)


        else:
            splits = compute_split_idx(original_len=len(dataset),
                                       split_sizes=cfg.dataset.split,
                                       random=cfg.dataset.shuffle_split)

            datasets = []
            for sp in splits:
                datasets.append(dataset[sp])

    else:
        datasets = []
        datasets.append(dataset[splits['train']])
        datasets.append(dataset[splits['valid']])
        datasets.append(dataset[splits['test']])  # Use random split, supported by DeepSNAP

    # We only change the training negative sampling ratio
    for i in range(1, len(datasets)):
        dataset.edge_negative_sampling_ratio = 1

    ## Transform each split dataset
    time4 = time.time()
    datasets = transform_after_split(datasets)

    time5 = time.time()
    logging.info('Load: {:.4}s, Before split: {:.4}s, '
                 'Split: {:.4}s, After split: {:.4}s'.format(
        time2 - time1, time3 - time2, time4 - time3, time5 - time4))

    return datasets


def create_dataset_simple():
    ## Load dataset
    time1 = time.time()
    name = cfg.dataset.name
    dataset_dir = cfg.dataset.dir
    dataset_dir = os.path.join(dataset_dir, name)

    if cfg.dataset.name == 'PPI':
        from torch_geometric.datasets import PPI
        def my_dataset_transform(data):
            data.node_feature = data.x
            data.edge_feature = data.edge_attr
            # del data.x
            # del data.edge_attr
            data.node_label = data.y
            data.node_label_index = torch.arange(data.x.shape[0])
            return data

        dataset_tr = PPI(root=dataset_dir, split='train', transform=my_dataset_transform)
        setattr(dataset_tr, 'num_labels', dataset_tr.num_classes)
        setattr(dataset_tr, 'node_dim', dataset_tr.num_node_features)
        setattr(dataset_tr, 'edge_dim', dataset_tr.num_edge_features)

        dataset_val = PPI(root=dataset_dir, split='val', transform=my_dataset_transform)
        dataset_tst = PPI(root=dataset_dir, split='test', transform=my_dataset_transform)

        datasets = [dataset_tr,
                    dataset_val,
                    dataset_tst]

        cfg.dataset.label_dim = dataset_tr.num_classes
        cfg.dataset.node_dim = dataset_tr.num_node_features
        cfg.model.loss_fun = 'binary_cross_entropy'

        num_pos = (dataset_tr.data.y == 1).sum(0)
        num_neg = (dataset_tr.data.y == 0).sum(0)
        cfg.dataset.weight = (num_neg / num_pos).tolist()

    elif cfg.dataset.name in ['CIFAR10']:
        from torch_geometric.datasets import GNNBenchmarkDataset
        def my_dataset_transform(data):
            data.node_feature = data.x
            data.edge_feature = data.edge_attr
            # del data.x
            # del data.edge_attr
            data.graph_label = data.y
            # data.node_label_index = torch.arange(data.x.shape[0])
            return data

        dataset_tr = GNNBenchmarkDataset(root=dataset_dir, name=name, split='train', transform=my_dataset_transform)
        setattr(dataset_tr, 'num_labels', dataset_tr.num_classes)
        setattr(dataset_tr, 'node_dim', dataset_tr.num_node_features)
        setattr(dataset_tr, 'edge_dim', dataset_tr.num_edge_features)

        dataset_val = GNNBenchmarkDataset(root=dataset_dir, name=name, split='val', transform=my_dataset_transform)
        dataset_tst = GNNBenchmarkDataset(root=dataset_dir, name=name, split='test', transform=my_dataset_transform)

        # dataset_tr = dataset_tr[:1024]
        # dataset_val = dataset_val[:1024]
        # dataset_tst = dataset_tst[:1024]

        datasets = [dataset_tr,
                    dataset_val,
                    dataset_tst]

        cfg.dataset.label_dim = dataset_tr.num_classes
        cfg.dataset.node_dim = dataset_tr.num_node_features
        cfg.dataset.edge_dim = dataset_tr.num_edge_features

        cfg.model.loss_fun = 'cross_entropy'

        # cfg.dataset.weight = get_weight(datasets[0])

    else:
        raise NotImplementedError

    # ## Transform the whole dataset
    # dataset.node_dim = cfg.dataset.node_dim
    # dataset.edge_dim = cfg.dataset.edge_dim
    #

    datasets = transform_after_split(datasets)

    return datasets


def create_loader(datasets):
    loader_train = DataLoader([datasets[0]],
                              batch_size=cfg.train.batch_size, shuffle=True,
                              num_workers=cfg.num_workers, pin_memory=False)

    loaders = [loader_train]
    for i in range(1, len(datasets)):
        loaders.append(DataLoader([datasets[i]],
                                  batch_size=cfg.train.batch_size,
                                  shuffle=False,
                                  num_workers=cfg.num_workers,
                                  pin_memory=False))

    return loaders
