import networkx as nx
import time
import logging
import pickle

from graphgym.contrib.utils.deepsnap import MyGraphDataset, MyGraphDataset2
import torch
from torch.utils.data import DataLoader

from torch_geometric.datasets import *
import torch_geometric.transforms as T

import torch_geometric.utils as tgutils
from graphgym.config import cfg
import graphgym.models.feature_augment as preprocess
from graphgym.models.transform import (ego_nets, remove_node_feature,
                                       edge_nets, path_len)
from graphgym.contrib.loader import *
import graphgym.register as register

from ogb.graphproppred import PygGraphPropPredDataset
from ogb.nodeproppred import PygNodePropPredDataset
from ogb.linkproppred import PygLinkPropPredDataset

from deepsnap.batch import Batch
import torch.nn.functional as F

from graphgym.contrib.utils.random import get_permutation
from graphgym.utils.loader import get_weight


import random
import numpy as np

def load_pyg(name, dataset_dir):
    '''
    load pyg format dataset
    :param name: dataset name
    :param dataset_dir: data directory
    :return: a list of networkx/deepsnap graphs
    '''
    dataset_dir = '{}/{}'.format(dataset_dir, name)
    if name in ['Cora', 'CiteSeer', 'PubMed']:
        dataset_raw = Planetoid(dataset_dir,
                                name)
        cfg.dataset.weight = get_weight(dataset_raw)
        if dataset_raw.data.edge_attr is not None:
            assert False, f"Not implemented"
            cfg.dataset.edge_dim = dataset_raw.data.edge_attr.max().item() + 1

        cfg.dataset.node_dim = dataset_raw.data.x.shape[1]
        cfg.dataset.label_dim = dataset_raw.num_classes
    elif name in ['CoraFull']:
        dataset_raw = CoraFull(root=dataset_dir)
        cfg.dataset.weight = get_weight(dataset_raw)
        if dataset_raw.data.edge_attr is not None:
            assert False, f"Not implemented"
            cfg.dataset.edge_dim = dataset_raw.data.edge_attr.max().item() + 1

        cfg.dataset.node_dim = dataset_raw.data.x.shape[1]
        cfg.dataset.label_dim = dataset_raw.num_classes

    elif name in ['GitHub']:
        dataset_raw = GitHub(root=dataset_dir)
        cfg.dataset.weight = get_weight(dataset_raw)
        if dataset_raw.data.edge_attr is not None:
            assert False, f"Not implemented"
            cfg.dataset.edge_dim = dataset_raw.data.edge_attr.max().item() + 1

        cfg.dataset.node_dim = dataset_raw.data.x.shape[1]
        cfg.dataset.label_dim = dataset_raw.num_classes
    elif name in ['FacebookPagePage']:
        dataset_raw = FacebookPagePage(root=dataset_dir)
        cfg.dataset.weight = get_weight(dataset_raw)
        if dataset_raw.data.edge_attr is not None:
            assert False, f"Not implemented"
            cfg.dataset.edge_dim = dataset_raw.data.edge_attr.max().item() + 1

        cfg.dataset.node_dim = dataset_raw.data.x.shape[1]
        cfg.dataset.label_dim = dataset_raw.num_classes
    elif 'Twitch' in name:
        dataset_raw = Twitch(root=dataset_dir, name=name.replace('Twitch', ''))
        cfg.dataset.weight = get_weight(dataset_raw)
        if dataset_raw.data.edge_attr is not None:
            assert False, f"Not implemented"
            cfg.dataset.edge_dim = dataset_raw.data.edge_attr.max().item() + 1

        cfg.dataset.node_dim = dataset_raw.data.x.shape[1]
        cfg.dataset.label_dim = dataset_raw.num_classes
    elif name[:3] == 'TU_':
        # TU_IMDB doesn't have node features
        if name[3:] == 'IMDB':
            name = 'IMDB-MULTI'

            def pre_transform(data):
                del data.num_nodes
                return data

            dataset_raw = TUDataset(dataset_dir, name,
                                    transform=T.Constant(),
                                    pre_transform=pre_transform)

        elif name[3:] == 'REDDIT':
            name = 'REDDIT-BINARY'

            def pre_transform(data):
                del data.num_nodes
                return data

            dataset_raw = TUDataset(dataset_dir, name,
                                    transform=T.Constant(),
                                    use_node_attr=True,
                                    pre_transform=pre_transform)
        else:
            dataset_raw = TUDataset(dataset_dir,
                                    name[3:],
                                    use_node_attr=cfg.dataset.tu_use_node_attr,
                                    use_edge_attr=cfg.dataset.tu_use_edge_attr)

        # TU_dataset only has graph-level label
        cfg.dataset.weight = get_weight(dataset_raw)

        if cfg.dataset.use_subset < 1.0:
            original_len = len(dataset_raw)
            perm = get_permutation(original_len=original_len)
            new_len = int(original_len * cfg.dataset.use_subset)
            dataset_raw = dataset_raw[perm[:new_len]]

        cfg.dataset.node_dim = dataset_raw.num_node_features
        # The goal is to have synthetic tasks
        # that select smallest 100 graphs that have more than 200 edges
        if cfg.dataset.tu_simple and cfg.dataset.task != 'graph':
            size = []
            for data in dataset_raw:
                edge_num = data.edge_index.shape[1]
                edge_num = 9999 if edge_num < 200 else edge_num
                size.append(edge_num)
            size = torch.tensor(size)
            order = torch.argsort(size)[:100]
            dataset_raw = dataset_raw[order]
    elif name == 'Karate':
        dataset_raw = KarateClub()
    elif 'Coauthor' in name:
        if 'CS' in name:
            dataset_raw = Coauthor(dataset_dir, name='CS')
        else:
            dataset_raw = Coauthor(dataset_dir, name='Physics')
        cfg.dataset.weight = get_weight(dataset_raw)
        if dataset_raw.data.edge_attr is not None:
            assert False, f"Not implemented"
            cfg.dataset.edge_dim = dataset_raw.data.edge_attr.max().item() + 1

        cfg.dataset.node_dim = dataset_raw.data.x.shape[1]
        cfg.dataset.label_dim = dataset_raw.num_classes
    elif name in ['DBLP']:
        from torch_geometric.datasets import CitationFull
        dataset_raw = CitationFull(root=dataset_dir, name=name)
        cfg.dataset.weight = get_weight(dataset_raw)
        cfg.dataset.node_dim = dataset_raw.data.x.shape[1]
        cfg.dataset.label_dim = dataset_raw.num_classes
    elif 'Amazon' in name:
        if 'Computers' in name:
            dataset_raw = Amazon(dataset_dir, name='Computers')
        else:
            dataset_raw = Amazon(dataset_dir, name='Photo')
        cfg.dataset.weight = get_weight(dataset_raw)
        if dataset_raw.data.edge_attr is not None:
            assert False, f"Not implemented"
            cfg.dataset.edge_dim = dataset_raw.data.edge_attr.max().item() + 1

        cfg.dataset.node_dim = dataset_raw.data.x.shape[1]
        cfg.dataset.label_dim = dataset_raw.num_classes
    elif name == 'MNIST':
        dataset_raw = MNISTSuperpixels(dataset_dir)
    elif name == 'PPI':
        dataset_raw = PPI(dataset_dir)
    elif name == 'QM7b':
        dataset_raw = QM7b(dataset_dir)
    elif name == 'QM9':
        dataset_raw = QM9(dataset_dir)
        mean = dataset_raw.data.y.mean(dim=0, keepdim=True)
        std = dataset_raw.data.y.std(dim=0, keepdim=True)
        dataset_raw.data.y = (dataset_raw.data.y - mean) / std
    elif name == 'ZINC':
        from torch_geometric.datasets import ZINC
        dataset_raw = ZINC(root=dataset_dir, subset=True, split='train')

        # num_classes_x = dataset_raw.data.x.max().item() + 1
        dataset_raw.data.edge_attr = dataset_raw.data.edge_attr - 1
        # num_classes_e = dataset_raw.data.edge_attr.max().item()  + 1
        # dataset_raw.data.x =  F.one_hot(dataset_raw.data.x,num_classes=num_classes_x)
        # dataset_raw.data.edge_attr =  F.one_hot(dataset_raw.data.edge_attr, num_classes=num_classes_e)
        cfg.dataset.edge_dim = dataset_raw.data.edge_attr.max().item() + 1
        cfg.dataset.node_dim = dataset_raw.data.x.max().item() + 1
        mean = dataset_raw.data.y.mean(dim=0, keepdim=True)
        std = dataset_raw.data.y.std(dim=0, keepdim=True)
        dataset_raw.data.y = (dataset_raw.data.y - mean) / std
        cfg.dataset.label_dim = 1

    else:
        raise ValueError('{} not support'.format(name))


    if cfg.dataset.standarize:
        from sklearn import preprocessing
        x = dataset_raw.data.x
        scaler = preprocessing.StandardScaler().fit(x)
        scaler = None#  lambda x: torch.tensor(scaler.transform(x)).to(cfg.device)

    else:
        scaler =  None # lambda x: x

    # cfg.share.scaler = [scaler]
    if cfg.dataset.task == 'edge':
        graphs = MyGraphDataset.pyg_to_graphs(dataset_raw)
    else:
        graphs = MyGraphDataset2.pyg_to_graphs(dataset_raw)

    return graphs


