training:
  seed: 42
  resume_from_checkpoint: null  # null to start from scratch
  inference_mode: mean  # mean, all, random_search, gradient_ascent
  batch_size: 16  # this has to be divisible by gradient_accumulation_steps * num_devices
  gradient_accumulation_steps: 1  # the higher the slower but the lower memory usage while keeping effective batch size constant
  total_num_steps: 15000
  log_every_n_steps: 2  # this has to respect dataset_size >= batch_size * log_every_n_steps
  eval_every_n_logs: 100  # null to disable eval
  save_checkpoint_every_n_logs: null  # null to disable checkpointing
  learning_rate: 4e-4
  prior_kl_coeff: 1e-4
  pairwise_kl_coeff: null
  mixed_precision: False  # if True, it uses bfloat16 for activations (params stay in float32)
  online_data_augmentation: True
  task_generator:
  train_datasets:
    - storage/arc_dummy_4x4_32_3
  use_hf: True
  

eval:
  eval_datasets:
    - folder: storage/arc_dummy_4x4_32_3
      use_hf: True
      length: 32
      batch_size: 32
  test_datasets:
    - name: mean
      folder: storage/arc_dummy_4x4_32_3
      use_hf: True
      length: 16
      batch_size: 8
    - name: random_search_10
      folder: storage/arc_dummy_4x4_32_3
      use_hf: True
      length: 16
      batch_size: 8
      inference_mode: random_search
      inference_kwargs:
        num_samples: 10
        scale: 1.0
    - name: gradient_ascent
      folder: storage/arc_dummy_4x4_32_3
      use_hf: True
      length: 16
      batch_size: 8
      inference_mode: gradient_ascent
      inference_kwargs:
        num_steps: 2
        lr: 5e-2
        optimizer: adam
  json_datasets:
    - name: mean
      challenges: json/arc-agi_training_challenges.json
      solutions: json/arc-agi_training_solutions.json
      task_id_range: [10, 15]  # start is inclusive but end is exclusive
      inference_mode: mean
    - name: mean
      challenges: json/arc-agi_training_challenges.json
      solutions: json/arc-agi_training_solutions.json
      only_n_tasks: 2
      inference_mode: mean
    - name: random_search
      challenges: json/arc-agi_training_challenges.json
      solutions: json/arc-agi_training_solutions.json
      only_n_tasks: 2
      inference_mode: random_search
      inference_kwargs:
        num_samples: 10
        scale: 0.1

encoder_transformer:
  _target_: src.models.utils.EncoderTransformerConfig
  max_rows: 4
  max_cols: 4
  num_layers: 1
  position_embeddings:
    rope_embeddings:
      active: True
      max_freq: 10.0 
    learned_position_embeddings:
      active: True
      scale: False
  transformer_layer:
    _target_: src.models.utils.TransformerLayerConfig
    num_heads: 8
    emb_dim_per_head: 6
    mlp_dim_factor: 1.0
    dropout_rate: 0.0
    attention_dropout_rate: 0.0
    mha_norm_type: rms_norm # rms_norm, layer_norm, none
  latent_dim: 32

decoder_transformer:
  _target_: src.models.utils.DecoderTransformerConfig
  max_rows: 4
  max_cols: 4
  num_layers: 2
  position_embeddings:
    rope_embeddings:
      active: True
      max_freq: 10.0 
    learned_position_embeddings:
      active: False
      scale: False
  next_position_embeddings: True
  next_position_embeddings_new_input_embeds: False
  transformer_layer:
    _target_: src.models.utils.TransformerLayerConfig
    num_heads: 8
    emb_dim_per_head: 6
    mlp_dim_factor: 1.0
    dropout_rate: 0.0
    attention_dropout_rate: 0.0
    mha_norm_type: rms_norm # rms_norm, layer_norm, none # TODO: remove this none
  

### Examples, dashed lines are required, the rest is optional (default values are shown).
# eval_datasets:
#   - folder: ...
#     use_hf: True
#     length: ...
#     batch_size: null  # default to length
#     seed: 0
# test_datasets:
#   - name: null  # null for ~ f"{inference_mode}_{i}""
#     folder: ...
#     use_hf: True
#     length: ...
#     batch_size: null  # default to length
#     num_tasks_to_show: 5  # 0 means no figures
#     seed: 0
#     inference_mode: mean  # choices are [mean, first, random_search, gradient_ascent]
#   - name: null  # null for ~ f"{inference_mode}_{i}"
#     generator: True
#     num_pairs: ...
#     length: ...
#     batch_size: null  # default to length
#     num_tasks_to_show: 5
# json_datasets:
#   - name: null  # null for ~ f"{inference_mode}_{i}"
#     challenges: ...
#     solutions: ...
#     only_n_tasks: null  # null for all
#     num_tasks_to_show: 5
#     inference_mode: mean  # choices are [mean, first, random_search, gradient_ascent]

# For test_datasets and json_datasets, the following are also specified:
# For random_search inference_mode:
#   inference_kwargs: 
#     num_samples: ...  # 0 means max over the above, i.e. mean latent or all latents
#     scale: ...
#     scan_batch_size: null  # null for one big batch
#     include_mean_latent: True
#     include_all_latents: False
# For gradient_ascent inference_mode:
#   inference_kwargs:
#     num_steps: ...
#     lr: ...
#     lr_schedule: False
#     lr_schedule_exponent: 0.5
#     accumulate_gradients_decoder_pairs: False
#     scan_gradients_latents: False
#     optimizer: sgd
#     optimizer_kwargs: ...
#     random_perturbation:
#       num_samples: ...
#       scale: ...
#     stop_gradient_latent_move: True
#     include_mean_latent: True
#     include_all_latents: False
#     remove_encoder_latents: False
