verifier_cfg:
  verifier_type : "all" # "reward_models" or "judges" or "all"
  verifier_size: 8 #80 # "all", "small", "medium", "large" , int (verifiers less than or equal to this number)
  verifier_subset: # list of verifier names to use

data_cfg:
  # Options: "MATH-500", "AIMO", "MMLU-Pro", "GPQA","MMLU-College","AlpacaEval","BBH"
  dataset_name: "AIMO-v2"
  # fraction of data used for training 
  train_split: 0.8 # 1.0 means no test set
  # <=1: fraction of total queries (determined by 'train_split') used for training
  # >1: number of  train queries used for training
  train_queries: 1 # fraction/number of total queries in train_split used for training
  train_samples: 1 # fraction of total samples in train_split used for training
  nan_replacement: 0 #"mean" # "mean" or 0
  random_seed: 0
  model_size: "8B"
  reward_threshold:  # null means no threshold
  # normalization config
  normalize_type: "per_problem" # per_problem, all_problems
  normalize_method: "minmax" # "minmax", "quantile", "winsorize"
  normalize_params: ${normalize_method_params.${data_cfg.normalize_method}}
  # how to map train and test problems
  closest_train_problem_method: "mean_verifier_distance" #"SBERT" "mean_verifier_distance"
  closest_train_problem_metric_type: "euclidean" # "cosine" "euclidean"
  verifier_cfg: ${verifier_cfg}
  mv_as_verifier: true
  fixed_test_split:  # test set is fixed independently of train_split
  same_train_test: false 
  exclude_all_zeros: false  # Exclude samples where all labels are 0 during training
  sampler_type: "weighted_hybrid" # "weighted_labels" or "weighted_inv_freq" or "random"
  sampler_hybrid_alpha: 1.0
  # Data augmentation config
  augmentation:
    verifier_noise: false  # add noise to the verifier scores
    verifier_dropout: false # drop verifiers
    embedding_mixup: false # mixup the embeddings
    verifier_noise_std: 0.1 # standard deviation of the noise
    verifier_dropout_rate: 0.2 # dropout rate
    mixup_alpha: 0.2  #alpha for the mixup
    balanced_batch: false # use balanced batch sampling

  # Verifier-specific augmentation config
  verifier_augmentation:
    calibration: false # calibrate verifier scores
    calibration_method: "isotonic"  # "isotonic" or "sigmoid"
    interpolation: false  # interpolate missing verifier scores
    smoothing: false  # apply temporal smoothing to verifier scores
    smoothing_window: 3  # window size for smoothing

model_cfg:
  model_type: "bert_classifier_stacked"
  #model_class: "per_dataset"
  model_params: ${model_params.${model_cfg.model_type}}
  model_seed: 42 

loss_cfg:
  loss_type: "bce" # "bce" or "focal" "kldiv"
  loss_params: ${loss_params.${loss_cfg.loss_type}}

fit_cfg:
  batch_size: 64
  early_stopping: true  # Enable/disable early stopping
  lr: 0.05
  min_delta: 1e-4     # Minimum improvement required
  model_path: "verifier_combination_model.pth"  # Where to save the model
  num_epochs: 15
  patience: 5         # Number of epochs to wait before stopping
  save_model: false    # Whether to save the model

debug: false

optimizer_cfg:
  optimizer_type: "adamw" # "adam""adamw, reduce_on_plateau"
  optimizer_params: ${optimizer_params.${optimizer_cfg.optimizer_type}}
  scheduler_type:  "constant" # "constant" # "reduce_on_plateau"
  scheduler_params: ${scheduler_params.${optimizer_cfg.scheduler_type}}

logging: "wandb"
wandb_cfg:
  project: "verification"
  entity: "${oc.env:WANDB_ENTITY}"

model_params:
  bert_classifier:
    latent_dim: 0
    dropout: 0.2
  bert_classifier_stacked:
    latent_dim: 256
    dropout: 0.2
  bert_classifier_attention:
    latent_dim: 128
  bert_classifier_mlp:
    proj_dim: 32
    hidden_dim: 64

loss_params:
  kldiv:
    log_target: false
  bce:
    pos_weight: 10
  focal:
    alpha: 0.5
    gamma: 2.0
  margin_rank:
    margin: 1.0

optimizer_params:
  adam:
    betas: [0.9, 0.999]
  adamw:
    weight_decay: 1e-5

scheduler_params:
  constant:
    factor: 1.0
    total_iters: 0
  step:
    step_size: 10
    gamma: 0.1
  cosine_annealing:
    T_max: 100
  reduce_on_plateau:
    mode: "max"
    factor: 0.5
    patience: 10
    verbose: true
    min_lr: 0.00001

normalize_method_params:
  minmax:
    tmp: # does not take any parameters
  quantile:
    output_distribution: "uniform"
    n_quantiles: 100 # min number of quantiles
  winsorize: # map each observation to quantile by clipping extreme values
    lower_quantile: 0.01
    upper_quantile: 0.99
  quantile_minmax:
    output_distribution: "normal"
    n_quantiles: 100 # min number of quantiles

# python run2.py data_cfg.dataset_name="GPQA-v2" data_cfg.model_size="70B" model_cfg.model_type="bert_classifier_attention" fit_cfg.lr=0.0001 verifier_cfg.verifier_size=8