config:
  (): custom_colbert.utils.train_custom_colbert_models.ColModelTrainingConfig
  output_dir: !path ../../../models/train_real_siglip
  processor:
    () : custom_colbert.utils.wrapper.AutoProcessorWrapper
    pretrained_model_name_or_path: google/siglip-so400m-patch14-384
    max_length: 64
  model:
    (): custom_colbert.utils.wrapper.AutoColModelWrapper
    pretrained_model_name_or_path: google/siglip-so400m-patch14-384
    training_objective: "biencoder_mean"
    # attn_implementation: "eager"
    torch_dtype:  !ext torch.float16
#    device_map: "auto"
#    quantization_config:
#      (): transformers.BitsAndBytesConfig
#      load_in_4bit: true
#      bnb_4bit_quant_type: "nf4"
#      bnb_4bit_compute_dtype:  "bfloat16"
#      bnb_4bit_use_double_quant: true

  dataset_loading_func: !ext custom_colbert.utils.dataset_transformation.load_docvqa_dataset
  eval_dataset_loader: !import ../data/debug_data.yaml

  max_length: 64
  run_train: true
  run_eval: true
  add_suffix: true
  loss_func:
    (): custom_colbert.loss.colbert_loss.BiPairwiseCELoss
  tr_args:
    (): transformers.training_args.TrainingArguments
    output_dir: null
    overwrite_output_dir: true
    num_train_epochs: 1
    per_device_train_batch_size: 2
    max_steps: 10
    # 6 x 8 gpus = 48 batch size
    # gradient_accumulation_steps: 4
    per_device_eval_batch_size: 2
    eval_strategy: "steps"
    # dataloader_num_workers: 8
    # bf16: true
    save_steps: 500
    logging_steps: 10
    eval_steps: 50
    warmup_steps: 100
    learning_rate: 5e-5
    save_total_limit: 1
    optim: "paged_adamw_8bit"

  peft_config:
    (): peft.LoraConfig
    r: 32
    lora_alpha: 32
    lora_dropout: 0.1
    init_lora_weights: "gaussian"
    bias: "none"
    task_type: "FEATURE_EXTRACTION"
    target_modules: '(.*(text_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_(text|image)_proj).*$)'
