import os
from dataclasses import dataclass
from typing import Literal, Optional

# import tyro
from jsonargparse import CLI


"""
models:    
llava-hf/llava-onevision-qwen2-0.5b-ov-hf
"""


@dataclass
class Config:
    model: str = "xtuner/llava-phi-3-mini-hf"
    loss: Literal["sft", "digit", "digit_base"] = "digit"
    tag: Optional[str] = None

    data_path: str = "../data/images/laion_coco_aesthetic/normalized/1_30/total.json"
    image_dir: str = "../data/images/laion_coco_aesthetic/normalized/1_30/images"
    output_dir: str = "../data/ckpt"

    max_length: int = 4096
    grad_acc: int = 2
    learning_rate: float = 2e-5
    num_warmup_steps: int = 100
    num_epochs: int = 3
    save_steps: int = 100
    batch_size: int = 8
    num_workers: int = 4
    local_rank: int = -1


def get_name(args):
    name = f"ft-{args.model.split('/')[-1]}_{args.loss}"
    if args.tag is not None:
        name = f"{name}_{args.tag}"
    return name


def get_args():
    args = CLI(Config, as_positional=False)
    # args = tyro.cli(Config)
    args.run_name = get_name(args)
    args.output_dir = os.path.join(args.output_dir, args.run_name)
    return args
