defaults:
  - base
  - _self_

config:
  _target_: custom_models.sequence_mixing_model.SequenceMixingConfig
  _convert_: "all"
  model_type: sequence_mixing_causal_lm
  architectures:
    - SequenceMixingForCausalLM
  auto_map:
    AutoConfig: "configuration_sequence_mixing.SequenceMixingConfig"
    AutoModelForCausalLM: "modeling_sequence_mixing.SequenceMixingForCausalLM"
  base_model_name_or_path: ${model_name_or_path}
  attention_module_paths: ${attention_module_paths}
  sequence_mixing_type: ${sequence_mixing_type}
  loss_type: ${loss_type}
  per_layer_loss_coef: ${per_layer_loss_coef}
  distillation_loss_coef: ${distillation_loss_coef}
  next_token_loss_coef: ${next_token_loss_coef}
  mask_per_layer_losses: ${mask_per_layer_losses}
  copy_attention_weights: ${copy_attention_weights}
  use_additional_features: ${use_additional_features}
  forward_mode: ${forward_mode}
  reinitialize_base_model: ${reinitialize_base_model}
  freeze_base_model: ${freeze_base_model}
  torch_dtype: ${torch_dtype}

torch_dtype: bfloat16

copy_attention_weights: true
use_additional_features: false
forward_mode: sequence_mixing_training
reinitialize_base_model: false
freeze_base_model: true
loss_type: "mse"
per_layer_loss_coef: 1.0
distillation_loss_coef: 0.0
next_token_loss_coef: 0.0
mask_per_layer_losses: true
metrics_logger:
  _target_: custom_models.sequence_mixing_utils.SequenceMixingLogger

loaded_model:
  _target_: custom_models.sequence_mixing_model.create_sequence_mixing_model
  config: ${config}
  metrics_logger: ${metrics_logger}