from dataclasses import dataclass, asdict, field
from transformers import TrainingArguments
from typing import Optional, Any

@dataclass
class MyTrainingArguments(TrainingArguments):
    do_train: bool = True
    do_eval: bool = True
    skip: bool = False
    save_safetensors: bool = False
    output_dir: str = "./tmp"
    checkpoint_path: Optional[str] = None
    gradient_checkpointing: bool = True
    gradient_checkpointing_kwargs: Optional[dict[str, Any]] = field(default_factory=lambda: {"use_reentrant": False})

    def get_dict(self):
        tmp = asdict(self)
        other = TrainingArguments(output_dir=None)
        out = {}
        for k in tmp.keys():
            if not hasattr(other, k):
                out[k] = tmp[k]
            elif getattr(other, k) != tmp[k]:
                out[k] = tmp[k]
        return out