<type>: Munch
<init>: true

# ____ Wandb logger. ____
logger:
  <type>: WandbLogger
  <init>: false
  project: "PVN"
  entity: "<username>"
  tags: ["find-plus", "collaborative"]
  group: "find-plus"

# ____ PyTorch Lightning System. ____
system:
  <type>: LJNSingleFindThePlus
  <init>: false

  # ____ Training setup. ____
  prover_mode: "collaborative"  # Collaborative or "ljn".
  give_correct_proof_for_n_game_steps: -1
  block_prover_output_for_n_game_steps: -1

  # ____ Periodic stuff. ____
  log_forward_metrics_every_n_game_steps: 5
  log_forward_plots_every_n_game_steps: 500
  do_inspection_every_n_batches: 500

  # ____ Verifier. ____
  verifier:
    <type>: VerifierWrapper
    <init>: true
    # Proof-input processor.
    proof_input_processor:
      <type>: STNProofInputProcessor
      <init>: true
      proof_features: 32
      num_heads: 10
      padding_mode: "zeros"  # Other option is "reflection".
      shear: False
      scaling: False
      default_scale_factor: 0.5  # Only used when scaling=False.
    # Classifier.
    classifier:
      <type>: Linear
      <init>: true
      in_features: 1000
      out_features: 2
    # Verifier aux heads.
    aux_heads:
      <type>: ModuleDict
      <init>: true
      proof_input_matching:
        <type>: Linear
        <init>: true
        in_features: 1000
        out_features: 2

  # ____ Prover. ____
  prover:
    <type>: ProverWrapper
    <init>: true
    # Feature extractor.
    feat_extractor:
      <type>: FindPlusConvFeatExtractor
      <init>: true
      in_channels: 1
      hid_channels: 40
      out_channels: 40
      cat_position_embeddings: false
    # Proof generator.
    proof_generator:
      <type>: FindPlusProofGenerator
      <init>: true
      module:
        <type>: Linear
        <init>: true
        in_features: 100
        out_features: 32
        bias: false
    # Auxiliary heads.
    aux_heads:
      <type>: ModuleDict
      <init>: true
      classification:
        <type>: Linear
        <init>: true
        in_features: 100
        out_features: 2
#      autoencoding:
#        <type>: FCNetFixedWidth
#        <init>: true
#        num_inputs: 100
#        num_hidden_dim: 256
#        num_outputs: 100
#        num_hidden_layers: 2
#        activation_init:
#          <type>: ReLU
#          <init>: false
#        use_layer_norm: True
      autoencoding:
        <type>: SpatialBroadcastDecoder
        <init>: true
        in_feats: 100
        im_size: [10, 10]
        hid_channels: 32
        num_conv_layers: 3
        out_channels: 1
        kernel_size: 3
        act:
          <type>: leaky_relu
          <init>: false
    detach_aux_heads: False

  # ____ Data loader. ____
  train_loader:
    <type>: DataLoader
    <init>: false
    dataset:
      <type>: FindThePlusDataset
      <init>: true
      im_size: [10, 10]
      white_plus_probability: 0.5
      fill_ratios: [0.5, 0.1]
      epoch_len: 100000
    batch_size: 2000
    shuffle: true
    num_workers: 8
    worker_init_fn:
      <type>: seed_workers
      <init>: false

  # ____ Optimizers. ____
  verifier_optimizer_init:
    <type>: Adam
    <init>: false
    lr: 0.0003

  prover_optimizer_init:
    <type>: Adam
    <init>: false
    lr: 0.0003

  lookahead: null
  #  lookahead:
  #    <type>: Lookahead
  #    <init>: false
  #    k: 5
  #    alpha: 0.5

  # Learning rate schedulers. (use WeightFreezeSchedule to deal with freezing networks. )
  verifier_scheduler: null
  prover_scheduler: null

  # Gradient updates per game step.
  num_verifier_steps: 1
  num_prover_steps: 1
  adaptive_prover_step_specs:
    <type>: Munch
    <init>: true
    upper_cutoff_accuracy: 0.75
    lower_cutoff_accuracy: 0.5
    update_prover_strength_n_game_steps_after_verifier_past_cutoff: 20
    max_prover_steps: 15


  # ____ Loss functions. ____
  # Classification loss.
  classification_loss_fn:
    <type>: LabelSmoothedCrossEntropy
    <init>: true
    alpha: 0.05
  use_matching_verifier_loss: True
  nm_loss_weighting: 0.

  label_0_weighting: 1.
  label_1_weighting: 1.

  label_flipping: null
  #    <type>: Munch
  #    <init>: true
  #    flip_probability|flip_prob_0: 0.
  #    flip_probability|flip_prob_0_3: 0.3

  # Auxiliary task loss.
  prover_aux_losses:
    <type>: Munch
    <init>: true
    classification:
      <type>: Munch
      <init>: true
      loss_fn:
        <type>: classification_aux_loss
        <init>: false
        alpha: 0.
        downweight_ce: True
      coeff: 1.
    autoencoding:
      <type>: Munch
      <init>: true
      loss_fn:
        <type>: autoencoding_aux_loss
        <init>: false
      coeff: 1.

  verifier_aux_losses:
    <type>: Munch
    <init>: true
    proof_input_matching:
      <type>: Munch
      <init>: true
      loss_fn:
        <type>: verifier_input_proof_matching_aux_loss
        <init>: false
      coeff: 0.
