name: ptb/schema/
num_runs: 1
num_workers: 1
gpus: !!python/tuple ['0']
seed: 1

model:
  module: hiddenschemanetworks.models.languagemodels
  name: RealSchema
  args:
    n_symbols: 100
    Erdos_edge_prob: 0.5
    word_dropout: 0.0
    kl_threshold_rw: 0.1
    kl_threshold_graph: 0.1
    ################# GRAPH
    graph_generator:
      module: hiddenschemanetworks.models.blocks
      name: GraphGenerator
      args:
        symbols2hidden_layers: !!python/tuple [512, 512]  # must be list!
        nonlinearity: LeakyReLU  # LeakyReLU
        normalization: false
        n_communities: 128
        diag_in_adj_matrix: true
        symbol_pair2link_function: true
        aggregated_kl: true      # if kl_graph is computed using the average posterior link prob
    ################# ENCODER
    encoder:
      module: hiddenschemanetworks.models.blocks_transformers
      name: EncoderSchema
      args:
        rw_length: 20
    ################# DECODER
    decoder:
      module: hiddenschemanetworks.models.blocks_transformers
      name: DecoderSchema
      args:

data_loader:
  module: hiddenschemanetworks.data.dataloaders
  name: DataLoaderPennTreebank
  args:
    batch_size: 32
    path_to_data: ./data/ptb
    fix_len: 140

optimizer:
  module: torch.optim
  name: Adam
  args:
    lr: 0.00001
    weight_decay: 0.0

trainer:
  module: hiddenschemanetworks.trainer
  name: TrainerRealSchema
  args:
    bm_metric: NLL-Loss
    save_after_epoch: 20
    reconstruction_every: 500
    num_rec_sentences: 1
    num_samples: 0
    lr_schedulers: !!python/tuple
      - optimizer: # name of the optimizer
          counter: 0 # anneal lr rate if there is no improvement after n steps
          module:  torch.optim.lr_scheduler
          name: StepLR
          args:
            step_size: 1
            gamma: 0.8
    schedulers: !!python/tuple
      - module: hiddenschemanetworks.utils.param_scheduler
        name: PeriodicScheduler  # ExponentialIncrease, ConstantScheduler, PeriodicScheduler
        label: beta_scheduler_kl_rws
        args:
          max_value: 1.0                  # for exp. sched.
          beta: 1.0                       # for const. sched.
          training_fraction_to_reach_max: 0.5  # for exp. sched.False
          validation_value: 1.0
      - module: hiddenschemanetworks.utils.param_scheduler
        name: ExponentialSchedulerGumbel
        label: temperature_scheduler_rws
        args:
          temp_init: 1.0
          min_temp: 1.0
          training_fraction_to_reach_min: 0.7
          validation_value: 1.0
      ################# GRAPH ##########################
      - module: hiddenschemanetworks.utils.param_scheduler
        name: PeriodicScheduler   # ExponentialIncrease, ConstantScheduler, PeriodicScheduler
        label: beta_scheduler_kl_graph
        args:
          beta: 1.0                       # for const. sched.
          max_value: 1.0                  # for exp. sched.
          training_fraction_to_reach_max: 0.5   # for exp. sched.
          validation_value: 0.1
      - module: hiddenschemanetworks.utils.param_scheduler
        name: ExponentialSchedulerGumbel
        label: temperature_scheduler_graph
        args:
          temp_init: 1.0
          min_temp: 1.0
          training_fraction_to_reach_min: 0.7
          validation_value: 1.0
  epochs: 100
  save_dir: ./results
  logging:
    tensorboard_dir: ./results
    logging_dir: ./results
    formatters:
      verbose: "%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d %(message)s"
      simple: "%(levelname)s %(asctime)s %(message)s"