from pathlib import Path
import json


from se.configs import PROJECT_ROOT, ModelConfig, TrainConfig, WandbConfig


def load_train_config(config_path: Path) -> TrainConfig:
    with config_path.open("r", encoding="utf-8") as fp:
        data = json.load(fp)
    model_cfg = ModelConfig(**data.pop("model_cfg"))
    wandb_cfg = WandbConfig(**data.pop("wandb_cfg"))
    return TrainConfig(model_cfg=model_cfg, wandb_cfg=wandb_cfg, **data)


def resolve_checkpoint_path(
    log_dir: Path, epoch: int | None, explicit: Path | None
) -> Path:
    if explicit is not None:
        ckpt = explicit.expanduser().resolve()
        if not ckpt.is_file():
            raise FileNotFoundError(f"Checkpoint {ckpt} does not exist.")
        return ckpt

    last_ckpt = log_dir / "weights_last.pt"
    if epoch is None and last_ckpt.is_file():
        return last_ckpt

    if epoch is None:
        checkpoints = sorted(log_dir.glob("weights_epoch_*.pt"))
        if checkpoints:
            return checkpoints[-1]
        raise FileNotFoundError(f"No checkpoints found in {log_dir}.")

    matching = log_dir / f"weights_epoch_{epoch:04d}.pt"
    if not matching.is_file():
        raise FileNotFoundError(f"Checkpoint {matching} does not exist.")
    return matching
