hydra:
  run:
    dir: ./outputs/runs/${now:%Y-%m-%d}/${now:%H-%M-%S}

defaults:
  - _self_
  - hyperparams/tabpfn: finetune
  - hyperparams/foundation: finetune
  - hyperparams/foundation_flash: finetune
  - hyperparams/tab2d: finetune
  - pretrain_model: tab2d
  - plotting: plotting_default
  # - pretrain_continue
  - pretrain_regression
  # - pretrain_test
  # - pretrain_regression_test
  - override hydra/hydra_logging: disabled  
  - override hydra/job_logging: disabled  

output_dir: ${hydra:run.dir}
seed: 0
devices: [0, 1, 2, 3]
workers_per_gpu: 16
max_cpus_per_device: null

optim:
  steps: 44_000                # Every step completes batch_size * gradient_accumulation_steps samples
  warmup_steps: 1_000
  log_every_n_steps: 10
  eval_every_n_steps: 4_400
  batch_size: 256                    # Total batch size over all devices. 
  gradient_accumulation_steps: 4     # Accumulation steps are not counted towards max_steps
  lr: 1.e-4
  weight_decay: 0.1
  beta1: 0.9
  beta2: 0.95
  cosine_scheduler: True  
  max_grad_norm: 1.0
  regression_loss: MSE
  label_smoothing: 0.0
  precision: bfloat16
  grad_scaler:
    enabled: False
    scale_init: 65536.
    scale_min: 65536.
    growth_interval: 1000
  use_pretrained_weights: False
  path_to_weights: outputs_done/foundation_key_att/weights/model_step_500000.pt               # Path to a checkpoint to load weights from

data:
  generator: mix             # tabpfn, forest, or mix
  min_samples_support: 16
  max_samples_support: 512            
  n_samples_query: 128
  min_features: 1
  max_features: 16
  max_classes: 10
  task: CLASSIFICATION
  generator_hyperparams:
    # min_complexity: 0.01
    # max_complexity: 1.0
    # n_octaves: 7
    min_depth: 1
    max_depth: 25
    base_size: 1024
    categorical_x: True


preprocessing:
  use_quantile_transformer: False              
  use_feature_count_scaling: False        # TabPFN: True


testing:
  downstream_tasks:                               # What task to do during validation and testing phase
    - ZEROSHOT
    - FINETUNE

  n_default_runs_per_dataset_valid: 1           # Only applicable to the WHYTREES benchmark          
  n_default_runs_per_dataset_test: 10           # Only applicable to the WHYTREES benchmark           
  openml_dataset_ids_to_ignore:
    - 45041  # topo doesn't fit on GPU
    # - 44089  # missing SAINT hyperparameter search in original benchmark results
    # - 44135  # isolet numerical regression (613 features)
    # - 44061  # mercedes_benz categorical regression (359 features)
    # - 45041  # topo categorical regression (255 features)
    # - 45046  # allstate_claims categorical regression (124 features)
    # - 45019  # bioresponse numerical classification (419 features)

  loss_graph_min_step: 999

  decision_boundary_analysis:
    enabled: True
    grid_size: 1000

  benchmarks_valid:
    - CATEGORICAL_CLASSIFICATION

  benchmarks_test:
    # - DEBUG_CLASSIFICATION
    # - DEBUG_REGRESSION
    # - DEBUG_TABZILLA
    - CATEGORICAL_CLASSIFICATION
    - NUMERICAL_CLASSIFICATION
    # - CATEGORICAL_REGRESSION
    # - NUMERICAL_REGRESSION
    # - CATEGORICAL_CLASSIFICATION_LARGE
    # - NUMERICAL_CLASSIFICATION_LARGE
    # - CATEGORICAL_REGRESSION_LARGE
    # - NUMERICAL_REGRESSION_LARGE
    # - TABZILLA_HARD_MAX_TEN_CLASSES
    - TABZILLA_HAS_COMPLETED_RUNS
