core:
  project_name: model_merging
  storage_dir: ${oc.env:PROJECT_ROOT}/storage
  entity: lXXXXX
  version: 0.0.1
  tags: 
  - dev
  
defaults:
  - hydra: default
  - nn: default
  - train: default
  - merger: weight_avg # iso-cts, isotropic, tsv, task_arithmetic, weight_avg, dummy (just uses pretrained), see all options in ./conf/merger
  - benchmark: N20 # N14, N8, N20, N2
  - override hydra/launcher: basic # basic
  - override hydra/job_logging: none
  - override hydra/hydra_logging: none
  - _self_ # as last argument to allow the override of parameters via this main config

# If true, run pairwise merging evaluation for all pairs in the benchmark, instead of merging all of them and evaluate on all tasks
all_pairwise: true

# Whether to apply rotation alignment before merging
alignment: false

seed_index: 0
num_tasks: ???
eval_on_train: false
number_of_train_batches: 25 # number of batches of the val set, used for grid search
device: 'cuda'



conventions:
  x_key: 'x'
  y_key: 'y'

# compression_ratio = 1 / svd_compress_factor, if null the ratio is set to 1 / num_tasks
svd_compress_factor: null

misc:
  ckpt_path: ${oc.env:MODELS_PATH}/${nn.encoder.model_name}
  pretrained_checkpoint: ${misc.ckpt_path}/base/model.pt
  openclip_cachedir: "${oc.env:MODELS_PATH}/openclip_cache/"
  checkpoint_dir: ${oc.env:MODELS_PATH}/linear_router
  svd_path: "${oc.env:MODELS_PATH}/svd_dict_${nn.encoder.model_name}.pt"
  finetuned_accuracy_path: "${oc.env:PROJECT_ROOT}/results/finetuning/accs.json"
  results_path: "${oc.env:PROJECT_ROOT}/results/${nn.encoder.model_name}/"
  # Regularization suffix for checkpoint names (e.g., "_moderate_update_grad_magnitude", "_moderate_update", "_grad_magnitude", or "")
  reg_suffix: ""

# Mergeability metrics configuration
mergeability:
  # OPTION 1: List datasets explicitly (set benchmark_name to null or a custom name)
  # datasets:
  #   - MNIST
  #   - Cars
  # benchmark_name: null  # or "custom_2tasks"

  # OPTION 2: Use a benchmark (e.g., N2, N8, N14, N20) from ./conf/benchmark
  benchmark_name: N20  # Used for output filename
  datasets: ${benchmark_datasets:${mergeability.benchmark_name}}

  metrics: # see options in ./src/model_merging/metrics/mergeability.py, or use "all" to run all metrics.
    - all
  layer_wise: false  # If true, compute all metrics per-layer (saves full breakdown, logs average) THIS IS DEPRECATED!!! SOME METRICS ARE ALREADY LAYER-WISE EVEN WITHOUT THIS FLAG

  rot_sym_align: false # Whether to perform rotation symmetry alignment before computing the mergeability metrics

  # Activation-based metrics configuration
  n_calibration_samples: 10  # Number of samples per dataset for calibration, for N8 it's 80 samples.
  calibration_batch_size: 32  # Batch size for processing calibration data
  calibration_random_seed: 42  # Random seed for reproducible calibration sampling
  activation_layer_name: 'model.visual.transformer.resblocks.11'  # Layer to extract activations from
  # Options for activation_layer_name (ViT-B-16):
  #   - 'model.visual.transformer.resblocks.11' (last transformer block, high-level semantic features)
  #   - 'model.visual.transformer.resblocks.8'  (mid-upper layer, balanced features)
  #   - 'model.visual.transformer.resblocks.5'  (middle layer, intermediate features)
  #   - 'model.visual.transformer.resblocks.0'  (first transformer block, low-level features)
  #   - 'model.visual.ln_post'                  (after all transformer blocks, normalized features)

  output_path: "${oc.env:PROJECT_ROOT}/results/mergeability/${nn.encoder.model_name}/"

# Metric linear optimization configuration
metric_linear_optimization:
  validation_split: 0.2  # Fraction of pairs to use for validation (0.0 to 1.0)
  random_seed: 42  # Random seed for reproducible train/val splitting
  iterations: 1000  # Number of optimization iterations
  learning_rate: 0.01  # Learning rate for Adam optimizer
  convergence_threshold: 0.0001  # Stop if improvement < threshold for 'patience' iterations
  patience: 50  # Number of iterations without improvement before stopping
  target_metric: 'acc/test/avg'  # Performance metric to correlate with
  output_path: "${oc.env:PROJECT_ROOT}/results/metric_linear_optimization/"

# Learnable mergeability configuration (MLP-based)
learnable_mergeability:
  merge_methods:  # Merge methods to predict
    - weight_avg
    - arithmetic
    - tsv
    - isotropic
  hidden_dim: 8  # Hidden layer dimension for MLP (reduced from 16 for regularization)
  dropout: 0.4  # Dropout rate for regularization (increased from 0.3)
  weight_decay: 0.001  # L2 weight decay for Adam optimizer
  learning_rate: 0.001  # Learning rate for Adam optimizer
  epochs: 300  # Number of training epochs
  batch_size: null  # Batch size (null for full batch)
  validation_split: 0.2  # Fraction of tasks for validation (task-level split)
  random_seed: 42  # Random seed for reproducibility
  device: 'cuda'  # Device to use (cuda or cpu)
  results_dir: "${oc.env:PROJECT_ROOT}/results/ViT-B-16"  # Directory containing merge method results
  output_path: "${oc.env:PROJECT_ROOT}/results/learnable_mergeability/"
