from dataclasses import dataclass
from pathlib import Path
from typing import Literal

import tyro
import yaml


@dataclass
class Config:
    config: str = "configs/base.yaml"
    opt: Literal[1, 2] = 1


args = tyro.cli(Config)
assert Path(args.config).is_file(), f"Config file not found: {args.config}"

with open(args.config, "r") as f:
    config = yaml.safe_load(f)

opts = [
    ("0,1,2,3,4,5,6,7", "9501"),
    ("8,9,10,11,12,13,14,15", "9502"),
][args.opt - 1]

msg = """export CUDA_HOME=$CONDA_PREFIX
export CUDA_ROOT=$CONDA_PREFIX
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$CONDA_PREFIX/lib"
export PATH="$CONDA_PREFIX/bin:$PATH"

"""

msg += f"deepspeed --include localhost:{opts[0]} --master_port {opts[1]} train.py \\\n"
for k, v in config.items():
    msg += f"    --{k} {v} \\\n"
msg = msg.strip()[:-2]

with open(f"train{args.opt}.sh", "w") as f:
    f.write(msg)
