<type>: Munch
<init>: true

# ____ Wandb logger. ____
logger:
  <type>: WandbLogger
  <init>: false
  project: "PVN"
  entity: "<username>"
  tags: ["bec", "collaborative"]
  group: "bec"

# ____ PyTorch Lightning System. ____
system:
  <type>: LJNSingleBinaryErasure
  <init>: false

  # ____ Training setup. ____
  prover_mode: "collaborative"
  give_correct_proof_for_n_game_steps: -1
  block_prover_output_for_n_game_steps: -1

  # Periodic stuff.
  log_forward_metrics_every_n_game_steps: 10
  log_forward_plots_every_n_game_steps: 100
  do_inspection_every_n_batches: 10000

  # ____ Verifier. ____
  verifier:
    <type>: BinaryErasureFCVerifier
    <init>: true
    input_dim: 2
    proof_dim: 16
    hid_dim: 100
    output_dim: 2
    num_layers: 2
    activation:
      <type>: LeakyReLU
      <init>: false
    layer_norm: True
    use_spectral_norm: False

  # ____ Prover. ____
  prover:
    <type>: ProverWrapper
    <init>: true
    # Prover feature extractor.
    feat_extractor:
      <type>: FCNetFixedWidth
      <init>: true
      num_inputs: 2
      num_hidden_dim: 100
      num_hidden_layers: 1
      num_outputs: 100
      activation_init:
        <type>: LeakyReLU
        <init>: false
      use_layer_norm: True
      use_spectral_norm: False
    # Proof generator.
    proof_generator:
      <type>: BinaryErasureLinearProofGenerator
      <init>: true
      in_features: 100
      out_features: 16
    # Auxiliary heads.
    aux_heads:
      <type>: ModuleDict
      <init>: true
      classification:
        <type>: Linear
        <init>: true
        in_features: 100
        out_features: 2
      autoencoding:
        <type>: Linear
        <init>: true
        in_features: 100
        out_features: 2
    detach_aux_heads: False

  # ____ Data loader. ____
  train_loader:
    <type>: DataLoader
    <init>: false
    dataset:
      <type>: BinaryErasureChannelIO
      <init>: true
      zero_prob: 0.5
      epoch_len: 102000
    batch_size: 2000
    shuffle: true
    num_workers: 0
    worker_init_fn:
      <type>: seed_workers
      <init>: false

  # ____ Optimizers. ____
  verifier_optimizer_init:
    <type>: Adam
    <init>: false
    lr: 0.0003

  prover_optimizer_init:
    <type>: Adam
    <init>: false
    lr: 0.0003

  lookahead: null
#  lookahead:
#    <type>: Lookahead
#    <init>: false
#    k: 5
#    alpha: 0.5

  # Learning rate schedulers. (use WeightFreezeSchedule to deal with freezing networks. )
  verifier_scheduler: null
  prover_scheduler: null

  # Gradient updates per game step.
  num_verifier_steps: 5
  num_prover_steps: 1
  adaptive_prover_step_specs:
    <type>: Munch
    <init>: true
    upper_cutoff_accuracy: 1.1
    lower_cutoff_accuracy: 0.7
    update_prover_strength_n_game_steps_after_verifier_past_cutoff: 50
    max_prover_steps: 15


  # ____ Loss functions. ____
  # Classification loss.
  classification_loss_fn:
    <type>: LabelSmoothedCrossEntropy
    <init>: true
    alpha: 0.
  use_matching_verifier_loss: False
  nm_loss_weighting: 0.

  label_0_weighting: 1.
  label_1_weighting: 1.

  label_flipping: null
#    <type>: Munch
#    <init>: true
#    flip_probability: 0.2

  # Auxiliary task loss.
  prover_aux_losses:
    <type>: Munch
    <init>: true
    classification:
      <type>: Munch
      <init>: true
      loss_fn:
        <type>: classification_aux_loss
        <init>: false
        alpha: 0.
        downweight_ce: false
      coeff: 1.
    autoencoding:
      <type>: Munch
      <init>: true
      loss_fn:
        <type>: autoencoding_aux_loss
        <init>: false
      coeff: 1.

  verifier_aux_losses:
    <type>: Munch
    <init>: true


  # Additional regularization terms.
  proximal_reg: null
