# Run with: python src/train.py --config-name arc_train_overfit_all training.task_generator.overfit_task_index=0
# python src/train.py --multirun --config-name arc_train_overfit_all training.task_generator.overfit_task_index=0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49
# python src/train.py --multirun --config-name arc_train_overfit_all training.task_generator.overfit_task_index=50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99
# python src/train.py --multirun --config-name arc_train_overfit_all training.task_generator.overfit_task_index=100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149
# python src/train.py --multirun --config-name arc_train_overfit_all training.task_generator.overfit_task_index=150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199
# python src/train.py --multirun --config-name arc_train_overfit_all training.task_generator.overfit_task_index=200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249
# python src/train.py --multirun --config-name arc_train_overfit_all training.task_generator.overfit_task_index=250,251,252,253,254,255,256,257,258,259,260,261,262,263,264,265,266,267,268,269,270,271,272,273,274,275,276,277,278,279,280,281,282,283,284,285,286,287,288,289,290,291,292,293,294,295,296,297,298,299

training:
  seed: 0
  resume_from_checkpoint: null  # null to start from scratch
  inference_mode: mean  # mean, all, random_search, gradient_ascent
  batch_size: 128  # this has to be divisible by gradient_accumulation_steps * num_devices
  gradient_accumulation_steps: 1  # the higher the slower but the lower memory usage while keeping effective batch size constant
  total_num_steps: 20000
  log_every_n_steps: 50  # this has to respect dataset_size >= batch_size * log_every_n_steps
  eval_every_n_logs: 20  # null to disable eval
  save_checkpoint_every_n_logs: null  # null to disable checkpointing
  learning_rate: 3e-4
  prior_kl_coeff: 1e-4
  pairwise_kl_coeff: null
  mixed_precision: False  # if True, it uses bfloat16 for activations (params stay in float32)
  online_data_augmentation: False
  task_generator:
    num_workers: 48
    num_pairs: 2
    class: ARC
    overfit_task_index: 0
  train_datasets:


eval:
  eval_datasets:
  test_datasets:
    - generator: ARC
      task_generator_kwargs:
        overfit_task_index: ${training.task_generator.overfit_task_index}
      name: arc_overfit_mean
      num_pairs: 4
      length: 64
      batch_size: 64
      num_tasks_to_show: 5
  json_datasets:
    - name: json_mean
      challenges: json/arc-agi_training_challenges.json
      solutions: json/arc-agi_training_solutions.json
      overfit_task_index: ${training.task_generator.overfit_task_index}
      inference_mode: mean
      num_tasks_to_show: 1


encoder_transformer:
  _target_: src.models.utils.EncoderTransformerConfig
  max_rows: 30
  max_cols: 30
  num_layers: 1
  position_embeddings:
    rope_embeddings:
      active: True
      max_freq: 10.0 
    learned_position_embeddings:
      active: False
      scale: False
  transformer_layer:
    _target_: src.models.utils.TransformerLayerConfig
    num_heads: 6
    emb_dim_per_head: 16
    mlp_dim_factor: 1.0
    dropout_rate: 0.0
    attention_dropout_rate: 0.0
    mha_norm_type: rms_norm # rms_norm, layer_norm, none
  latent_dim: 32

decoder_transformer:
  _target_: src.models.utils.DecoderTransformerConfig
  max_rows: 30
  max_cols: 30
  num_layers: 8
  position_embeddings:
    rope_embeddings:
      active: True
      max_freq: 10.0 
    learned_position_embeddings:
      active: False
      scale: False
  next_position_embeddings: True
  next_position_embeddings_new_input_embeds: False
  transformer_layer:
    _target_: src.models.utils.TransformerLayerConfig
    num_heads: 8
    emb_dim_per_head: 64
    mlp_dim_factor: 4.0
    dropout_rate: 0.0
    attention_dropout_rate: 0.0
    mha_norm_type: rms_norm # rms_norm, layer_norm, none # TODO: remove this none