# Run with: python src/train.py --config-path configs/pattern_ood --config-name mean

training:
  seed: 0
  resume_from_checkpoint: null  # null to start from scratch
  inference_mode: mean  # mean, all, random_search, gradient_ascent
  batch_size: 128  # this has to be divisible by gradient_accumulation_steps * num_devices
  gradient_accumulation_steps: 1  # the higher the slower but the lower memory usage while keeping effective batch size constant
  total_num_steps: 400000
  log_every_n_steps: 100  # this has to respect dataset_size >= batch_size * log_every_n_steps
  eval_every_n_logs: 1  # null to disable eval
<<<<<<< HEAD
  save_checkpoint_every_n_logs: 1000  # null to disable checkpointing
=======
  save_checkpoint_every_n_logs: null  # null to disable checkpointing
>>>>>>> f21adda8ad327d59921a07aaa420810e0811b663
  learning_rate: 4e-4
  prior_kl_coeff: 1e-4
  pairwise_kl_coeff: null
  mixed_precision: False  # if True, it uses bfloat16 for activations (params stay in float32)
  online_data_augmentation: False
  task_generator:
    num_workers: 32
    num_pairs: 4
    class: PATTERN
    pattern_size: 4
    pattern_density: 0.5
  train_datasets:


eval:
  eval_datasets:
  test_datasets:
    # - generator: PATTERN
    #   task_generator_kwargs:
    #     pattern_size: ${training.task_generator.pattern_size}
    #     pattern_density: ${training.task_generator.pattern_density}
    #   name: in_dist_0shot1
    #   num_pairs: 4
    #   length: 512
    #   batch_size: 128
    #   num_tasks_to_show: 32
    #   inference_mode: mean
    #   inference_kwargs:
    #     num_permutations: 1
    # - generator: PATTERN
    #   task_generator_kwargs:
    #     pattern_size: ${training.task_generator.pattern_size}
    #     pattern_density: ${training.task_generator.pattern_density}
    #   name: in_dist_0shot5
    #   num_pairs: 4
    #   length: 512
    #   batch_size: 128
    #   num_tasks_to_show: 32
    #   inference_mode: mean
    #   inference_kwargs:
    #     num_permutations: 5
    # - generator: PATTERN
    #   task_generator_kwargs:
    #     pattern_size: ${training.task_generator.pattern_size}
    #     pattern_density: ${training.task_generator.pattern_density}
    #   name: in_dist_0shot10
    #   num_pairs: 4
    #   length: 512
    #   batch_size: 128
    #   num_tasks_to_show: 32
    #   inference_mode: mean
    #   inference_kwargs:
    #     num_permutations: 10
    # - generator: PATTERN
    #   task_generator_kwargs:
    #     pattern_size: ${training.task_generator.pattern_size}
    #     pattern_density: ${training.task_generator.pattern_density}
    #   name: in_dist_0shot20
    #   num_pairs: 4
    #   length: 512
    #   batch_size: 32
    #   num_tasks_to_show: 32
    #   inference_mode: mean
    #   inference_kwargs:
    #     num_permutations: 20
    # - generator: PATTERN
    #   task_generator_kwargs:
    #     pattern_size: ${training.task_generator.pattern_size}
    #     pattern_density: ${training.task_generator.pattern_density}
    #   name: in_dist_0shot100
    #   num_pairs: 4
    #   length: 512
    #   batch_size: 16
    #   num_tasks_to_show: 32
    #   inference_mode: mean
    #   inference_kwargs:
    #     num_permutations: 100
    - generator: PATTERN
      task_generator_kwargs:
        pattern_size: ${training.task_generator.pattern_size}
        pattern_density: ${training.task_generator.pattern_density}
      name: in_dist_0shot1
      num_pairs: 4
      length: 512
      batch_size: 128
      num_tasks_to_show: 32
      inference_mode: mean
      inference_kwargs:
        num_permutations: 1
    - generator: PATTERN
      task_generator_kwargs:
        pattern_size: ${training.task_generator.pattern_size}
        pattern_density: 0.75
      name: density_0.75_0shot
      num_pairs: 4
      length: 96
      batch_size: 96
      num_tasks_to_show: 32
    - generator: PATTERN
      task_generator_kwargs:
        pattern_size: ${training.task_generator.pattern_size}
        pattern_density: 1.0
      name: density_1.0_0shot
      num_pairs: 4
      length: 96
      batch_size: 96
      num_tasks_to_show: 32
    - generator: PATTERN
      task_generator_kwargs:
        pattern_size: ${training.task_generator.pattern_size}
        pattern_density: ${training.task_generator.pattern_density}
      name: in_dist_ga_10
      num_pairs: 4
      length: 96
      batch_size: 96
      num_tasks_to_show: 32
      inference_mode: gradient_ascent
      inference_kwargs:
        num_steps: 10
        lr: 0.1
        optimizer: adam
        optimizer_kwargs:
          b2: 0.9
    - generator: PATTERN
      task_generator_kwargs:
        pattern_size: ${training.task_generator.pattern_size}
        pattern_density: 0.75
      name: density_0.75_ga_10
      num_pairs: 4
      length: 96
      batch_size: 96
      num_tasks_to_show: 32
      inference_mode: gradient_ascent
      inference_kwargs:
        num_steps: 10
        lr: 0.1
        optimizer: adam
        optimizer_kwargs:
          b2: 0.9
    - generator: PATTERN
      task_generator_kwargs:
        pattern_size: ${training.task_generator.pattern_size}
        pattern_density: 1.0
      name: density_1.0_ga_10
      num_pairs: 4
      length: 96
      batch_size: 96
      num_tasks_to_show: 32
      inference_mode: gradient_ascent
      inference_kwargs:
        num_steps: 10
        lr: 0.1
        optimizer: adam
        optimizer_kwargs:
          b2: 0.9
    - generator: PATTERN
      task_generator_kwargs:
        pattern_size: ${training.task_generator.pattern_size}
        pattern_density: ${training.task_generator.pattern_density}
      name: in_dist_ga_100
      num_pairs: 4
      length: 96
      batch_size: 96
      num_tasks_to_show: 32
      inference_mode: gradient_ascent
      inference_kwargs:
<<<<<<< HEAD
        num_steps: 100
        lr: 0.1
        optimizer: adam
        optimizer_kwargs:
          b2: 0.9
    - generator: PATTERN
      task_generator_kwargs:
        pattern_size: ${training.task_generator.pattern_size}
        pattern_density: 0.75
      name: density_0.75_ga_100
=======
        num_permutations: 128
    - generator: PATTERN
      task_generator_kwargs:
        pattern_size: ${training.task_generator.pattern_size}
        pattern_density: 1.0
      name: density_1.0_ga_100
>>>>>>> f21adda8ad327d59921a07aaa420810e0811b663
      num_pairs: 4
      length: 96
      batch_size: 96
      num_tasks_to_show: 32
      inference_mode: gradient_ascent
      inference_kwargs:
<<<<<<< HEAD
=======
        num_permutations: 128
>>>>>>> f21adda8ad327d59921a07aaa420810e0811b663
        num_steps: 100
        lr: 0.1
        optimizer: adam
        optimizer_kwargs:
          b2: 0.9
<<<<<<< HEAD
    - generator: PATTERN
      task_generator_kwargs:
        pattern_size: ${training.task_generator.pattern_size}
        pattern_density: 1.0
      name: density_1.0_ga_100
      num_pairs: 4
      length: 96
      batch_size: 96
      num_tasks_to_show: 32
      inference_mode: gradient_ascent
      inference_kwargs:
        num_steps: 100
        lr: 0.1
        optimizer: adam
        optimizer_kwargs:
          b2: 0.9