def load_nx(name, dataset_dir):
    '''
    load networkx format dataset
    :param name: dataset name
    :param dataset_dir: data directory
    :return: a list of networkx/deepsnap graphs
    '''
    try:
        with open('{}/{}.pkl'.format(dataset_dir, name), 'rb') as file:
            graphs = pickle.load(file)
    except:
        graphs = nx.read_gpickle('{}/{}.gpickle'.format(dataset_dir, name))
        if not isinstance(graphs, list):
            graphs = [graphs]
    return graphs


def load_dataset():
    '''
    load raw datasets.
    :return: a list of networkx/deepsnap graphs, plus additional info if needed
    '''
    format = cfg.dataset.format
    name = cfg.dataset.name
    # dataset_dir = '{}/{}'.format(cfg.dataset.dir, name)
    dataset_dir = cfg.dataset.dir
    # Try to load customized data format
    for func in register.loader_dict.values():
        graphs = func(format, name, dataset_dir)
        if graphs is not None:
            return graphs
    # Load from Pytorch Geometric dataset
    if format == 'PyG':
        graphs = load_pyg(name, dataset_dir)
    # Load from networkx formatted data
    # todo: clean nx dataloader
    elif format == 'nx':
        graphs = load_nx(name, dataset_dir)
    # Load from OGB formatted data
    elif cfg.dataset.format == 'OGB':
        if cfg.dataset.name in ['ogbg-molhiv', 'ogbg-ppa', 'ogbg-molpcba']:
            dataset_raw = PygGraphPropPredDataset(name=cfg.dataset.name)
        elif cfg.dataset.name in ['ogbn-mag', 'ogbn-arxiv', 'ogbn-proteins']:
            def my_transform(data):
                data.edge_index = tgutils.to_undirected(data.edge_index)
                data.y = data.y.flatten()
                return data
            dataset_raw = PygNodePropPredDataset(name=cfg.dataset.name,
                                                 transform=my_transform)
        elif cfg.dataset.name in ['ogbl-ppa']:
            raise NotImplementedError
            dataset_raw = PygLinkPropPredDataset(name=cfg.dataset.name)
        else:
            raise NotImplementedError

        cfg.dataset.weight = get_weight(dataset_raw)
        if dataset_raw.data.edge_attr is not None:
            assert False, f"Not implemented"
            cfg.dataset.edge_dim =  dataset_raw.num_edge_features

        cfg.dataset.node_dim = dataset_raw.data.x.shape[1]
        cfg.dataset.label_dim = dataset_raw.num_classes

        if dataset_raw.num_edge_features > 0:
            cfg.dataset.encoder_edge_dim = dataset_raw.num_edge_features

        # Note this is only used for custom splits from OGB
        split_idx = dataset_raw.get_idx_split()

        if cfg.dataset.use_subset < 1.0 and cfg.dataset.task == 'graph':

            original_len = len(dataset_raw)
            perm = get_permutation(original_len=original_len)
            new_len = int(original_len * cfg.dataset.use_subset)
            dataset_raw = dataset_raw[perm[:new_len]]
            perm = get_permutation(original_len=new_len)
            all_idx = torch.arange(new_len)[perm]
            start_idx, end_idx = 0, None
            all_idx_splits = []

            num_splits = len(split_idx)
            for i, key in enumerate(split_idx):
                value = split_idx[key]
                new_len = int(value.size(0) * cfg.dataset.use_subset)
                end_idx = new_len + start_idx
                if i == (num_splits - 1):
                    split_idx[key] = all_idx[start_idx:]
                else:
                    split_idx[key] = all_idx[start_idx:end_idx]
                start_idx = end_idx
                all_idx_splits.append(split_idx[key])

            all_idx_splits = torch.cat(all_idx_splits).unique()
            assert len(dataset_raw) == len(all_idx_splits), f"{len(dataset_raw)}, {len(all_idx_splits)}"
        graphs = MyGraphDataset2.pyg_to_graphs(dataset_raw)

        return graphs, split_idx
    else:
        raise ValueError('Unknown data format: {}'.format(cfg.dataset.format))
    return graphs