#  proximal_reg:  # Set to "null" to disable it.
#    <type>: Munch
#    <init>: true
#    prover_coeff: 0.01
#    verifier_coeff: 0.01
#    sync_every_n_game_steps: 20
#    distance_fn:
#      <type>: mse_loss
#      <init>: false

#  l2_proof_reg: null
  l2_proof_reg:
    <type>: Munch
    <init>: true
    coeff: 0.
  collapse_preventing_reg: null
#    <type>: Munch
#    <init>: true
#    max_samples: 500
#    reg_fn:
#      <type>: collapse_preventing_loss
#      <init>: false
#      coeff: 1.
#      scale: 4.


  # ____ Pretraining specs. ____
  pretraining_specs: null
#    <type>: Munch
#    <init>: true
#    stop_pretraining_if_acc_above_threshold: 0.85
#    stop_pretraining_acc_horizon: 20
#    pretrain_for_n_batches: 2
#    label_smoothing_coeff: 0.1
#    collapse_preventing_reg: null
##      <type>: Munch
##      <init>: true
##      max_samples: 500
##      reg_fn:
##        <type>: collapse_preventing_loss
##        <init>: false
##        coeff: 1.
##        scale: 4.
#    optimizer:
#      <type>: Adam
#      <init>: false
#      lr: 0.0003


  # ____ Probe training. ____
  max_num_batches_in_buffer: 10
  track_last_n_batch_probe_outputs: 10
  max_probe_training_batches: 10
  probe_specs:
    # Proof to input antoencoding.
    - <type>: Munch
      <init>: true
      task: "regression"
      inputs: "proofs"
      outputs: "p_xs"  # Inputs given to the prover.
      loss_fn:
        <type>: mse_loss
        <init>: false
      model_init:
        <type>: FCNetFixedWidth
        <init>: false
        num_inputs: 16
        num_hidden_dim: 200
        num_hidden_layers: 2
        num_outputs: 2
        activation_init:
          <type>: LeakyReLU
          <init>: false
        use_layer_norm: True
        use_spectral_norm: False
      logging_fns:
        - <type>: probe_loss_plot_logging
          <init>: false
        - <type>: probe_loss_logging
          <init>: false

  probe_dataloader:
    <type>: DataLoader
    <init>: false
    batch_size: 2000
    shuffle: true
    num_workers: 0
    worker_init_fn:
      <type>: seed_workers
      <init>: false

  probe_optimizer_init:
    <type>: Adam
    <init>: false
    lr: 0.0003


  # ____ Attacking the verifier (to test soundness) ____
  v_attack_p_optimizer_init:
    <type>: Adam
    <init>: false
    lr: 0.0003

  max_v_attack_p_training_batches: 500

  # ____ Attacking the verifier by optimizing proofs directly. ____
  v_attack_proof_optim_specs:
    <type>: Munch
    <init>: true
    max_attack_proof_samples: 200
    proof_optimizer:
      <type>: LBFGS
      <init>: false
      lr: 1
      max_iter: 300
      tolerance_grad: 1.0e-5
      tolerance_change: 1.0e-9
      history_size: 300
      line_search_fn: "strong_wolfe"

  # ____ Plot logging. ____
  plotting_fns: null
#    - <type>: verifier_pred_confidence_logging
#      <init>: false
#    - <type>: visualize_proof_and_grad_pairwise_similarities_and_grad_hist
#      <init>: false
#      num_samples: 20
#    - <type>: visualize_proofs_as_images
#      <init>: false
#      num_images: 5

  # ____ Minibatch inspection. ____
  fixed_minibatch_getter:
    <type>: get_inspection_minibatch
    <init>: false

# ____ Checkpoint callback. ____
checkpoint_callback:
  <type>: SimpleCheckpointer
  <init>: false
  save_every_t_min: 100000
  sync_every_t_min: 100000

callbacks: null

# Trainer.
trainer:
  <type>: Trainer
  <init>: false
  max_epochs: 10000000
  progress_bar_refresh_rate: 1
  num_sanity_val_steps: 1
  gradient_clip_val: 0.

# Task function.
task_fn:
  <type>: train
  <init>: false

# Other.
resume_if_possible: False
test: False
seed: 0
use_deterministic_algorithms: False
wandb_dryrun: False
