"""Subroutines for converting PNA dataset to jraph GraphsTuple."""

import cloudpickle
import jraph
import numpy as np

import graph_utils

_PATH_TO_STORE = 'data/pna_jraph.pkl'
_PATH_TO_FILE = 'data/raw/multitask_dataset.pkl'
_DATASET = {'train': [], 'val': [], 'test': []}
# The number of eigenvectors to store for DGN baseline
# https://arxiv.org/abs/2010.02863
_MAX_EIGENVECTORS = 10


def concatenate_matrices_with_padding(matrices):
  dims = [0, 0]
  for matrix in matrices:
    dims[0] += matrix.shape[0]
    dims[1] = max(dims[1], matrix.shape[1])
  result = np.zeros(dims)
  row_offset = 0
  for matrix in matrices:
    result[row_offset:row_offset + matrix.shape[0], 0:matrix.shape[1]] = matrix
    row_offset += matrix.shape[0]


def main():
  with open(_PATH_TO_FILE, 'rb') as f:
    pna_data = cloudpickle.load(f)

  assert len(pna_data) == 4
  max_node_labels, max_graph_labels = None, None

  # keep track of the maximum # of edges among all graphs.
  # this is necessary so as to pad the effective resistances (whose
  # dimensions depend on the graph size) to the same size.
  max_nb_edges = 0

  for stage in ['train', 'val', 'test']:
    nb_batches = len(pna_data[0][stage])
    for batch in range(nb_batches):
      nb_graphs = pna_data[0][stage][batch].shape[0]
      for graph in range(nb_graphs):
        max_nb_edges = max(
            max_nb_edges,
            graph_utils.count_edges(pna_data[0][stage][batch][graph]))

  for stage in ['train', 'val', 'test']:
    assert len(pna_data[0][stage]) == len(pna_data[1][stage])
    assert len(pna_data[0][stage]) == len(pna_data[2][stage])
    assert len(pna_data[0][stage]) == len(pna_data[3][stage])
    nb_batches = len(pna_data[0][stage])
    for batch in range(nb_batches):
      assert pna_data[0][stage][batch].shape[0] == pna_data[1][stage][
          batch].shape[0]
      assert pna_data[0][stage][batch].shape[0] == pna_data[2][stage][
          batch].shape[0]
      assert pna_data[0][stage][batch].shape[0] == pna_data[3][stage][
          batch].shape[0]
      nb_graphs = pna_data[0][stage][batch].shape[0]
      nodes = []
      in_deg_list = []
      out_deg_list = []
      senders = []
      receivers = []
      n_node = []
      n_edge = []
      resistances = []
      hitting_times = []
      embeddings = []
      eigenv_diffs = []
      node_labels = []
      distance_embeddings = []

      tot_nodes = 0
      for graph in range(nb_graphs):
        assert pna_data[0][stage][batch][graph].shape[0] == pna_data[0][stage][
            batch][graph].shape[1]
        assert pna_data[0][stage][batch][graph].shape[0] == pna_data[1][stage][
            batch][graph].shape[0]
        assert pna_data[0][stage][batch][graph].shape[0] == pna_data[2][stage][
            batch][graph].shape[0]
        nb_nodes = pna_data[0][stage][batch][graph].shape[0]
        nodes.append(pna_data[1][stage][batch][graph])
        cur_senders, cur_receivers, eff_res_embedding, eff_res, hitting_time, eigenv_diff, rows, cols, in_degs, out_degs, distance_embedding = (
            graph_utils.get_edges_and_features(pna_data[0][stage][batch][graph],
                                               tot_nodes))

        if graph % 50 == 1:
          print('{0}/{1} graphs of stage {2} processed so far.'.format(
              graph - 1, nb_graphs, stage))

        # print('n = ', nb_nodes, ' m = ', len(cur_senders), 'ER embedding = ', eff_res_embedding.shape, ' ER = ', eff_res.shape, ' HT = ', hitting_times.shape)

        # Make sure that all graphs have the same # of dimensions in their
        # embeddings.
        padded_embedding = np.zeros((eff_res_embedding.shape[0], max_nb_edges))
        padded_embedding[:, :eff_res_embedding.shape[1]] = eff_res_embedding

        assert len(rows) == len(cols)
        shift = int(len(rows) / 2)
        padded_embedding_edges = np.zeros((shift, 2 * max_nb_edges))
        for i in range(eff_res_embedding.shape[0]):
          assert rows[i] == i
          assert rows[i + shift] == i
          padded_embedding_edges[i, :max_nb_edges] = padded_embedding[
              cols[i], :max_nb_edges]
          padded_embedding_edges[i, max_nb_edges:] = padded_embedding[
              cols[i + shift], :max_nb_edges]

        in_deg_list.append(
            np.expand_dims(np.asarray(in_degs, dtype=np.int32), axis=1))
        out_deg_list.append(
            np.expand_dims(np.asarray(out_degs, dtype=np.int32), axis=1))
        senders.extend(cur_senders)
        receivers.extend(cur_receivers)
        n_node.append(nb_nodes)
        n_edge.append(len(cur_senders))
        node_labels.append(pna_data[2][stage][batch][graph])
        resistances.append(eff_res)
        hitting_times.append(hitting_time)
        embeddings.append(padded_embedding)
        distance_embeddings.append(distance_embedding)

        eigenv_diffs.append(eigenv_diff[:, :_MAX_EIGENVECTORS])

        # embeddings_edges.append(hitting_time)
        tot_nodes += nb_nodes

      node_regular_feats = np.concatenate(nodes, axis=0)
      node_all_feats = {
          'regular': node_regular_feats,
          'in_degs': np.concatenate(in_deg_list, axis=0),
          'out_degs': np.concatenate(out_deg_list, axis=0)
      }
      cur_graph = jraph.GraphsTuple(
          nodes=node_all_feats,
          edges=np.ones((len(senders), 1)),  # dummy edge multipliers
          senders=np.array(senders),
          receivers=np.array(receivers),
          globals=np.zeros((nb_graphs, 1)),  # dummy global features
          n_node=np.array(n_node),
          n_edge=np.array(n_edge))
      node_labels = np.concatenate(node_labels, axis=0)
      graph_labels = pna_data[3][stage][batch]

      if stage == 'train':
        # Compute maximum values for node and graph labels for normalisation
        if max_node_labels is None:
          max_node_labels = np.max(node_labels, axis=0)
          max_graph_labels = np.max(graph_labels, axis=0)
        else:
          max_node_labels = np.maximum(max_node_labels,
                                       np.max(node_labels, axis=0))
          max_graph_labels = np.maximum(max_graph_labels,
                                        np.max(graph_labels, axis=0))

      resistances = np.expand_dims(np.concatenate(resistances, axis=0), -1)
      hitting_times = np.concatenate(hitting_times, axis=0)
      eigenv_diffs = np.concatenate(eigenv_diffs, axis=0)
      embeddings = np.concatenate(embeddings, axis=0)
      distance_embeddings = concatenate_matrices_with_padding(
          distance_embeddings)

      _DATASET[stage].append(
          (cur_graph, node_labels, graph_labels, resistances, hitting_times,
           eigenv_diffs, embeddings, distance_embeddings))

    # Normalise node and graph labels based on training samples
    assert max_node_labels is not None
    assert max_graph_labels is not None
    _DATASET[stage] = [
        (g, nl / np.expand_dims(max_node_labels, 0),
         gl / np.expand_dims(max_graph_labels, 0), re, ht, ed, em, de)
        for (g, nl, gl, re, ht, ed, em, de) in _DATASET[stage]
    ]
    print('{0} complete.'.format(stage))

  print('finished computation, now dumping output.')
  with open(_PATH_TO_STORE, 'wb') as g:
    bytes_ = cloudpickle.dumps(_DATASET)
    print('   converted to byte stream, now writing {0}Mbytes.'.format(
        ((len(bytes_) // 1024) // 1024)))
    g.write(bytes_)
    print('   finished.')


if __name__ == '__main__':
  main()
