import os
from pathlib import Path
from dataclasses import dataclass
from typing import Literal, Optional

# import tyro
from jsonargparse import CLI


@dataclass
class Config:
    model: str = "meta-llama/Llama-3.2-1B-Instruct"
    loss: Literal["sft", "digit", "digit_base"] = "digit"
    tag: Optional[str] = None
    ablation: Optional[str] = None

    data_path: str = "../../../data/toycos/data"
    data_split: str = "train10"
    output_dir: str = "../../../data/toycos/ckpt"
    split_digit: bool = True
    seed: int = 42

    max_length: int = 2048
    grad_acc: int = 1
    learning_rate: float = 2e-5
    num_warmup_steps: int = 0
    num_epochs: int = 1
    save_steps: int = 300
    batch_size: int = 16
    num_workers: int = 4
    local_rank: int = -1


def get_name(args):
    data_name = args.data_split
    model_name = args.model.split("/")[-1].replace(".", "-")
    name = f"{args.seed}_{data_name}_{model_name}_{args.loss}"
    if args.ablation is not None:
        name = f"{name}_{args.ablation}"
    if args.tag is not None:
        name = f"{name}_{args.tag}"
    return name


def get_args():
    args = CLI(Config, as_positional=False)
    # args = tyro.cli(Config)
    args.run_name = get_name(args)
    args.output_dir = os.path.join(args.output_dir, args.run_name)
    return args
