defaults:
  - base
  - _self_

project: text2semantic_finetune_dual_ar
max_length: 4096
pretrained_ckpt_path: checkpoints/fish-speech-1.5

# Lightning Trainer
trainer:
  accumulate_grad_batches: 1
  gradient_clip_val: 1.0
  gradient_clip_algorithm: "norm"
  max_steps: 10000
  precision: bf16-true
  limit_val_batches: 10
  val_check_interval: 100
  # strategy:
  #   find_unused_parameters: true
  #   static_graph: true 

# Dataset Configuration
tokenizer:
  _target_: fish_speech.tokenizer.FishTokenizer
  model_path: ${pretrained_ckpt_path}/tokenizer.tiktoken

# Dataset Configuration
train_dataset:
  _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset
  proto_files:
    - data/protos
  tokenizer: ${tokenizer}
  causal: true
  max_length: ${max_length}
  use_speaker: false
  interactive_prob: 0.7

val_dataset:
  _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset
  proto_files:
    - data/protos
  tokenizer: ${tokenizer}
  causal: true
  max_length: ${max_length}
  use_speaker: false
  interactive_prob: 0.7

data:
  _target_: fish_speech.datasets.semantic.SemanticDataModule
  train_dataset: ${train_dataset}
  val_dataset: ${val_dataset}
  num_workers: 4
  batch_size: 4
  tokenizer: ${tokenizer}
  max_length: ${max_length}

# Model Configuration
model:
  _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
  model: 
    _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
    path: ${pretrained_ckpt_path}
    load_weights: true
    max_length: ${max_length}
    lora_config: null

  optimizer:
    _target_: torch.optim.AdamW
    _partial_: true
    lr: 1e-4
    weight_decay: 0
    betas: [0.9, 0.95]
    eps: 1e-5

  lr_scheduler:
    _target_: torch.optim.lr_scheduler.LambdaLR
    _partial_: true
    lr_lambda:
      _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
      _partial_: true
      num_warmup_steps: 10

# Callbacks
callbacks:
  model_checkpoint:
    every_n_train_steps: ${trainer.val_check_interval}
