# @package _global_
# Example configuration for experimenting. Trains the Attention Model on
# the TSP environment with 50 locations via REINFORCE with greedy rollout baseline.
# You may find comments on the most common hyperparameters below.

# Override defaults: take configs from relative path
defaults:
  - override /model: tmatnet.yaml
  - override /env: default.yaml
  - override /callbacks: default.yaml
  - override /trainer: default.yaml
  # - override /logger: null # comment this line to enable logging
  - override /logger: wandb.yaml

# Environment configuration
# Note that here we load by default the `.npz` files for the TSP environment
# that are automatically generated with seed following Kool et al. (2019).
env:
  _target_: rl4co.envs.TDTSPEnv
  name: "tdtsp-mat"
#  _target_: rl4co.envs.TDTSPEnvForPomo
#  name: "tdtsp_pomo"
  generator_params:
    num_loc: 10
  check_solution: False # optimization
  time_matrix_params:
    data: "beijing"
    interpolate: "constant"
    scale: 1

# Logging: we use Wandb in this case
logger:
  wandb:
    _target_: lightning.pytorch.loggers.wandb.WandbLogger
    name: ${now:%Y.%m.%d-%H.%M.%S} # name of the run (normally generated by wandb)
    save_dir: "${paths.output_dir}"
    offline: False
    id: null # pass correct id to resume experiment!
    anonymous: null # enable anonymous logging
    project: "time-tsp"
    log_model: False # upload lightning ckpts
    prefix: "" # a string to put at the beginning of metric keys
    # entity: "" # set to name of your wandb team
    group: ${env.name} # group name for the run
    tags: ['rl4co', 'matnet']
    job_type: ""

hydra:
  run:
    dir: ${paths.log_dir}/${mode}/runs/${logger.wandb.group}/${logger.wandb.name}/${env.time_matrix_params.data}_${env.time_matrix_params.interpolate}_${env.generator_params.num_loc}
  sweep:
    dir: ${paths.log_dir}/${mode}/multiruns/${logger.wandb.group}/${logger.wandb.name}/${env.time_matrix_params.data}_${env.time_matrix_params.interpolate}_${env.generator_params.num_loc}
    subdir: ${hydra.job.num}


# Model: this contains the environment (which gets automatically passed to the model on
# initialization), the policy network and other hyperparameters.
# This is a `LightningModule` and can be trained with PyTorch Lightning.
model:
  batch_size: 1024
  val_batch_size: 1024
  test_batch_size: 1024
  train_data_size: 5_120_0
  val_data_size: 10_000
  test_data_size: 10_000
  optimizer_kwargs:
    lr: 1e-4

  policy_params:
    embed_dim: 256
    num_encoder_layers: 5
    num_heads: 16
    normalization: "instance"
    init_embedding_kwargs: {"mode": "RandomOneHot" }
    use_graph_context: False
    bias: False
    kwargs:
      mask_non_neighbors: False

# Trainer: this is a customized version of the PyTorch Lightning trainer.
trainer:
  max_epochs: 6000

seed: 1234
