name: synth/erdos_lstm/
num_runs: 1
num_workers: 1
gpus: !!python/tuple ['0']
seed: 1

model:
  module: hiddenschemanetworks.models.languagemodels
  name: NARRNN
  args:
    emb_dim: 256
    rw_length: 10
    encoder:
      module: hiddenschemanetworks.models.blocks_synthetic
      name: NonAutoRegEncoderRNN
      args:
        rnn_dim: 256
        embedding_dropout: .2

data_loader:
  module: hiddenschemanetworks.data.dataloaders
  name: DataLoaderSynthetic
  args:
    schemata_name: erdos
    path_to_data: ./data/Erdos # ./data/Barabasi_100_Large
    batch_size: 32

optimizer:
  module: torch.optim
  name: Adam
  args:
    lr: 0.001
    weight_decay: 0.0

trainer:
  module: hiddenschemanetworks.trainer
  name: TrainerSimpleSchema
  args:
    bm_metric: NLL-Loss
    save_after_epoch: 10
    reconstruction_every: 25000
    num_rec_sentences: 2
    num_samples: 0
    num_interpolation_samples: 0
    num_interpolation_steps: 0
    heat_map: false
    lr_schedulers: !!python/tuple
      - optimizer: # name of the optimizer
          counter: 2 # anneal lr rate if there is no improvement after n steps
          module:  torch.optim.lr_scheduler
          name: StepLR
          args:
            step_size: 1
            gamma: 0.25
    schedulers: !!python/tuple
      - module: hiddenschemanetworks.utils.param_scheduler
        name: ExponentialSchedulerGumbel
        label: temperature_scheduler_rws
        args:
          temp_init: 0.75
          min_temp: 0.75
          n_steps_to_rich_minimum: 12000
          validation_value: 0.5
  epochs: 100
  save_dir: ./results
  logging:
    tensorboard_dir: ./results
    logging_dir: ./results
    formatters:
      verbose: "%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d %(message)s"
      simple: "%(levelname)s %(asctime)s %(message)s"