"""MoleculeNet data loading"""

import logging
import random

import networkx as nx
import numpy as np

from operator import itemgetter

# from typing import Callable, Optional

from torch_geometric.utils import from_networkx
from tqdm import tqdm

_LOG = logging.getLogger("vqt2g_logger")


# If these values are changed other sections will break
_real_node_pad_val = 0
_pad_node_pad_val = 1
_pad_node_attr_val = 0


def no_features(graph, adj, idx, **kwargs):
    """No features on nodes. Only used for debugging."""
    return []


def padded_adj_row(graph, adj, idx, max_nodes, **kwargs):
    """Row of the adjacency for the node with indicators for the node and padding"""
    row = adj[idx]
    num_pad = max_nodes - len(row)
    padding = np.array([-1] * num_pad)
    row[idx] = 2
    return list(np.concatenate((row, padding)))


def degree_feature(graph, adj, idx, max_degree, one_hot_degree, **kwargs):
    """One-hot degree indicator. Up to max degree in dataset"""

    node_degree = graph.degree(idx) // 2
    if one_hot_degree:
        feat_vec = np.zeros(max_degree, dtype=int)
        try:
            feat_vec[node_degree] = 1
        except IndexError as e:
            raise ValueError(
                f"In a {len(graph)}-node graph, found node with degree too large: {node_degree}. "
                f"Max degree (for one-hot degree vector) was set to: {max_degree}"
            ) from e
    else:
        feat_vec = np.array([node_degree], dtype=int)
    return feat_vec


def concat_adj_ohe_deg(graph, adj, idx, max_nodes, max_degree, one_hot_degree):
    """Use concatenated adjacency row and one-hot degree features"""
    row = padded_adj_row(graph, adj, idx, max_nodes)
    ohe = degree_feature(graph, adj, idx, max_degree, one_hot_degree)
    return list(np.concatenate((row, ohe)))


attr_func_map = {
    "none": no_features,
    "adj": padded_adj_row,
    "degree": degree_feature,
    "degree+adj": concat_adj_ohe_deg,
    "adj+degree": concat_adj_ohe_deg,
}


def padded_graph(
    nx_graph: nx.DiGraph,
    num_attrs: int,
    make_attrs: bool = True,
    attr_func=None,
    attr_name="attr",
    max_nodes=0,
    max_degree=0,
    one_hot_degree=True,
):
    """Pad one graph

    Args:
      nx_graph: nx.DiGraph:
      num_attrs: int:
      make_attrs: bool:  (Default value = True)
      attr_func:  (Default value = None)
      attr_name:  (Default value = "attr")
      max_nodes:  (Default value = 0)
      max_degree:  (Default value = 0)
      one_hot_degree: (Default value = 0)

    Returns:

    """

    if make_attrs:
        adj = nx.to_numpy_array(nx_graph)
        nx_graph = nx.from_numpy_array(adj)
        for idx, node in enumerate(nx_graph.nodes):
            nx_graph.nodes[node]["attr"] = attr_func(
                graph=nx_graph,
                adj=adj,
                idx=idx,
                max_nodes=max_nodes,
                max_degree=max_degree,
                one_hot_degree=one_hot_degree,
            )

    # padded_nx = nx_graph  ### Don't think this is necessary anymore

    # Add padding
    nx.set_node_attributes(nx_graph, _real_node_pad_val, "pad")
    pad_nodes = nx.empty_graph(max_nodes, create_using=nx.DiGraph)
    pad_nodes.remove_nodes_from(nx_graph.nodes)
    nx_graph.add_nodes_from(
        pad_nodes,
        attr=[_pad_node_attr_val] * num_attrs,
        pad=_pad_node_pad_val,
    )
    nx_graph.remove_edges_from(nx.selfloop_edges(nx_graph))

    # Return graph in pytorch geometric format
    return to_pyg(nx_graph, attr_name=attr_name)


def pad_all_graphs(
    nx_graphs,
    num_attrs,
    make_attrs=True,
    attr_func=None,
    attr_name="attr",
    max_nodes=0,
    max_degree=0,
    one_hot_degree=True,
):
    """Pad every graph in the dataset

    Args:
      nx_graphs:
      num_attrs:
      make_attrs:  (Default value = True)
      attr_func:  (Default value = None)
      attr_name:  (Default value = "attr")
      max_nodes:  (Default value = 0)
      max_degree:  (Default value = 0)
      one_hot_degree: (Default value = True)

    Returns:

    """
    graphs = []
    for graph in tqdm(nx_graphs, desc="Converting graphs for pytorch geometric"):
        g = padded_graph(
            graph,
            num_attrs=num_attrs,
            make_attrs=make_attrs,
            attr_func=attr_func,
            attr_name=attr_name,
            max_nodes=max_nodes,
            max_degree=max_degree,
            one_hot_degree=one_hot_degree,
        )
        graphs.append(g)
    return graphs


def to_pyg(padded_nx, attr_name="attr"):
    """Convert one padded graph to pyg format"""
    return from_networkx(padded_nx, group_node_attrs=[attr_name, "pad"])


