"""Configuration parameters for train.py."""

from ml_collections import config_dict


def get_config():
  """Default configuration."""
  config = config_dict.ConfigDict()

  config.data_path = 'data/pna_jraph.pkl'
  # Directory to save parameters to
  config.save_dir = 'output'
  # Number of training steps
  config.num_training_steps = 2000
  # Model type; supported models 'mpnn', 'gcn' and 'gat'
  config.model = 'mpnn'
  # Latent size
  config.hidden_size = 64
  # Message passing steps
  config.mp_steps = 1
  # Number of hidden layers
  config.num_layers = 3
  # Base Adam learning rate
  config.learning_rate = 1e-3
  # Should effective resistances be used?
  config.use_effective_resistance = True
  # Should resistive embeddings be used as node features?
  config.use_er_node_embeddings = True
  # Should resistive embeddings be used as edge features?
  config.use_er_edge_embeddings = True
  # Whether to use eigenvector difference as edge features
  config.use_eigen_diff = False
  # Whether to use random features
  config.use_random_features = False
  # Whether to use hitting times
  config.use_hitting_times = True

  # EXPERIMENTAL: Whether to use distance features
  config.use_distance_embeddings = False
  # EXPERIMENTAL: Whether to centrality encoding
  config.use_centrality_encoding = False

  # Random seed for training parameters
  config.training_random_seed = 60
  # Do we randomly rotate the embedding each time
  config.randomly_rotate = True
  config.random_rotation_seed = 50
  # How many times do we feed the same graph? Only makes sense if random
  # rotations allowed.
  config.same_example_freq = 1
  # How many random rotation matrices to use? Should be a prime number.
  config.num_rotation_matrices = 23

  return config
