# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
from dataclasses import dataclass
from typing import ClassVar


@dataclass
class TrainArgs:
    output_dir: str = "finetuned_models"
    per_device_train_batch_size: int = 8
    per_device_eval_batch_size: int = 10
    gradient_accumulation_steps: int = 4
    num_train_epochs: int = 5
    #logging_steps: int = 1
    log_level: str = "info"
    logging_steps: int = 50
    logging_strategy: str = "steps"
    evaluation_strategy: str = "steps"
    #eval_steps: int = 50
    save_strategy: str = "steps"
    bf16: bool = True
    learning_rate: float = 2e-4
    lr_scheduler_type: str = "cosine"
    weight_decay: float = 0.0
    warmup_ratio: float = 0.1
    seed: int = 42
    push_to_hub: bool = False
    hub_model_id: str = "placeholder"
    report_to: str = "wandb"
    #hub_private_repo: bool = True
    #hub_strategy: str = "every_save"
    remove_unused_columns: bool = True
    full_determinism: bool = False
    resume_from_checkpoint: str = None
    save_steps: int = 1000
    optim: str = "adamw_torch"
    gradient_checkpointing: bool = False
    load_best_model_at_end: bool = False
    push_to_hub: bool = False
    #metric_for_best_model: str = 'precision'
    save_total_limit: int = 0
    hub_strategy: str = 'checkpoint'
    save_only_model: bool = True
    