def pyg_dataset(
    graph_list,
    add_attrs=True,
    existing_attr_name="attr",
    new_attrs="degree+adj",
    one_hot_degree=True,
):
    """Convert list of networkx graphs to pytorch geometric format

    Args:
      graph_list:
      add_attrs: (Default value = True)
      existing_attr_name: (Default value = "attr")
      new_attrs: (Default value = "degree+adj")

    Returns:

    """

    max_nodes = max([len(g) for g in graph_list])
    if one_hot_degree:
        max_degree = max([max(g.degree(), key=lambda x: x[1])[1] for g in graph_list])
    else:
        # Not really wanting 'max degree' when not doing one-hot degree. This is only used for
        # calculating total node feature dimension anyway
        max_degree = 1

    if not add_attrs:
        if not all(existing_attr_name in g.nodes[node] for g in graph_list for node in g):
            _LOG.warning(f"Some nodes don't have attribute '{existing_attr_name}'")
        num_attrs = len(graph_list[0].nodes[0].get(existing_attr_name, [0]))
        attr_name = existing_attr_name
        attr_func = None

    else:
        attr_name = "attr"
        attr_func = attr_func_map.get(new_attrs)
        if not attr_func:
            raise ValueError(
                f"Attribute '{new_attrs}' is invalid, options are {list(attr_func_map.keys())}"
            )

        num_attrs_map = {
            "none": 0,
            "degree": max_degree,
            "adj": max_nodes,
            "degree+adj": max_nodes + max_degree,
            "adj+degree": max_nodes + max_degree,
        }
        num_attrs = num_attrs_map[new_attrs]

    pyg_graphs = pad_all_graphs(
        graph_list,
        num_attrs,
        make_attrs=add_attrs,
        attr_func=attr_func,
        attr_name=attr_name,
        max_nodes=max_nodes,
        max_degree=max_degree,
        one_hot_degree=one_hot_degree,
    )
    return pyg_graphs


def load_vqt2g_dataset(
    data,
    proportion_or_count="proportion",
    test_prop=0.2,
    test_num=0,
    add_node_attrs=True,
    attr_type="degree+adj",
    shuffle=True,
    seed=None,
    max_dataset_size=0,
    one_hot_degree=True,
):
    """Load graph + text (or graph-only) datase:, add node attrs, train/test split

    Dataset assumed to be in memory already. A graph + text dataset should be a list of tuples
    where the graph is the first element and the corresponding text(s) the second element. If
    there's multiple texts for a graph they should be lists. When loading a graph-only dataset, the
    returned texts will be a default value of a single space for each graph.

    Args:
      data:
      proportion_or_count:  (Default value = "proportion")
      test_prop:  (Default value = 0.2)
      test_num:  (Default value = 0)
      add_node_attrs:  (Default value = True)
      attr_type:  (Default value = "degree+adj")
      shuffle:  (Default value = True)
      seed:  (Default value = None)
      max_dataset_size:  (Default value = 0)

    Returns:
      train_graphs
      test_graphs
      train_texts
      test_texts
      train_indices

    """

    # Shuffle list of indices not the `data` list so indices can be saved - e.g. for consistent
    # dataset splits in baseline models that use networkx (not pyg) objects
    dataset_size = len(data)
    dataset_indices = list(range(dataset_size))
    if shuffle:
        random.seed(seed)
        # random.shuffle(data)
        random.shuffle(dataset_indices)

    if max_dataset_size is not None and max_dataset_size > 0:
        # data = data[:max_dataset_size]
        dataset_indices = dataset_indices[:max_dataset_size]

    if proportion_or_count == "proportion":
        train_prop = 1 - test_prop
        train_num = int(len(dataset_indices) * train_prop)
    elif proportion_or_count == "count":
        train_num = len(dataset_indices) - test_num
    else:
        raise ValueError(f"Got split '{proportion_or_count}', must be 'proportion' or 'count'")

    train_indices = dataset_indices[:train_num]
    test_indices = dataset_indices[train_num:]
    indices = {"train": train_indices, "test": test_indices}

    dataset_getter = itemgetter(*dataset_indices)

    # Check if graph+text or graph-only dataset
    texts_exist = isinstance(data[0], tuple)

    # Load graphs

    graphs = [i[0] for i in dataset_getter(data)] if texts_exist else dataset_getter(data)
    graphs = pyg_dataset(
        graph_list=graphs,
        add_attrs=add_node_attrs,
        new_attrs=attr_type,
        one_hot_degree=one_hot_degree,
    )

    # Load texts
    if texts_exist:
        texts = [i[1] for i in dataset_getter(data)]
    else:  # Texts are set to spaces
        texts = [" " for _ in dataset_getter(data)]

    train_graphs = graphs[:train_num]
    test_graphs = graphs[train_num:]
    train_texts = texts[:train_num]
    test_texts = texts[train_num:]

    return train_graphs, test_graphs, train_texts, test_texts, indices