#      coeff|verifier_aux_matching_coeff_0_1: 0.1
#      coeff|verifier_aux_matching_coeff_0_33: 0.33
#      coeff|verifier_aux_matching_coeff_1: 1.
#      coeff|verifier_aux_matching_coeff_3_3: 3.3

  # Additional regularization terms.
  proximal_reg: null
  l2_proof_reg:
    <type>: Munch
    <init>: true
    coeff: 0.
  collapse_preventing_reg: null
#    <type>: Munch
#    <init>: true
#    max_samples: 500
#    reg_fn:
#      <type>: collapse_preventing_loss
#      <init>: false
#      coeff|cpl_coeff_10: 10.
#      scale: 5.

  # ____ Pretraining specs. ____
#  pretraining_specs|no_pretrain: null
  pretraining_specs:
    <type>: Munch
    <init>: true
    stop_pretraining_if_acc_above_threshold: 1.00
    stop_pretraining_acc_horizon: 10
    pretrain_for_n_batches: 100
    label_smoothing_coeff: 0.3
    collapse_preventing_reg: null
    optimizer:
      <type>: Adam
      <init>: false
      lr: 0.0003

  # ____ Probe training. ____
  max_num_batches_in_buffer: 20
  max_probe_training_batches: 250
  track_last_n_batch_probe_outputs: 10
  probe_specs:
    # Proof to input antoencoding.
    - <type>: Munch
      <init>: true
      task: "regression"
      inputs: "proofs"
      outputs: "correct_proofs"
      loss_fn:
        <type>: mse_loss
        <init>: false
      model_init:
        <type>: FCNetFixedWidth
        <init>: false
        num_inputs: 32
        num_hidden_dim: 64
        num_hidden_layers: 1
        num_outputs: 2
        activation_init:
          <type>: LeakyReLU
          <init>: false
        use_layer_norm: True
        use_spectral_norm: False
      logging_fns:
        - <type>: probe_loss_plot_logging
          <init>: false
        - <type>: probe_loss_logging
          <init>: false

  probe_dataloader:
    <type>: DataLoader
    <init>: false
    batch_size: 2000
    shuffle: true
    num_workers: 0
    worker_init_fn:
      <type>: seed_workers
      <init>: false

  probe_optimizer_init:
    <type>: Adam
    <init>: false
    lr: 0.0003

  # ____ Attacking the verifier (to test soundness) ____
  v_attack_p_optimizer_init:
    <type>: Adam
    <init>: false
    lr: 0.0003

  max_v_attack_p_training_batches: 500

  # ____ Attacking the verifier by optimizing proofs directly. ____
  v_attack_proof_optim_specs:
    <type>: Munch
    <init>: true
    max_attack_proof_samples: 200
    proof_optimizer:
      <type>: LBFGS
      <init>: false
      lr: 1
      max_iter: 300
      tolerance_grad: 1.0e-5
      tolerance_change: 1.0e-9
      history_size: 300
      line_search_fn: "strong_wolfe"


  # ____ Plot logging. ____
  plotting_fns:
    - <type>: verifier_pred_confidence_logging
      <init>: false
    - <type>: log_autoencoding_p_aux_head_outputs
      <init>: false
      visualize_inputs: true
      max_num_images: 5
    - <type>: visualize_proof_and_grad_pairwise_similarities_and_grad_hist
      <init>: false
      num_samples: 100
    - <type>: visualize_stn_transformed_images
      <init>: false
      num_imgs: 10
    - <type>: visualize_transformed_plus_coords
      <init>: false

  # ____ Minibatch inspection. ____
  fixed_minibatch_getter:
    <type>: get_inspection_minibatch
    <init>: false

# ____ Checkpoint callback. ____
checkpoint_callback:
  <type>: SimpleCheckpointer
  <init>: false
  save_every_t_min: 15
  sync_every_t_min: 15

callbacks: null

# Trainer.
trainer:
  <type>: Trainer
  <init>: false
  max_epochs: 100000000
  progress_bar_refresh_rate: 1
  num_sanity_val_steps: 1
  gradient_clip_val: 1.

# Task function.
task_fn:
  <type>: train
  <init>: false

# Other.
resume_if_possible: True
test: False
seed: 6
use_deterministic_algorithms: False
wandb_dryrun: False

