from pathlib import Path

def _pick_checkpoint(ckpt_dir: Path) -> Path:
    """
    Select a checkpoint directory from a JAX-style checkpoint layout.

    Expected structure:
        ckpt_dir/
            epoch_0000/
                checkpoint.pkl
            epoch_0001/
                checkpoint.pkl
            ...
            final/
                checkpoint.pkl
            best/               (optional)
                checkpoint.pkl
    """
    if not ckpt_dir.exists():
        raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")

    preferred = ["best", "final"]
    for name in preferred:
        p = ckpt_dir / name
        if p.is_dir() and (p / "checkpoint.pkl").is_file():
            return p

    epoch_ckpts = []
    for p in ckpt_dir.iterdir():
        if not p.is_dir():
            continue    
        if not p.name.startswith("epoch_"):
            continue
        ckpt_file = p / "checkpoint.pkl"
        if not ckpt_file.is_file():
            continue
        try:
            epoch = int(p.name.split("_")[1])
        except (IndexError, ValueError):
            continue
        epoch_ckpts.append((epoch, p))

    if epoch_ckpts:
        epoch_ckpts.sort(key=lambda t: t[0])
        return epoch_ckpts[-1][1]

    valid_ckpts = [
        p for p in ckpt_dir.iterdir()
        if p.is_dir() and (p / "checkpoint.pkl").is_file()
    ]

    if valid_ckpts:
        return max(valid_ckpts, key=lambda p: p.stat().st_mtime)

    raise FileNotFoundError(f"No valid checkpoint found in {ckpt_dir}.")


def get_dataset_path(exp_dir: str) -> str:
    exp_path = Path(exp_dir)
    exp_data = exp_path / "exp_config.json"
    if not exp_data.is_file():
        raise FileNotFoundError(f"Experiment config file not found at {exp_data}")
    with open(exp_data, "r") as f:
        import json
        exp_config = json.load(f)

    dataset_name = exp_config.get("dataset_name")
    if dataset_name is None:
        raise ValueError("Dataset name not found in experiment config.")
    dataset_path = f"{dataset_name}/test_data.npz"

    return dataset_path