program: test.py
project: tjepa-sweeps
name: tjepa_linear_higgs_v1
method: bayes
metric:
  goal: maximize
  name: higgs_val_accuracy
parameters:
  batch_size:
    value: 512
  data_set:
    value: "higgs"
  data_path:
    value: "./datasets"
  exp_final_weight_decay:
    value: 0
  exp_weight_decay:
    value: 0
  exp_start_lr:
    value: 0
  exp_lr:
    min: 0.00001
    max: 0.001
  exp_final_lr:
    value: 0.0
  model_dim_hidden:
    values: [2, 4, 8, 16, 32, 64, 128]
  model_num_heads:
    values: [2, 4, 8]
  model_num_layers:
    values: [1, 2, 3, 4, 5, 6, 7, 8, 16]
  model_dim_feedforward:
    values: [64, 128, 256, 512, 768, 1024]
  model_ema_start:
    value: 0.996
  model_ema_end:
    value: 1
  model_dropout_prob:
    min: 0.0
    max: 0.01
  mask_min_ctx_share:
    min: 0.07
    max: 0.15
  mask_max_ctx_share:
    min: 0.2
    max: 0.9
  mask_min_trgt_share:
    min: 0.05
    max: 0.20
  mask_max_trgt_share:
    min: 0.2
    max: 0.9
  pred_type:
    value: "transformer"
  pred_num_layers:
    values: [2, 4, 8, 16, 24, 32]
  pred_embed_dim:
    values: [4, 8, 16, 32, 64, 128]
  pred_num_heads:
    values: [2, 4, 8]
  pred_p_dropout:
    min: 0.0
    max: 0.01
  mask_num_preds:
    value: 4
  model_feature_index_embedding:
    value: False
  model_feature_type_embedding:
    value: False
  exp_cache_cadence:
    value: 20
  log_tensorboard:
    value: False
  pin_memory:
    value: False
  exp_train_total_epochs:
    value: 300
  exp_warmup:
    value: 10
  probe_cadence:
    value: 20
  exp_patience:
    value: 100
  init_type:
    value: "trunc_normal"
  probe_model:
    value: "linear_probe"
