name: yahoo/schema/
num_runs: 1
num_workers: 1
gpus: !!python/tuple ['0']
seed: 1

model:
  module: hiddenschemanetworks.models.languagemodels
  name: RealSchema
  args:
    n_symbols: 50
    Erdos_edge_prob: 0.5
    word_dropout: 0.3         # fraction of tokens droppen out in the decoder input
    kl_threshold_rw: 0.1      # if kl_0 + kl_rw is lower than this, don't penalize it
    kl_threshold_graph: 0.1   # if kl_graph is lower than this, don't penalize it
    ################# GRAPH
    graph_generator:
      module: hiddenschemanetworks.models.blocks
      name: GraphGenerator
      args:
        symbols2hidden_layers: [!!python/tuple [512, 512]]  # must be list!
        nonlinearity: LeakyReLU
        normalization: false
        n_communities: 128
        diag_in_adj_matrix: true
        symbol_pair2link_function: true
        aggregated_kl: true      # if kl_graph is computed using the average posterior link prob
    ################# ENCODER
    encoder:
      module: hiddenschemanetworks.models.blocks_transformers
      name: EncoderSchema
      args:
        rw_length: 5
    ################# DECODER
    decoder:
      module: hiddenschemanetworks.models.blocks_transformers
      name: DecoderSchema
      args:
        rw_pos_encoding: true

data_loader:
  module: hiddenschemanetworks.data.dataloaders
  name: DataLoaderYahooAnswers
  args:
    batch_size: 32
    path_to_data: ./data/yahoo
    fix_len: 200

optimizer:
  module: torch.optim
  name: Adam
  args:
    lr: 0.00001
    weight_decay: 0.0

trainer:
  module: hiddenschemanetworks.trainer
  name: TrainerRealSchema
  args:
    freeze_decoder: false
    bm_metric: NLL-Loss
    save_after_epoch: 10
    reconstruction_every: 500
    num_rec_sentences: 8
    num_samples: 0
    lr_schedulers: !!python/tuple
      - optimizer: # name of the optimizer
          counter: 1 # anneal lr rate if there is no improvement after n steps
          module:  torch.optim.lr_scheduler
          name: StepLR
          args:
            step_size: 1
            gamma: 0.8
    schedulers: !!python/tuple
      - module: hiddenschemanetworks.utils.param_scheduler
        name: PeriodicScheduler  # ExponentialIncrease, ConstantScheduler
        label: beta_scheduler_kl_rws
        args:
          max_value: 1.0                 # for exp. sched.
      - module: hiddenschemanetworks.utils.param_scheduler
        name: ExponentialSchedulerGumbel
        label: temperature_scheduler_rws
        args:
          temp_init: 1.0
          min_temp: 1.0
          n_steps_to_rich_minimum: 200000
          validation_value: 1.0
      ################# GRAPH ##########################
      - module: hiddenschemanetworks.utils.param_scheduler
        name: PeriodicScheduler   # ExponentialIncrease, ConstantScheduler
        label: beta_scheduler_kl_graph
        args:
          max_value: 1.0                  # for exp. sched.
      - module: hiddenschemanetworks.utils.param_scheduler
        name: ExponentialIncrease
        label: beta_scheduler_kl_gamma
        args:
          training_fraction_to_reach_max: 0.5
          max_value: 1.0                   # for exp. sched.
      - module: hiddenschemanetworks.utils.param_scheduler
        name: ExponentialSchedulerGumbel
        label: temperature_scheduler_graph
        args:
          temp_init: 1.0
          min_temp: 1.0
          n_steps_to_rich_minimum: 50000
          validation_value: 1.0
  epochs: 100
  save_dir: ./data
  logging:
    tensorboard_dir: ./data
    logging_dir: ./data
    formatters:
      verbose: "%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d %(message)s"
      simple: "%(levelname)s %(asctime)s %(message)s"