name: synth/erdos_schema/
num_runs: 1
num_workers: 1
gpus: !!python/tuple ['0']
seed: 1

model:
  module: hiddenschemanetworks.models.languagemodels
  name: SyntheticSchema
  args:
    emb_dim: 256
    Erdos_edge_prob: 0.2
    #################
    graph_generator:
      module: hiddenschemanetworks.models.blocks
      name: GraphGenerator
      args:
        symbols2hidden_layers: !!python/tuple [256] # must be list!
        nonlinearity: LeakyReLU  # LeakyReLU
        normalization: true
        n_communities: 128
    #################
    encoder:
      module: hiddenschemanetworks.models.blocks_synthetic
      name: SimpleTransformerEncoder
      args:
        const_lower_value_f_matrix: 1.0
        train_word_embeddings: true
        rw_length: 10 # 10 Erdos, 11 Barabasi
        n_heads: 2
        n_layers: 2
        dropout: 0.2
        custom_init: true

data_loader:
  module: hiddenschemanetworks.data.dataloaders
  name: DataLoaderSynthetic
  args:
    schemata_name: erdos
    path_to_data: ./data/Erdos
    batch_size: 256

optimizer:
  module: torch.optim
  name: Adam
  args:
    lr: 0.0001
    weight_decay: 0.0

trainer:
  module: hiddenschemanetworks.trainer
  name: TrainerSimpleSchema
  args:
    bm_metric: NLL-Loss
    save_after_epoch: 20
    reconstruction_every: 50000
    num_rec_sentences: 4
    num_samples: 0
    num_interpolation_samples: 0
    num_interpolation_steps: 0
    lr_schedulers: !!python/tuple
      - optimizer: # name of the optimizer
          counter: 2 # anneal lr rate if there is no improvement after n steps
          module:  torch.optim.lr_scheduler
          name: StepLR
          args:
            step_size: 1
            gamma: 0.25
    schedulers: !!python/tuple
      - module: hiddenschemanetworks.utils.param_scheduler
        name: ConstantScheduler  # ExponentialIncrease, ConstantScheduler
        label: beta_scheduler_kl_rws
        args:
          beta: 1.0                       # for const. sched.
          max_value: 1.0                 # for exp. sched.
          n_steps_to_rich_maximum: 10000  # for exp. sched.
          validation_value: 1.0
      - module: hiddenschemanetworks.utils.param_scheduler
        name: ConstantScheduler    # ExponentialSchedulerGumbel
        label: temperature_scheduler_rws
        args:
          beta: 0.75
          temp_init: 0.75
          min_temp: 0.75
          n_steps_to_rich_minimum: 12000
          validation_value: 0.5
      ################# GRAPH ##########################
      - module: hiddenschemanetworks.utils.param_scheduler
        name: ConstantScheduler   # ExponentialIncrease, ConstantScheduler
        label: beta_scheduler_kl_graph
        args:
          beta: 1.0            # for const. sched.
          max_value: 1.0     # for exp. sched.
          n_steps_to_rich_maximum: 10000   # for exp. sched.
          validation_value: 1.0
      - module: hiddenschemanetworks.utils.param_scheduler
        name: ConstantScheduler  # ConstantScheduler, ExponentialSchedulerGumbel
        label: temperature_scheduler_graph
        args:
          beta: 0.75
          temp_init: 0.75
          min_temp: 0.75
          n_steps_to_rich_minimum: 120000
          validation_value: 0.5
  epochs: 200
  save_dir: ./results
  logging:
    tensorboard_dir: ./results
    logging_dir: ./results
    formatters:
      verbose: "%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d %(message)s"
      simple: "%(levelname)s %(asctime)s %(message)s"