from dataclasses import dataclass, field
from typing import Literal, Optional


@dataclass
class VTCLArguments:
    strategy: Optional[str] = field(
        default="sft",
        metadata={"help": "learning strategy: vtcl, mdpo, sft, dpo"}
    )
    anchor_ratio: Optional[float] = field(
        default=1.0,
        metadata={"help": "anchor loss ratio in mDPO training pipeline"}
    )
    iter_training: Optional[bool] = field(
        default=False,
        metadata={"help": "use iter training step"}
    )
    vtcl_mode: Optional[str] = field(
        default="single",
        metadata={"help": "s, single[only positive data]; sameBatch[pos and nega data are in the same batch]; difBatch[pos and nega data in different batch]"},
    )
    mask_same_sequence: Optional[bool] = field(
        default=False,
        metadata={"help": "whether to mask the same parts when calculating loss."}
    )
    test: Optional[bool] = field(
        default=False,
        metadata={"help": "test mode"}
    )
    test_dataset_name: Optional[str] = field(
        default="",
        metadata={"help": "test dataset"}
    )
    test_output_name: Optional[str] = field(
        default=None,
        metadata={"help": "the name of saved result file; default save path: test.json"}
    )

    def __post_init__(self):
        self.load_double_dataset = True if self.vtcl_mode == "difBatch" else False
        self.use_negative_data = True if self.vtcl_mode == "sameBatch" else False