import os
from dataclasses import dataclass, field
from typing import Optional, List, Literal


@dataclass
class PreprocessArg:
    data_dir: str = field(
        metadata={"help": "The directory that includes raw text dataset (multiple files)."},
    )
    prefix_name: str = field(
        metadata={"help": "The prefix of raw text data."},
    )
    dest_dir: str = field(
        metadata={"help": "The directory that saves binarized dataset."},
    )
    save_name: str = field(
        default="train",
        metadata={"help": "The prefix name for saving files."},
    )
    worker: int = field(
        default=50,
        metadata={"help": "The number of workers to process data."},
    )
    numberized: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether the dataset has been numberized."},
    )
    append_eos: Optional[bool] = field(
        default=True,
        metadata={"help": "Append eos token id to each sentence."},
    )
    chunk_load: Optional[bool] = field(
        default=False,
        metadata={"help": "Read large file by chunked lines."}
    )
    tokenizer: Optional[str] = field(
        default="hf-internal-testing/llama-tokenizer",
        metadata={"help": "The tokenizer to tokenize dataset."},
    )
    auth: str = field(
        default="<auth_token>",
        metadata={"help": "The authenticated token."}
    )


@dataclass
class TrainArg:
    model_name_or_path: str = field(
        default="meta-llama/Llama-2-7b-hf",
        metadata={
            "help": (
                "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
            )
        },
    )
    tokenizer: str = field(
        default="hf-internal-testing/llama-tokenizer",
        metadata={"help": "The tokenizer to tokenize dataset"},
    )
    auth: str = field(
        default="<auth_token>",
        metadata={"help": "The authenticated token."},
    )
    data_dir: str = field(
        default="/path/to/dataset",
        metadata={"help": "The directory that includes binarized datasets."},
    )
    train_split: List[str] = field(
        default_factory=lambda: ["valid"],
        metadata={"help": "The prefix for training dataset."}
    )
    valid_split: List[str] = field(
        default_factory=lambda: ["valid"],
        metadata={"help": "The prefix for validation dataset."}
    )
    save_dir: str = field(
        default="checkpoints",
        metadata={"help": "The save path for saving checkpoints."}
    )
    save_every_N_steps: int = field(
        default=50,
        metadata={"help": "Save checkpoint every N steps."}
    )
    seq_len: int = field(
        default=2048,
        metadata={"help": "The maximum length for pre-training (next token prediction)."},
    )
    batch_size: int = field(
        default=32,
        metadata={"help": "The batch size."},
    )
    eval_batch_size: int = field(
        default=32,
        metadata={"help": "The batch size for evaluation."}
    )
    gradient_accumulation_step: int = field(
        default=1,
        metadata={"help": "The gradient accumulation steps."},
    )
    num_worker: int = field(
        default=50,
        metadata={"help": "The number of subprocesses to use for data loading."},
    )
    precision: Literal["fp32", "fp16", "bf16"] = field(
        default="fp32",
        metadata={"help": "The precision for training / inference."}
    )
    max_update: int = field(
        default=100000,
        metadata={"help": "The maximum training steps."}
    )
    seed: int = field(
        default=1234,
        metadata={"help": "The seed number for experiment running."}
    )
    append_bos: bool = field(
        default=False,
        metadata={"help": "Append bos token to sequence."}
    )
    criterion: str = field(
        default="cross_entropy",
        metadata={"help": "The criterion for calculating loss."}
    )
    lr_scheduler: Literal["linear", "cosine", "polynomial"] = field(
        default="cosine",
        metadata={"help": "The scheduler type to use."}
    )
    learning_rate: float = field(
        default=2e-4,
        metadata={"help": "The learning rate for adam."}
    )
    max_grad_norm: float = field(
        default=1.0,
        metadata={"help": "The maximum gradient norm for clipping."}
    )
    num_warmup_steps: int = field(
        default=2000,
        metadata={"help": "The warmup steps for adam."}
    )
    distributed_backend: Optional[str] = field(
        default="nccl",
        metadata={
            "help": "Distributed backend.",
            "choices": ["nccl"],
        },
    )
    distributed_rank: int = field(
        default=os.getenv("RANK", 0),
        metadata={"help": "The rank of the GPU process on all servers."},
    )
    local_rank: int = field(
        default=os.getenv("LOCAL_RANK", 0),
        metadata={"help": "The rank of the GPU process on the local machine."},
    )
    distributed_world_size: int = field(
        default=os.getenv("WORLD_SIZE", 0),
        metadata={"help": "The world size of all GPU processes."}
    )
    distributed_type: Literal["Pytorch", "Accelerator"] = field(
        default="Accelerator",
        metadata={"help": "Choose the distributed platform."},
    )
    use_deepspeed: bool = field(
        default=False,
        metadata={"help": "Use deepspeed for acceleration."}
    )
    deepspeed_config: str = field(
        default="/path/to/deepspeed/config",
        metadata={"help": "The path of the deepspeed configuration."}
    )