def filter_graphs():
    '''
    Filter graphs by the min number of nodes
    :return: min number of nodes
    '''
    if cfg.dataset.task == 'graph':
        min_node = 0
    else:
        min_node = 5
    return min_node


def transform_before_split(dataset):
    '''
    Dataset transformation before train/val/test split
    :param dataset: A DeepSNAP dataset object
    :return: A transformed DeepSNAP dataset object
    '''
    if cfg.dataset.remove_feature:
        cfg.dataset.node_dim = 1
        dataset.apply_transform(remove_node_feature,
                                update_graph=True, update_tensor=False)
    augmentation = preprocess.FeatureAugment()
    actual_feat_dims, actual_label_dim = augmentation.augment(dataset)
    if cfg.dataset.augment_label:
        dataset.apply_transform(preprocess._replace_label,
                                update_graph=True, update_tensor=False)
    # Update augmented feature/label dims by real dims (user specified dims
    # may not be realized)
    cfg.dataset.augment_feature_dims = actual_feat_dims
    if cfg.dataset.augment_label:
        cfg.dataset.augment_label_dims = actual_label_dim

    # Temporary for ID-GNN path prediction task
    if cfg.dataset.task == 'edge' and 'id' in cfg.gnn.layer_type:
        dataset.apply_transform(path_len, update_graph=False,
                                update_tensor=False)

    return dataset


