name: mtp

beta: 0.0           # The weighting factor for the KL objective
gamma: 1            # Per-token loss exponential discounting hyperparameter
kl_algorithm: full  # KL loss is computed only if beta > 0
kl_type: forward    # Can be forward or reverse

model:
    _target_: mtp.models.mtp.MultiTokenLM
    lm: ${lm.model}
    circuit: ${circuit.model}
    mt_head_kwargs: ${mt_head.hyperparameters}
    init_from_lm_head: true
    # Loss weighing
    kl_type: ${model.kl_type}            # Can be forward or reverse
    kl_algorithm: ${model.kl_algorithm}  # Can be full or binary_approx
    beta: ${model.beta}                  # Interpolate losses. 0. means CE only and 1. is only KL
    gamma: ${model.gamma}                # Exp. future token discount. 1. Means all tokens contribute equally
