import json
import sys
from pathlib import Path

REPEATS = 1

def find_datasets():
    for data in Path("datasets").iterdir():
        for fmt in data.iterdir():
            for i in fmt.iterdir():
                if not i.is_dir():
                    continue
                yield i

match sys.argv[1]:
    case "all":
        ALL = True
        RUN = True
        REM = False
    case 'rem':
        ALL = False
        RUN = False
        REM = True
    case 'run':
        ALL =  False
        RUN = True
        REM = False
    case _:
        raise NotImplementedError


to_rem = []
to_run = []
for ds in find_datasets():
    dataset = ds.parts[0]
    ops: dict[str, list[str]]
    for model, model_coms, ops in [
        ("logit", ["rum.py", "--rum=logit", "--rum-mode=full", "--train_rank=2"], {}),
        ("mf", ["mf.py"], {}),
        ("optimal", ["optimal.py", "--optimal"], {}),
        ("probit2", ["rum.py", "--rum=probit2", "--train_rank=2"], {}),
        ("probit2I", ["rum.py", "--rum=probit2", "--rum-mode=spherical", "--train_rank=2"], {}),
        ("probit3", ["rum.py", "--rum=probit3"], {}),
    ]:
        path = Path("exps") / model / Path(*ds.parts[1:])

        coms: list[str] = [
            *model_coms,
            f"--base_dir={ds}",
            f"--output_dir={path}",
            *ops.get(dataset, []),
        ]
        com = " ".join(coms)
        repeats = REPEATS
        
        for p in path.rglob("done"):
                repeats -= not ALL
        for _ in range(repeats):
            to_run.append(com)
if REM:
    for r in to_rem:
        print(r)
elif RUN:
    for r in to_run:
        print("uv run", r)
else:
    raise NotImplementedError