model_name: 'tabcascade'

data:
  cat_encoding:  # empty (None) / onehot
  encoder: 'dt'  # dt / gmm
  max_depth: 8
  k_max: 10 # maximum k_max = 30, to have train time < 2 minutes on adult
  batch_size: 4096

highres:
  model:
    condition_on_z: True # False only works for cdtd variant
    variant: 'cdtd'
    mlp_n_layers: 5
    mlp_n_units: 394
    mlp_emb_dim: 256
    cat_emb_dim: 8
    timewarp_weight_low_noise: 3.0
    generation_steps: 200
    generation_batch_size: 4096
  training:
    num_steps_warmup: -1 # set to -1 to disable
    ema_decay: 0.999
    lr: 0.002
    weight_decay: 0
    betas: [0.9, 0.999]
    clip_grad: False
lowres:
  model:
    variant: 'cdtd' # flow / cdtd
    mlp_act: 'relu' # relu / silu
    mlp_n_layers: 5
    mlp_n_units: 664 #662
    mlp_emb_dim: 256
    cat_emb_dim: 16
    cat_emb_init_sigma: 0.001
    normalize_by_entropy: True # whether to normalize feature-specific losses by entropies
    learn_noise_schedule: False
    init_embs_zero: False  # whether to initialize embeddings at zero
    learn_latents: False  # whether to learn latent representations for noise schedule
    norm_dim:
    time_reweight: False

    # configurations of the noise schedule / timewarping
    timewarp_variant: 'logistic' # logistic / pwl
    timewarp_weight_low_noise: 3.0 # 1.0 = uniform initialization for logistic timewarping
    sigma_min: 0
    sigma_max: 100
    sigma_data: 1.0

    # for generation
    generation_steps: 200
    generation_batch_size: 4096

  training:
    num_steps_train: 30_000
    log_steps: 100
    lr: 0.001
    weight_decay: 0
    betas: [0.9, 0.999]
    ema_decay: 0.999
    freeze_emb: False
    clip_grad: False
    scheduler: True
    num_steps_warmup: 1000 # set to -1 to disable
    use_val: False
    patience: 10