from dataclasses import dataclass
from typing import Dict, Literal, Optional
from transformers import TrainingArguments

@dataclass
class SimPOConfig(TrainingArguments):
    max_length: Optional[int] = 2048
    max_prompt_length: Optional[int] = 1800
    max_completion_length: Optional[int] = None
    max_target_length: Optional[int] = None
    seed=42
    hub_model_id=None
    push_to_hub=False
    log_level="info"
    beta: float = 2.0
    gamma_beta_ratio: float = 0.25
    label_smoothing: float = 0
    disable_dropout: bool = True
    optim="adamw_torch"

    label_pad_token_id: int = -100
    padding_value: int = None
    truncation_mode: str = "keep_end"
    generate_during_eval: bool = False
    is_encoder_decoder: Optional[bool] = None
    model_init_kwargs: Optional[Dict] = None
    dataset_num_proc: Optional[int] = None

    save_steps=32
    gradient_accumulation_steps=16 
    loss_type: str = "simpo"
    ln: bool = True
    alpha: float=0.0
    do_eval=False
    # ours
    my_reward_model_path="./huggingface/ArmoRM-Llama3-8B-v0.1"
    # my_reward_model_path="./huggingface/ArmoRM-Llama3-8B-v0.1"
    num_train_epochs:int = 1
    on_policy_data_proportion=1.0
    sft_weight: float = 0.0
    chunk_num=3 #7 for math, 3 for iter data
    is_reset=True
    # chunk_num=1
    # is_reset=False
    sft_reset="all"  # choice: "sft", "pro", "all"
    reset_size=0.5
    dropout_size=0.5
    replay_ratio=2
