import dgl
import numpy as np
import torch
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import RandomSampler, SequentialSampler
import networkx as nx


def filter_nb_nodes(graphs, min_node, max_node, *linked_lists):
    return list(map(
        # Converting all elements from tuples to lists
        list,
        # Filtering graph & linked_data based on graph's number of nodes
        zip(*filter(
            lambda p: (min_node <= p[0].number_of_nodes()) and (p[0].number_of_nodes() <= max_node),
            zip(graphs, *linked_lists)
        ))
    ))


def filter_only_connected(graphs, *linked_lists):
    return list(map(
        # Converting all elements from tuples to lists
        list,
        # Filtering graph & linked_data based on graph being connected
        zip(*filter(
            lambda p: nx.is_connected(p[0].to_networkx().to_undirected()),
            zip(graphs, *linked_lists)
        ))
    ))


def is_true(val):
    # Check if arg val (usually passed as str through CLI but perhaps updated) evaluates to true
    return val == True or val == 'True' or val == 'true' or val == '1' or val == 1


def in_feats(dataset):
    return dataset[0][0].ndata['attr'].shape[1]


def out_feats(dataset):
    if hasattr(dataset, 'n_labels'):
        return dataset.n_labels
    return dataset[0][1].shape[0]


def separate_by_nodes(graphs, min_node, max_node, targets):
    graphs_list = []
    targets_list = []
    nodes_list = []
    for n in range(min_node, max_node+1):
        result = filter_nb_nodes(graphs, n, n, targets)
        if result:
            graphs_filtered, targets_filtered = result
            targets_filtered = torch.cat([target.unsqueeze(0) for target in targets_filtered], axis=0)
            graphs_list.append(graphs_filtered)
            targets_list.append(targets_filtered)
            nodes_list.append(n)

    return graphs_list, targets_list, np.array(nodes_list)



# Ref.: https://docs.dgl.ai/tutorials/blitz/5_graph_classification.html#sphx-glr-tutorials-blitz-5-graph-classification-py
def dataset_split(dataset, train_size=0.8, val_size=0.1, batch_size=10, seed=None, fast=False):
    """
    Used to build a dataset split
    :param seed: the seed used for reproductibility
    :param fast: boolean to use only 100 samples
    :return: train_dataloader, test_dataloader
    """
    num_examples = min(100, len(dataset)) if fast else len(dataset)
    num_train = int(num_examples * train_size)
    num_val = int(num_examples * val_size)

    if num_train == 0 or num_val == 0 or num_examples - (num_train + num_val) == 0:
        raise ValueError('Increase data split sizes to have having an empty split')

    torch.manual_seed(seed)  # for reproductibility of samples
    train_sampler = RandomSampler(torch.arange(num_train))
    val_sampler = RandomSampler(torch.arange(num_train, num_train + num_val))
    test_sampler = RandomSampler(torch.arange(num_train + num_val, num_examples))

    train_dataloader = GraphDataLoader(dataset, sampler=train_sampler, batch_size=batch_size, drop_last=False)
    val_dataloader = GraphDataLoader(dataset, sampler=val_sampler, batch_size=batch_size, drop_last=False)
    test_dataloader = GraphDataLoader(dataset, sampler=test_sampler, batch_size=batch_size, drop_last=False)
    return train_dataloader, val_dataloader, test_dataloader


def get_first_item(dataloader):
    """ Show first batch of provided GraphDataLoader """
    it = iter(dataloader)
    batch = next(it)
    return batch


def investigate_batch(batch):
    """ Investigate a batch (= batched graphs) """
    batched_graph, labels = batch
    print('Number of nodes for each graph element in the batch:', batched_graph.batch_num_nodes())
    print('Number of edges for each graph element in the batch:', batched_graph.batch_num_edges())

    # Recover the original graph elements from the minibatch
    graphs = dgl.unbatch(batched_graph)
    print('The original graphs in the minibatch:')
    print(graphs)