hydra:
  run:
    dir: ./outputs/runs/${now:%Y-%m-%d}/${now:%H-%M-%S}

defaults:
  - _self_
  - hyperparams/tabpfn: finetune
  - hyperparams/foundation: finetune
  - hyperparams/foundation_flash: finetune
  - hyperparams/tab2d: finetune
  - pretrain_model: foundation
  - plotting: plotting_default
  # - pretrain_continue
  # - pretrain_test
  - override hydra/hydra_logging: disabled  
  - override hydra/job_logging: disabled  

output_dir: ${hydra:run.dir}
seed: 0
devices: [6]
workers_per_gpu: 16
max_cpus_per_device: null

optim:
  steps: 25_000                # Every step completes batch_size * gradient_accumulation_steps samples
  warmup_steps: 1_000
  log_every_n_steps: 10
  eval_every_n_steps: 5_000
  batch_size: 256                    # Total batch size over all devices. 
  gradient_accumulation_steps: 1     # Accumulation steps are not counted towards max_steps
  lr: 1.e-4
  weight_decay: 0.0
  beta1: 0.9
  beta2: 0.95
  cosine_scheduler: True  
  max_grad_norm: 1.0
  label_smoothing: 0.0
  precision: float32
  grad_scaler:
    enabled: False
    scale_init: 65536.
    scale_min: 65536.
    growth_interval: 1000
  use_pretrained_weights: False
  path_to_weights: outputs_done/foundation_key_att/weights/model_step_500000.pt               # Path to a checkpoint to load weights from

data:
  generator: forest             # tabpfn, forest, or mix
  min_samples_support: 128
  max_samples_support: 1024            
  n_samples_query: 256
  min_features: 3
  max_features: 100
  max_classes: 10
  task: CLASSIFICATION
  generator_hyperparams:
    # min_complexity: 0.01
    # max_complexity: 1.0
    # n_octaves: 7
    min_depth: 1
    max_depth: 25
    base_size: 1024
    categorical_x: True


preprocessing:
  use_quantile_transformer: False              
  use_feature_count_scaling: False        # TabPFN: True
  shuffle_features: False                 # Tab2D should not shuffle: flash attention will break 

testing:
  downstream_tasks:                               # What task to do during validation and testing phase
    - ZEROSHOT
    - FINETUNE

  n_default_runs_per_dataset_valid: 1           # Only applicable to the WHYTREES benchmark          
  n_default_runs_per_dataset_test: 10           # Only applicable to the WHYTREES benchmark           
  openml_dataset_ids_to_ignore: []
    # - 44089  # missing SAINT hyperparameter search in original benchmark results
    # - 44135  # isolet numerical regression (613 features)
    # - 44061  # mercedes_benz categorical regression (359 features)
    # - 45041  # topo categorical regression (255 features)
    # - 45046  # allstate_claims categorical regression (124 features)
    # - 45019  # bioresponse numerical classification (419 features)

  loss_graph_min_step: 999
  decision_boundary_grid_size: 1000

  benchmarks_valid:
    - CATEGORICAL_CLASSIFICATION

  benchmarks_test:
    # - DEBUG_CATEGORICAL_CLASSIFICATION
    # - DEBUG_TABZILLA
    - CATEGORICAL_CLASSIFICATION
    - NUMERICAL_CLASSIFICATION
    # - CATEGORICAL_REGRESSION
    # - NUMERICAL_REGRESSION
    # - CATEGORICAL_CLASSIFICATION_LARGE
    # - NUMERICAL_CLASSIFICATION_LARGE
    # - CATEGORICAL_REGRESSION_LARGE
    # - NUMERICAL_REGRESSION_LARGE
    # - TABZILLA_HARD_MAX_TEN_CLASSES
    # - TABZILLA_HAS_COMPLETED_RUNS
