r"""Training code with affinity measures for PNA dataset."""

from absl import flags
from ml_collections.config_flags import config_flags
from scipy import stats
from typing import Union

import cloudpickle
import functools
import getpass
import haiku as hk
import jax
import jax.numpy as jnp
import jraph
import logging
import numpy as np
import optax
import pathlib
import pickle
import sys
import scipy as sp

import models




FLAGS = flags.FLAGS
config_flags.DEFINE_config_file('config', 'config.py')

_NB_NODE_LABELS = 3
_NB_GRAPH_LABELS = 3
_NB_EDGE_LABELS = 3


def make_graph(graph, node_labels, graph_labels, eff_res, hitting_times,
               eigenv_diff, embedding, distance_embedding):
  node_features = graph.nodes
  edge_features = graph.edges
  node_features = node_features['regular']

  if FLAGS.config.use_effective_resistance:
    edge_features = np.hstack((edge_features, eff_res))
  if FLAGS.config.use_eigen_diff:
    edge_features = np.hstack((edge_features, eigenv_diff[:, :_NB_EDGE_LABELS]))
  if FLAGS.config.use_random_features:
    edge_features = np.hstack((edge_features, np.random.randn(*eff_res.shape)))
  if FLAGS.config.use_er_node_embeddings:
    node_features = np.hstack((node_features, embedding))
  if FLAGS.config.use_distance_embeddings:
    node_features = np.hstack((node_features, distance_embedding))
  if FLAGS.config.use_er_edge_embeddings:
    dims = embedding.shape[1]
    edge_embeddings = np.zeros((graph.n_edge.sum(), dims * 2))
    for i, u, v in zip(
        range(len(graph.senders)), graph.senders, graph.receivers):
      edge_embeddings[i, :dims] = embedding[u, :]
      edge_embeddings[i, dims:] = embedding[v, :]
    edge_features = np.hstack((edge_features, edge_embeddings))
  if FLAGS.config.use_hitting_times:
    edge_features = np.hstack((edge_features, hitting_times))


  return jraph.GraphsTuple(
      nodes=node_features,
      edges=edge_features,
      senders=graph.senders,
      receivers=graph.receivers,
      globals=graph.globals,
      n_node=graph.n_node,
      n_edge=graph.n_edge)


def compute_loss(params, graph, node_labels, graph_labels, net):
  """Computes loss."""
  preds = net.apply(params, graph)

  # MSE per-label.
  node_mse_per_label = jnp.mean((preds.nodes - node_labels)**2, axis=0)
  graph_mse_per_label = jnp.mean((preds.globals - graph_labels)**2, axis=0)
  mse_per_label = jnp.concatenate((node_mse_per_label, graph_mse_per_label))

  # MSE loss
  loss = jnp.mean(mse_per_label)

  return loss, jnp.log10(mse_per_label)


def compute_baseline_perf(dataset):
  """Compute baseline performance."""
  base_node = np.zeros(3)
  base_graph = np.zeros(3)
  for ind in range(len(dataset['train'])):
    _, node_labels, graph_labels, _, _, _, _ = (dataset['train'][ind])
    base_node += jnp.mean(node_labels, axis=0)
    base_graph += jnp.mean(graph_labels, axis=0)
  base_node /= len(dataset['train'])
  base_graph /= len(dataset['train'])
  print(base_node, base_graph)

  accumulated_loss = 0.0
  accumulated_logmse = 0.0
  for (_, tst_node_labels, tst_graph_labels, _, _, _, _) in dataset['test']:
    node_mse_per_label = jnp.mean(
        (jnp.expand_dims(base_node, 0) - tst_node_labels)**2, axis=0)
    graph_mse_per_label = jnp.mean(
        (jnp.expand_dims(base_graph, 0) - tst_graph_labels)**2, axis=0)
    mse_per_label = jnp.concatenate((node_mse_per_label, graph_mse_per_label))
    tst_loss = jnp.mean(mse_per_label)
    tst_logmse_per_label = jnp.log10(mse_per_label)
    accumulated_loss += tst_loss
    accumulated_logmse += tst_logmse_per_label
  tst_loss = accumulated_loss / len(dataset['test'])
  tst_logmse = accumulated_logmse / len(dataset['test'])

  logging.info('test loss: %s, log_mse (per task): %s', tst_loss, tst_logmse)


