# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.

# Model related stuff
model_path: "sshleifer/tiny-gpt2" # Path to model checkpoint

# Data and Model Loading Settings
data_dir: "test/data"
cache_dir: "tests/data

module_names: [".*"] # TODO make a defaults config instead of using the one in models/
device: 'cuda'

seed: 42

compute_responses: true
learn_interventions: true

responses:
  # Response Generation Settings
  batch_size: 4
  balanced_data: 1
  save_fields: [] # Extra fields to save with responses
  shuffle: true
  device: "${device}"
  dtype: ${dtype:torch.float32}
  max_batches: None
  model_path: ${model_path}
  num_workers: 1
  seed: ${seed}
  resume: false
  # Diffusion-related fields
  num_inference_steps: 1
  guidance_scale: 0
  data_dir: ${data_dir}
  cache_dir: ${cache_dir}
  tag: "responses"
  module_names: ${module_names}
  # Subset Args (Specify which subsets to use)
  pooling_op: ${interventions.intervention_params.pooling_op}
  # see configs/intervention_params
  intervention_params:
    hook_params:
      strength: 1.0
      stop_at_first_hook: false
  # see configs/task_params
  task_params: null
interventions:
  # Response Generation Settings
  batch_size: 2
  shuffle: true
  device: "cpu"
  dtype: {dtype:torch.float32}
  load_fields: []
  max_batches: 1
  module_names: ${module_names}
  model_path: ${model_path}
  num_workers: 1
  seed: ${seed}
  resume: false
  cache_dir: ${cache_dir}
  tag: "test"
  # see configs/task_params
  task_params: null
  # see configs/intervention_params
  intervention_params:
    hook_params: # overriden by defaults
      strength: 1.0
      stop_at_first_hook: false

#TODO (below): needed for diffusion and pipeline
  # Diffusion Args
  # text_to_image:
  #   diffusion_guidance_scale: 0
  #   num_inference_steps: 1
  #   generation_resolution: 224
  # LLM-related fields

# generation:
#   # Response Generation Settings
#   batch_size: 2
#   shuffle: true
#   device: "cuda"
#   dtype: ${dtype:torch.float32}
#   max_batches: None
#   module_names: ${module_names}
#   num_workers: 1
#   seed: 42
#   resume: false
#   cache_dir: ${cache_dir}
#   output_dir: ${cache_dir}/outputs

#   # Subset Args (Specify which subsets to use)
#   src_subsets: ${src_subsets}
#   dst_subsets: ${dst_subsets}

#   intervention_params: ${intervention_params}
#   hook_params:
#     strength: 1.0
#     dtype: ${dtype:torch.float32}

#   # Diffusion Args
#   text_to_image:
#     diffusion_guidance_scale: 0
#     num_inference_steps: 1
#     generation_resolution: 224
#     generation_strength: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
#   # LLM-related fields
#   text_generation: 
#     seq_len: 128