=======

    # - generator: PATTERN
    #   task_generator_kwargs:
    #     pattern_size: ${training.task_generator.pattern_size}
    #     pattern_density: 0.75
    #   name: density_0.75_0shot
    #   num_pairs: 4
    #   length: 96
    #   batch_size: 96
    #   num_tasks_to_show: 32
    # - generator: PATTERN
    #   task_generator_kwargs:
    #     pattern_size: ${training.task_generator.pattern_size}
    #     pattern_density: 1.0
    #   name: density_1.0_0shot
    #   num_pairs: 4
    #   length: 96
    #   batch_size: 96
    #   num_tasks_to_show: 32
    # - generator: PATTERN
    #   task_generator_kwargs:
    #     pattern_size: ${training.task_generator.pattern_size}
    #     pattern_density: ${training.task_generator.pattern_density}
    #   name: in_dist_ga_10
    #   num_pairs: 4
    #   length: 96
    #   batch_size: 96
    #   num_tasks_to_show: 32
    #   inference_mode: gradient_ascent
    #   inference_kwargs:
    #     num_steps: 10
    #     lr: 0.1
    #     optimizer: adam
    #     optimizer_kwargs:
    #       b2: 0.9
    # - generator: PATTERN
    #   task_generator_kwargs:
    #     pattern_size: ${training.task_generator.pattern_size}
    #     pattern_density: 0.75
    #   name: density_0.75_ga_10
    #   num_pairs: 4
    #   length: 96
    #   batch_size: 96
    #   num_tasks_to_show: 32
    #   inference_mode: gradient_ascent
    #   inference_kwargs:
    #     num_steps: 10
    #     lr: 0.1
    #     optimizer: adam
    #     optimizer_kwargs:
    #       b2: 0.9
    # - generator: PATTERN
    #   task_generator_kwargs:
    #     pattern_size: ${training.task_generator.pattern_size}
    #     pattern_density: 1.0
    #   name: density_1.0_ga_10
    #   num_pairs: 4
    #   length: 96
    #   batch_size: 96
    #   num_tasks_to_show: 32
    #   inference_mode: gradient_ascent
    #   inference_kwargs:
    #     num_steps: 10
    #     lr: 0.1
    #     optimizer: adam
    #     optimizer_kwargs:
    #       b2: 0.9
    # - generator: PATTERN
    #   task_generator_kwargs:
    #     pattern_size: ${training.task_generator.pattern_size}
    #     pattern_density: ${training.task_generator.pattern_density}
    #   name: in_dist_ga_100
    #   num_pairs: 4
    #   length: 96
    #   batch_size: 96
    #   num_tasks_to_show: 32
    #   inference_mode: gradient_ascent
    #   inference_kwargs:
    #     num_steps: 100
    #     lr: 0.1
    #     optimizer: adam
    #     optimizer_kwargs:
    #       b2: 0.9
    # - generator: PATTERN
    #   task_generator_kwargs:
    #     pattern_size: ${training.task_generator.pattern_size}
    #     pattern_density: 0.75
    #   name: density_0.75_ga_100
    #   num_pairs: 4
    #   length: 96
    #   batch_size: 96
    #   num_tasks_to_show: 32
    #   inference_mode: gradient_ascent
    #   inference_kwargs:
    #     num_steps: 100
    #     lr: 0.1
    #     optimizer: adam
    #     optimizer_kwargs:
    #       b2: 0.9
    # - generator: PATTERN
    #   task_generator_kwargs:
    #     pattern_size: ${training.task_generator.pattern_size}
    #     pattern_density: 1.0
    #   name: density_1.0_ga_100
    #   num_pairs: 4
    #   length: 96
    #   batch_size: 96
    #   num_tasks_to_show: 32
    #   inference_mode: gradient_ascent
    #   inference_kwargs:
    #     num_steps: 100
    #     lr: 0.1
    #     optimizer: adam
    #     optimizer_kwargs:
    #       b2: 0.9
>>>>>>> f21adda8ad327d59921a07aaa420810e0811b663
  json_datasets:


encoder_transformer:
  _target_: src.models.utils.EncoderTransformerConfig
  max_rows: 10
  max_cols: 10
  num_layers: 1
  position_embeddings:
    rope_embeddings:
      active: False
      max_freq: 10.0 
    learned_position_embeddings:
      active: True
      scale: False
  transformer_layer:
    _target_: src.models.utils.TransformerLayerConfig
    num_heads: 8
    emb_dim_per_head: 8
    mlp_dim_factor: 2.0
    dropout_rate: 0.0
    attention_dropout_rate: 0.0
    mha_norm_type: layer_norm # rms_norm, layer_norm, none
  latent_dim: 32

decoder_transformer:
  _target_: src.models.utils.DecoderTransformerConfig
  max_rows: 10
  max_cols: 10
  num_layers: 2
  position_embeddings:
    rope_embeddings:
      active: False
      max_freq: 10.0 
    learned_position_embeddings:
      active: True
      scale: False
  next_position_embeddings: True
  next_position_embeddings_new_input_embeds: False
  transformer_layer:
    _target_: src.models.utils.TransformerLayerConfig
    num_heads: 8
    emb_dim_per_head: 4
    mlp_dim_factor: 1.0
    dropout_rate: 0.0
    attention_dropout_rate: 0.0
    mha_norm_type: layer_norm # rms_norm, layer_norm, none # TODO: remove this none