from typing import Optional, Union
from pathlib import Path

from pytorch_lightning import callbacks as pl_callbacks

from .dirs import get_checkpoints_dir, recreate_dir


def get_checkpoints_callback(
    task_name: Union[str, list[str]],
    # checkpoints_dir: Union[str, Path],
    save_interval: Optional[int] = None,
) -> pl_callbacks.ModelCheckpoint:
    checkpoints_dir = get_checkpoints_dir(task_name)
    recreate_dir(checkpoints_dir)

    return pl_callbacks.ModelCheckpoint(
        dirpath=checkpoints_dir,
        every_n_epochs=save_interval,
        # monitor="loss/val",
        # save_top_k=1,
        # filename="ckpt-{epoch:02d}-{global_step}-{val_loss:.2f}",
    )

def get_best_checkpoint_path(
    task_name: Union[str, list[str]],
) -> Path:
    checkpoints_dir = get_checkpoints_dir(task_name)
    if not checkpoints_dir.is_dir():
        raise ValueError(
            f"'{checkpoints_dir}' is not a valid (checkpoint) directory."
        )
    chkpt_files = sorted([f for f in checkpoints_dir.iterdir()])
    if len(chkpt_files) == 0:
        raise FileNotFoundError("No checkpoint found for task.")
    # assert len(chkpt_files) == 1, "There should only be 1 checkpoint per task."
    # Load the best-performing checkpoint
    return chkpt_files[-1]



# TrainingStateItemType = Literal["model", "optimizer", "scheduler"]

# def _get_item_file(
#     checkpoints_dir: Path,
#     epoch: int,
#     item_type: TrainingStateItemType,
# ) -> Path:
#     if item_type == "model":
#         item_postfix = ".param"
#     elif item_type == "optimizer":
#         item_postfix = "_optimizer.state"
#     elif item_type == "scheduler":
#         item_postfix = "_scheduler.state"
#     else:
#         raise ValueError(f"Invalid item type {item_type}")
#     return checkpoints_dir / f"{epoch:04d}{item_postfix}"


# EPOCH_NUMBER_LENGTH = 4

# def save_load_training_state(
#     checkpoints_dir: Path,
#     epoch: int,
#     save: bool,
#     model: Model,
#     optimizer: Optional[optim.Optimizer] = None,
#     lr_scheduler: Optional[optim.lr_scheduler._LRScheduler] = None,
#     device: str = "cpu",
# ) -> None:
#     if save:
#         checkpoints_dir.mkdir(parents=True, exist_ok=True)

#     save_items: list[
#         Union[Model, optim.Optimizer, optim.lr_scheduler._LRScheduler]
#     ] = [model]
#     item_postfixes = [".params"]
#     if optimizer is not None:
#         save_items.append(optimizer)
#         item_postfixes.append("_optimizer.state")
#     if lr_scheduler is not None:
#         save_items.append(lr_scheduler)
#         item_postfixes.append("_scheduler.state")

#     if epoch < 0:
#         # Save or load the final model state
#         epoch_prefix = "final"
#     else:
#         epoch_prefix = f"{epoch:0{EPOCH_NUMBER_LENGTH}d}"

#     for item, item_postfix in zip(save_items, item_postfixes):
#         save_file = checkpoints_dir / f"{epoch_prefix}{item_postfix}"
#         if save:
#             torch.save(item.state_dict(), save_file)
#         else:
#             item.load_state_dict(
#                 torch.load(save_file, map_location=torch.device(device)),
#             )
