name: basharin
n_component: 1
n_token: 4
beta: .9
gamma: .8

circuit:
    _target_: mtp.models.circuits.CircuitModel
    vocab_size: ${data.vocab_size}
    n_token: ${model.n_token}
    n_component: ${model.n_component}
    kind: 'cp'  # The available choices are 'cp' and 'hmm'
mt_head_hparams:  # The hyperparameters of the multi token head
    n_embd: ${lm.n_embd}
    transformer_n_head: ${lm.n_head}
    tok_transformer_n_layer: 0   # Set to zero if you want no transformer encoder, for the Categoricals
    sum_transformer_n_layer: 0   # Set to zero if you want no transformer encoder, for the sum weights
    expander_n_layer: 1          # Only used if the expander type is 'mlp'
    expander_type: 'linear'      # The available choices are 'linear', 'mlp'
    freeze_vocab_unembedding: true
# See https://github.com/ctlllll/axolotl/blob/main/examples/medusa/vicuna_13b_qlora_stage1.yml
adaptor_params:  # Set params to None if you do not want an adaptor
  r: 32
  lora_alpha: 16
  lora_dropout: 0.
  target_modules:
    - gate_proj
    - down_proj
    - up_proj
    - q_proj
    - v_proj
    - k_proj
    - o_proj
    # - lm_head
model:
    _target_: mtp.models.mtp.MultiTokenLM
    lm: ${lm.model}
    circuit: ${model.circuit}
    mt_head_kwargs: ${model.mt_head_hparams}
    adaptor_kwargs: ${model.adaptor_params}
    init_from_lm_head: true
    # Loss weighing
    kl_type: forward             # Can be forward or reverse
    kl_algorithm: binary_approx  # Can be full or binary_approx
    beta: ${model.beta}          # Interpolate losses. 0. means CE only and 1. is only KL
    gamma: ${model.gamma}        # Exp. future token discount. 1. Means all tokens contribute equally