def transform_after_split(datasets):
    '''
    Dataset transformation after train/val/test split
    :param dataset: A list of DeepSNAP dataset objects
    :return: A list of transformed DeepSNAP dataset objects
    '''
    if cfg.dataset.transform == 'ego':
        for split_dataset in datasets:
            split_dataset.apply_transform(ego_nets,
                                          radius=cfg.gnn.layers_mp,
                                          update_tensor=True,
                                          update_graph=False)
    elif cfg.dataset.transform == 'edge':
        for split_dataset in datasets:
            split_dataset.apply_transform(edge_nets,
                                          update_tensor=True,
                                          update_graph=False)
            split_dataset.task = 'node'
        cfg.dataset.task = 'node'
    return datasets


def create_dataset():
    ## Load dataset
    time1 = time.time()
    if cfg.dataset.format == 'OGB':
        graphs, splits = load_dataset()
    else:
        graphs = load_dataset()

    ## Filter graphs
    time2 = time.time()
    min_node = filter_graphs()

    std = graphs[0].node_feature.std()

    ## Create whole dataset
    if cfg.dataset.task == 'edge':
        dataset = MyGraphDataset(
            graphs,
            task=cfg.dataset.task,
            edge_train_mode=cfg.dataset.edge_train_mode,
            edge_message_ratio=cfg.dataset.edge_message_ratio,
            edge_negative_sampling_ratio=cfg.dataset.edge_negative_sampling_ratio,
            resample_disjoint=cfg.dataset.resample_disjoint,
            minimum_node_per_graph=min_node)
    else:
        dataset = MyGraphDataset2(
            graphs,
            task=cfg.dataset.task,
            edge_train_mode=cfg.dataset.edge_train_mode,
            edge_message_ratio=cfg.dataset.edge_message_ratio,
            edge_negative_sampling_ratio=cfg.dataset.edge_negative_sampling_ratio,
            resample_disjoint=cfg.dataset.resample_disjoint,
            minimum_node_per_graph=min_node)

    if cfg.dataset.name in ['QM9']:
        dataset._num_graph_labels = graphs[0].graph_label.shape[1]

    ## Transform the whole dataset
    dataset = transform_before_split(dataset)
    dataset.node_dim = cfg.dataset.node_dim
    dataset.edge_dim = cfg.dataset.edge_dim

    ## Split dataset
    time3 = time.time()
    # Use custom data splits
    if cfg.dataset.format == 'OGB' and  cfg.dataset.task == 'graph':
        datasets = []
        datasets.append(dataset[splits['train']])
        datasets.append(dataset[splits['valid']])
        datasets.append(dataset[splits['test']])
    # Use random split, supported by DeepSNAP
    else:
        if cfg.k_fold >= 0:
            random.seed(cfg.k_fold)
            np.random.seed(cfg.k_fold)
            torch.manual_seed(cfg.k_fold)
        else:
            random.seed(cfg.seed)
            np.random.seed(cfg.seed)
            torch.manual_seed(cfg.seed)
        if len(cfg.dataset.split) > 1:
            datasets = dataset.split(
                transductive=cfg.dataset.transductive,
                split_ratio=cfg.dataset.split,
                shuffle=True)
        else:
            datasets = [dataset]


    # We only change the training negative sampling ratio
    for i in range(1, len(datasets)):
        dataset.edge_negative_sampling_ratio = 1

    ## Transform each split dataset
    time4 = time.time()
    datasets = transform_after_split(datasets)

    if cfg.dataset.name in ['QM9']:
        print(f"TODO: Properly normalize labels in QM9")  # TODO

    time5 = time.time()
    logging.info('Load: {:.4}s, Before split: {:.4}s, '
                 'Split: {:.4}s, After split: {:.4}s'.format(
        time2 - time1, time3 - time2, time4 - time3, time5 - time4))

    return datasets


def create_loader(datasets):
    loader_train = DataLoader(datasets[0], collate_fn=Batch.collate(),
                              batch_size=cfg.train.batch_size, shuffle=True,
                              num_workers=cfg.num_workers, pin_memory=False)

    loaders = [loader_train]
    for i in range(1, len(datasets)):
        loaders.append(DataLoader(datasets[i], collate_fn=Batch.collate(),
                                  batch_size=cfg.train.batch_size,
                                  shuffle=False,
                                  num_workers=cfg.num_workers,
                                  pin_memory=False))

    return loaders
