# Run with: python src/train.py --config-path configs/arc_train_overfit --config-name 045e512c

training:
  seed: 0
  resume_from_checkpoint: null  # null to start from scratch
  inference_mode: mean  # mean, all, random_search, gradient_ascent
  batch_size: 128  # this has to be divisible by gradient_accumulation_steps * num_devices
  gradient_accumulation_steps: 4  # the higher the slower but the lower memory usage while keeping effective batch size constant
  total_num_steps: 10000
  log_every_n_steps: 20  # this has to respect dataset_size >= batch_size * log_every_n_steps
  eval_every_n_logs: 50  # null to disable eval
  save_checkpoint_every_n_logs: null  # null to disable checkpointing
  learning_rate: 4e-4
  prior_kl_coeff: 1e-4
  pairwise_kl_coeff: null
  mixed_precision: False  # if True, it uses bfloat16 for activations (params stay in float32)
  online_data_augmentation: False
  task_generator:
    num_workers: 48
    num_pairs: 4
    class: ARC
    overfit_task: 045e512c
  train_datasets:


eval:
  eval_datasets:
  test_datasets:
    - generator: ARC
      task_generator_kwargs:
        overfit_task: ${training.task_generator.overfit_task}
      name: arc_overfit_mean
      num_pairs: 4
      length: 128
      batch_size: 32
      num_tasks_to_show: 32
  json_datasets:
    - name: json_mean
      challenges: json/arc-agi_training_challenges.json
      solutions: json/arc-agi_training_solutions.json
      overfit_task: ${training.task_generator.overfit_task}
      inference_mode: mean
      num_tasks_to_show: 1


encoder_transformer:
  _target_: src.models.utils.EncoderTransformerConfig
  max_rows: 30
  max_cols: 30
  num_layers: 0
  position_embeddings:
    rope_embeddings:
      active: False
      max_freq: 10.0 
    learned_position_embeddings:
      active: True
      scale: False
  transformer_layer:
    _target_: src.models.utils.TransformerLayerConfig
    num_heads: 6
    emb_dim_per_head: 16
    mlp_dim_factor: 1.0
    dropout_rate: 0.0
    attention_dropout_rate: 0.0
    mha_norm_type: rms_norm # rms_norm, layer_norm, none
  latent_dim: 32

decoder_transformer:
  _target_: src.models.utils.DecoderTransformerConfig
  max_rows: 30
  max_cols: 30
  num_layers: 4
  position_embeddings:
    rope_embeddings:
      active: False
      max_freq: 10.0 
    learned_position_embeddings:
      active: True
      scale: False
  transformer_layer:
    _target_: src.models.utils.TransformerLayerConfig
    num_heads: 8
    emb_dim_per_head: 32
    mlp_dim_factor: 2.0
    dropout_rate: 0.0
    attention_dropout_rate: 0.0
