# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.

hydra:
  job:
    chdir: false

defaults:
  - _self_
  ####################
  #   Wandb config   #
  ####################
  - wandb/act.yaml
  ####################
  #   Intervention   #
  ####################
  - intervention_params: linear_ot
  ####################
  #    Diffusion     #
  ####################
  # Tasks
  - task_params: coco_styles.yaml
  # Models
  - model: SDXL-Lightning.yaml

intervention_params: null
task_params: null
acts_save_file_append: null

# Data and Model Loading Settings
data_dir: ${oc.env:DATA_DIR,/mnt/data}
cache_dir: ${oc.env:CACHE_DIR,${data_dir}}
results_dir: ${oc.env:OUTPUT_DIR,/tmp/results}

# Some globals
device: 'cuda'
fast: false
seed: 42
# By default we let model config specify batch_size
batch_size: ${model.default_batch_size}

# Decides which scripts to run
compute_responses: true
compute_interventions: true

# evaluation: e.g. ['clip_score']
evaluation: ${task_params.default_evaluation}

model:
  module_names: ${.default_module_names.default}

responses:
  # Response Generation Settings
  batch_size: ${batch_size}
  balanced_data: true
  save_fields: [] # Extra fields to save with responses
  shuffle: true
  device: "${device}"
  dtype: ${dtype:torch.float32}
  max_batches: null
  num_workers: 1
  seed: ${seed}
  resume: true
  data_dir: ${data_dir}
  cache_dir: ${cache_dir}
  tag: "responses"
  # Params for response-saving hooks
  intervention_params:
    name: "postprocess_and_save"
    pooling_op: ${interventions.intervention_params.pooling_op}
    hook_params:
      raise_exception: false
  # see configs/task_params
  task_params: ${task_params}
  acts_save_file_append: ${acts_save_file_append}
  model_params: ${model}

interventions:
  # Response Generation Settings
  batch_size: 2
  max_batches: null
  shuffle: true
  device: "cpu"
  dtype: {dtype:torch.float32}
  load_fields: []
  num_workers: 1
  seed: ${seed}
  resume: false
  cache_dir: ${cache_dir}
  tag: "interventions"
  # see configs/task_params
  task_params: ${task_params}
  # see configs/intervention_params
  intervention_params: ${intervention_params}
  acts_save_file_append: ${acts_save_file_append}
  # see configs/model
  model_params: ${model}

text_to_image_generation:
  max_batches: 4
  num_workers: 2
  device: "${device}"
  seed: ${seed}
  data_dir: ${data_dir}
  cache_dir: ${cache_dir}
  # Text generation batch size.
  batch_size: ${batch_size}
  # Will sweep over these numbers
  diffusion_guidance_scale: []
  # Base number for guidance scale
  guidance_scale: ${.model_params.guidance_scale}
  # Base number of diffusion inference steps
  num_inference_steps: ${.model_params.inference_steps}
  # Diffusion image resolution
  generation_resolution: 224
  # Where results are saved
  results_dir: ${results_dir}
  # If true, also save a gif animation
  create_gif: true
  # see configs/intervention_params
  intervention_params: ${intervention_params}
  acts_save_file_append: ${acts_save_file_append}
  # Strentgths for which to generate images
  min_strength: 0.0
  max_strength: 1.0
  # Will execute np.linspace over this number of steps
  strength_steps: 21
  # If true, run script in fast mode (small batches, small data, not useful for true results)
  fast: ${fast}
  verbose: 1
  # See defaults on top of this file
  task_params: ${task_params}
  # see configs/model
  model_params: ${model}
  # wandb params
  wandb: ${wandb}
  # Use these prompts to generate instead of reading from a dataset
  prompt_override: null

clip_score:
  input_folder: ${results_dir}/generate_with_hooks_diffusion
  results_dir: ${results_dir}
  device: ${device}
  # wandb params
  wandb: ${wandb}