# Run with: python src/train.py --config-path configs --config-name arc_train

training:
  seed: 42
  resume_from_checkpoint: null  # null to start from scratch
  inference_mode: gradient_ascent  # mean, max, random_search, gradient_ascent, transductive, zeros
  inference_kwargs:
    stop_gradient_latent_move: False
    num_steps: 10
    lr: 0.1
    optimizer: adam
    optimizer_kwargs:
      b2: 0.9
  batch_size: 256  # this has to be divisible by gradient_accumulation_steps * num_devices
  gradient_accumulation_steps: 1  # the higher the slower but the lower memory usage while keeping effective batch size constant
  total_num_steps: 500000
  log_every_n_steps: 100  # this has to respect dataset_size >= batch_size * log_every_n_steps
  eval_every_n_logs: null  # null to disable eval
  save_checkpoint_every_n_logs: 80  # null to disable checkpointing
  learning_rate: 3e-4
  prior_kl_coeff: 1e-4
  pairwise_kl_coeff: null
  mixed_precision: True  # if True, it uses bfloat16 for activations (params stay in float32)
  online_data_augmentation: True
  task_generator:
    num_workers: 64
    num_pairs: 4
    class: ARC
  train_datasets:

eval:
  eval_datasets:
  test_datasets:
    # - generator: ARC
    #   task_generator_kwargs:
    #   name: mean
    #   inference_mode: mean
    #   num_pairs: 4
    #   length: 16
    #   batch_size: 16
    #   num_tasks_to_show: 12
  json_datasets:
    # - challenges: json/arc-agi_training_challenges.json
    #   solutions: json/arc-agi_training_solutions.json
    #   name: mean
    #   only_n_tasks: ${training.task_generator.only_n_tasks}
    #   num_tasks_to_show: 100


# eval:
#   eval_datasets:
#   test_datasets:
#     - generator: ARC
#       task_generator_kwargs:
#       name: generator_mean
#       num_pairs: 4
#       length: 48
#       batch_size: 16
#       num_tasks_to_show: 48
#     - generator: ARC
#       task_generator_kwargs:
#       name: generator_gradient_ascent_1
#       num_pairs: 4
#       length: 16
#       batch_size: 16
#       num_tasks_to_show: 16
#       inference_mode: gradient_ascent
#       inference_kwargs:
#         num_steps: 1
#         lr: 0.1
#         optimizer: adam
#         optimizer_kwargs:
#           b2: 0.9
#     - generator: ARC
#       task_generator_kwargs:
#       name: generator_gradient_ascent_20
#       num_pairs: 4
#       length: 16
#       batch_size: 16
#       num_tasks_to_show: 16
#       inference_mode: gradient_ascent
#       inference_kwargs:
#         num_steps: 20
#         lr: 0.1
#         optimizer: adam
#         optimizer_kwargs:
#           b2: 0.9
#   json_datasets:
#     - challenges: json/arc-agi_training_challenges.json
#       solutions: json/arc-agi_training_solutions.json
#       name: mean
#       only_n_tasks: 100
#       num_tasks_to_show: 50
#     - challenges: json/arc-agi_training_challenges.json
#       solutions: json/arc-agi_training_solutions.json
#       name: gradient_ascent_20
#       only_n_tasks: 100
#       num_tasks_to_show: 50
#       inference_mode: gradient_ascent
#       inference_kwargs:
#         num_steps: 20
#         lr: 0.1
#         optimizer: adam
#         optimizer_kwargs:
#           b2: 0.9
#     - challenges: json/arc-agi_evaluation_challenges.json
#       solutions: json/arc-agi_evaluation_solutions.json
#       name: mean
#       only_n_tasks: 100
#       num_tasks_to_show: 50
#     - challenges: json/arc-agi_evaluation_challenges.json
#       solutions: json/arc-agi_evaluation_solutions.json
#       name: gradient_ascent_20
#       only_n_tasks: 100
#       num_tasks_to_show: 50
#       inference_mode: gradient_ascent
#       inference_kwargs:
#         num_steps: 20
#         lr: 0.1
#         optimizer: adam
#         optimizer_kwargs:
#           b2: 0.9


encoder_transformer:
  _target_: src.models.utils.EncoderTransformerConfig
  max_rows: 30
  max_cols: 30
  num_layers: 8
  position_embeddings:
    rope_embeddings:
      active: True
      max_freq: 10.0 
    learned_position_embeddings:
      active: False
      scale: False
  transformer_layer:
    _target_: src.models.utils.TransformerLayerConfig
    num_heads: 8
    emb_dim_per_head: 64 # This could be scaled
    mlp_dim_factor: 4.0
    dropout_rate: 0.0
    attention_dropout_rate: 0.0
    mha_norm_type: rms_norm # rms_norm, layer_norm, none
  latent_dim: 256

decoder_transformer:
  _target_: src.models.utils.DecoderTransformerConfig
  max_rows: 30
  max_cols: 30
  num_layers: 6
  position_embeddings:
    rope_embeddings:
      active: True
      max_freq: 10.0 
    learned_position_embeddings:
      active: False
      scale: False
  next_position_embeddings: True
  next_position_embeddings_new_input_embeds: False
  transformer_layer:
    _target_: src.models.utils.TransformerLayerConfig
    num_heads: 8
    emb_dim_per_head: 64  # THis might be too large? We want minimum level needed to learn programs
    mlp_dim_factor: 4.0
    dropout_rate: 0.0
    attention_dropout_rate: 0.0
    mha_norm_type: rms_norm # rms_norm, layer_norm, none # TODO: remove this none