def eval_on_data_set(dataset, params, compute_loss_fn):
  accumulated_loss = 0.0
  accumulated_logmse = 0.0
  for (tst_graph0, tst_node_labels, tst_graph_labels, tst_eff_res, tst_ht,
       tst_eigenv_diff, tst_embedding, tst_distance_embedding) in dataset:
    tst_graph = make_graph(tst_graph0, tst_node_labels, tst_graph_labels,
                           tst_eff_res, tst_ht, tst_eigenv_diff, tst_embedding, tst_distance_embedding)
    (tst_loss,
     tst_logmse_per_label), _ = compute_loss_fn(params, tst_graph,
                                                tst_node_labels,
                                                tst_graph_labels)
    accumulated_loss += tst_loss
    accumulated_logmse += tst_logmse_per_label
  tst_loss = accumulated_loss / len(dataset)
  tst_logmse = accumulated_logmse / len(dataset)
  return tst_loss, tst_logmse


def random_orthonormal_matrix(n, random_state):
  random_gaussian = sp.stats.norm.rvs(size=(n, n), random_state=random_state)
  rot, q = sp.linalg.qr(random_gaussian)
  return rot


def train(data_path: str, num_training_steps: int, save_dir: str):
  """Training loop."""
  # Load dataset
  logging.info('Training PNA.')
  with open(data_path, 'rb') as f:
    bytes_ = f.read()
    dataset = cloudpickle.loads(bytes_)

  model = models.MODELS_DICT[FLAGS.config.model](
      _NB_NODE_LABELS, _NB_EDGE_LABELS, _NB_GRAPH_LABELS,
      FLAGS.config.hidden_size, FLAGS.config.mp_steps, FLAGS.config.num_layers,
      FLAGS.config.use_centrality_encoding)
  # Transform impure `net_fn` to pure functions with hk.transform.
  net = hk.without_apply_rng(hk.transform(model.net_fn))
  # Get a candidate graph and label to initialize the network.
  graph, node_labels, graph_labels, eff_res, hitting_times, eigenv_diff, embedding, distance_embedding = dataset[
      'train'][0]

  graph = make_graph(graph, node_labels, graph_labels, eff_res, hitting_times,
                     eigenv_diff, embedding, distance_embedding)

  print('Generating random matrices.')
  random_rotation_mats = [[]] * FLAGS.config.num_rotation_matrices
  random_state = np.random.RandomState(FLAGS.config.random_rotation_seed)
  for i in range(FLAGS.config.num_rotation_matrices):
    random_rotation_mats[i] = random_orthonormal_matrix(
        n=embedding.shape[1], random_state=random_state)
    print('   {0}/{1} done.'.format(i + 1, FLAGS.config.num_rotation_matrices))
  random_rotation_idx = 0

  # Initialize the network.
  logging.info('Initializing network.')
  params = net.init(
      jax.random.PRNGKey(FLAGS.config.training_random_seed), graph)
  # Initialize the optimizer.
  opt_init, opt_update = optax.adam(FLAGS.config.learning_rate)
  opt_state = opt_init(params)

  compute_loss_fn = functools.partial(compute_loss, net=net)
  # We jit the computation of our loss, since this is the main computation.
  # Using jax.jit means that we will use a single accelerator. If you want
  # to use more than 1 accelerator, use jax.pmap. More information can be
  # found in the jax documentation.
  compute_loss_fn = jax.jit(jax.value_and_grad(compute_loss_fn, has_aux=True))

  data_ind = 0
  ep = 0
  best_params = params
  best_val_mse = 1e9

  for idx in range(num_training_steps):
    graph, node_labels, graph_labels, eff_res, hitting_times, eigenv_diff, embedding, distance_embedding = (
        dataset['train'][data_ind])
    for _ in range(FLAGS.config.same_example_freq):
      if FLAGS.config.randomly_rotate:
        curr_embedding = np.matmul(embedding,
                                   random_rotation_mats[random_rotation_idx])
        random_rotation_idx = (random_rotation_idx +
                               1) % len(random_rotation_mats)
      else:
        curr_embedding = embedding

      curr_graph = make_graph(graph, node_labels, graph_labels, eff_res,
                              hitting_times, eigenv_diff, curr_embedding,
                              distance_embedding)
      (loss, _), grad = compute_loss_fn(params, curr_graph, node_labels,
                                        graph_labels)
      updates, opt_state = opt_update(grad, opt_state, params)
      params = optax.apply_updates(params, updates)

    data_ind = (data_ind + 1) % len(dataset['train'])

    if idx % len(dataset['train']) == 0:
      logging.info('epoch: %s (step: %s), loss: %s', ep, idx, loss)

      ep += 1
      accumulated_loss = 0.0
      accumulated_logmse = 0.0
      for (val_graph, val_node_labels, val_graph_labels, val_eff_res, val_ht,
           val_eigenv_diff, val_embedding,
           val_distance_embedding) in dataset['val']:
        val_graph = make_graph(val_graph, val_node_labels, val_graph_labels,
                               val_eff_res, val_ht, val_eigenv_diff,
                               val_embedding, val_distance_embedding)
        (val_loss,
         val_logmse_per_label), _ = compute_loss_fn(params, val_graph,
                                                    val_node_labels,
                                                    val_graph_labels)
        accumulated_loss += val_loss
        accumulated_logmse += val_logmse_per_label
      val_loss = accumulated_loss / len(dataset['val'])
      val_logmse = accumulated_logmse / len(dataset['val'])

      logging.info('validation loss: %s, log_mse (per task): %s', val_loss,
                   val_logmse)
      valid_dict = {
          f'valid_log_mse_task_{i}': np.asarray(val_logmse[i])
          for i in range(len(val_logmse))
      }
      valid_dict['step'] = idx
      valid_dict['valid_loss'] = np.asarray(val_loss)

      if val_loss < best_val_mse:
        best_val_mse = val_loss
        best_params = params
        logging.info('overwriting best parameters')

        tst_loss, tst_logmse = eval_on_data_set(dataset['test'], best_params,
                                                compute_loss_fn)
        logging.info(
            'on these parameters, test loss: %s, log_mse (per task): %s',
            tst_loss, tst_logmse)
    idx += 1

  accumulated_loss = 0.0
  accumulated_logmse = 0.0
  for (tst_graph, tst_node_labels, tst_graph_labels, tst_eff_res, tst_ht,
       tst_eigenv_diff, tst_embedding,
       tst_distance_embedding) in dataset['test']:
    tst_graph = make_graph(tst_graph, tst_node_labels, tst_graph_labels,
                           tst_eff_res, tst_ht, tst_eigenv_diff, tst_embedding,
                           tst_distance_embedding)
    (tst_loss,
     tst_logmse_per_label), _ = compute_loss_fn(best_params, tst_graph,
                                                tst_node_labels,
                                                tst_graph_labels)
    accumulated_loss += tst_loss
    accumulated_logmse += tst_logmse_per_label
  tst_loss = accumulated_loss / len(dataset['test'])
  tst_logmse = accumulated_logmse / len(dataset['test'])

  logging.info('test loss: %s, log_mse (per task): %s', tst_loss, tst_logmse)
  test_dict = {
      f'test_log_mse_task_{i}': np.asarray(tst_logmse[i])
      for i in range(len(tst_logmse))
  }
  test_dict['step'] = idx
  test_dict['test_loss'] = np.asarray(tst_loss)


  if save_dir is not None:
    with pathlib.Path(save_dir, 'pna.pkl').open('wb') as fp:
      logging.info('Saving model to %s', save_dir)
      pickle.dump(params, fp)
  logging.info('Training finished')


def main():
  FLAGS(sys.argv)
  print(FLAGS.config)

  train(FLAGS.config.data_path, FLAGS.config.num_training_steps,
        FLAGS.config.save_dir)


if __name__ == '__main__':
  main()
