from __future__ import annotations

import os
from typing import Any, Dict, Optional

import torch


class CheckpointManager:
    def __init__(self, dirpath: str, keep: int = 5):
        self.dir = dirpath
        self.keep = keep
        os.makedirs(self.dir, exist_ok=True)

    def _ckpt_path(self, tag: str) -> str:
        return os.path.join(self.dir, f"{tag}.pt")

    def save(self, tag: str, state: Dict[str, Any]) -> str:
        path = self._ckpt_path(tag)
        torch.save(state, path)
        self._cleanup()
        return path

    def load(self, tag: str, map_location: Optional[str] = None) -> Dict[str, Any]:
        return torch.load(self._ckpt_path(tag), map_location=map_location)

    def _cleanup(self):
        files = sorted([f for f in os.listdir(self.dir) if f.endswith('.pt')])
        if len(files) > self.keep:
            to_rm = files[: len(files) - self.keep]
            for f in to_rm:
                try:
                    os.remove(os.path.join(self.dir, f))
                except OSError:
                    pass

