# Shared configurations across all methods
shared:
  # Model configuration
  model:
    model_name: "Qwen/Qwen2.5-3B-Instruct"
    trust_remote_code: true
    use_qlora: false

  # QLoRA configuration (only used if use_qlora is true)
  qlora:
    load_in_4bit: false
    bnb_4bit_quant_type: "nf4"
    bnb_4bit_compute_dtype: "float16"
    bnb_4bit_use_double_quant: false

  # Training configuration (shared parameters)
  training:
    output_dir: "outputs" 
    save_dir: "saved_models"  # Final saved models directory    
    per_device_train_batch_size: 8
    gradient_accumulation_steps: 4
    num_train_epochs: 3
    logging_steps: 10
    save_steps: 100
    eval_steps: 100
    warmup_steps: 100
    fp16: true
    remove_unused_columns: false
    report_to: "wandb"
    dataloader_num_workers: 4
    dataloader_pin_memory: true
    max_length: 500  # Maximum sequence length
    max_prompt_length: 500  # Maximum prompt length

  # LoRA configuration (shared parameters)
  lora:
    r: 32
    lora_alpha: 32
    target_modules: ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]
    lora_dropout: 0.1
    bias: "none"
    task_type: "CAUSAL_LM"

  # Dataset configuration
  dataset:
    train_path: "datasets/gpt4/pairwise_dataset_gpt4_w0.8_n30000.json"
    num_samples: 10000

  # Tracking configuration
  tracking:
    wandb_project: ""
    wandb_token: ""

  # Authentication
  auth:
    hf_token: ""

# Methods to train (in order)
methods: ["dpo", "proposed"]

# Method-specific configurations
method_configs:
  dpo:
    # DPO-specific training parameters
    training:
      learning_rate: 1e-4
      run_name: "dpo-training"

    # DPO-specific configuration
    specific:
      beta: 0.02
      max_prompt_length: 500
      max_length: 500


  proposed:
    # Proposed model method specific training parameters
    training:
      learning_rate: 1e-4
      run_name: "proposed-training"

    # Proposed model specific configuration
    specific:
      mu_epochs: 3    # Number of epochs to train with mu_loss
      pi_epochs: 3    # Number of epochs to train with pi_loss
      mu_learning_rate: 1e-4
      pi_learning_rate: 1e-4
      beta_pi: 0.02
      beta: 0  # Beta parameter for pi loss calculation
      reference_model_id: "Qwen/Qwen2.5-3B-Instruct"
