# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging
import re
import typing as tp
from enum import Enum
from pathlib import Path

import flashy
import torch

from ..environment import AudioCraftEnvironment

logger = logging.getLogger(__name__)


class CheckpointSource(Enum):
    CURRENT_XP = "current_xp"
    PRETRAINED = "pretrained"
    OTHER = "other"


def checkpoint_name(
    name: tp.Optional[str] = None, rank: tp.Optional[int] = None, use_fsdp: bool = False
) -> str:
    """Checkpoint name formatted for all use in AudioCraft codebase and has the following format:
    `checkpoint_<name>.th(.<rank>)`. By convention, name is expected to be empty for last checkpoint,
    'best' for the best checkpoint or the epoch number.

    Args:
        name (str, optional): Name suffix for the checkpoint file stem.
        rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided.
        use_fsdp (bool): Whether the calling solver relies on FSDP.
    Returns:
        str: The checkpoint name.
    """
    suffix = ""
    if rank is None:
        rank = flashy.distrib.rank()
    if rank > 0 and use_fsdp:
        suffix = "." + str(rank)
    name_part = ""
    if name is not None:
        name_part = f"_{name}"
    return f"checkpoint{name_part}.th{suffix}"


def is_sharded_checkpoint(path: Path) -> bool:
    """Whether the checkpoint at the given path corresponds to a sharded checkpoint across rank."""
    return re.search(r"\.th\.\d+$", path.name) is not None


def resolve_checkpoint_path(
    sig_or_path: tp.Union[Path, str],
    name: tp.Optional[str] = None,
    use_fsdp: bool = False,
) -> tp.Optional[Path]:
    """Resolve a given checkpoint path for a provided dora sig or path.

    Args:
        sig_or_path (Path or str): Checkpoint path or dora signature.
        name (str, optional): Name suffix for the checkpoint file stem.
        rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided.
        use_fsdp (bool): Whether the calling solver relies on FSDP.
    Returns:
        Path, optional: Resolved checkpoint path, if it exists.
    """
    from audiocraft import train

    xps_root = train.main.dora.dir / "xps"
    sig_or_path = str(sig_or_path)
    if sig_or_path.startswith("//sig/"):
        sig = sig_or_path[len("//sig/") :]
        path = xps_root / sig
    else:
        path = Path(sig_or_path)
        path = AudioCraftEnvironment.resolve_reference_path(path)

    if path.is_dir():
        path = path / checkpoint_name(name, use_fsdp=use_fsdp)

    if path.exists():
        return path
    else:
        return None


def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> tp.Any:
    """Load state from checkpoints at the specified checkpoint path."""
    if is_sharded:
        rank0_checkpoint_path = checkpoint_path.parent / checkpoint_name(use_fsdp=False)
        if rank0_checkpoint_path.exists():
            check_sharded_checkpoint(checkpoint_path, rank0_checkpoint_path)
    state = torch.load(checkpoint_path, "cpu")
    logger.info("Checkpoint loaded from %s", checkpoint_path)
    return state


def save_checkpoint(
    state: tp.Any, checkpoint_path: Path, is_sharded: bool = False
) -> None:
    """Save state to disk to the specified checkpoint_path."""
    _safe_save_checkpoint(state, checkpoint_path, is_sharded)
    logger.info("Checkpoint saved to %s", checkpoint_path)


def flush_stale_checkpoints(
    checkpoint_path: Path, keep_last: tp.Optional[int] = None
) -> None:
    """Flush checkpoints to only keep last N checkpoints."""
    if keep_last is None or keep_last <= 0:
        return
    checkpoint_dir = checkpoint_path.parent
    suffix = ""
    if flashy.distrib.rank() > 0:
        suffix = f".{flashy.distrib.rank()}"
    checkpoint_files_with_epoch = []
    for path in Path(checkpoint_dir).glob(f"checkpoint_*.th{suffix}"):
        epoch_part = path.name.split(".", 1)[0].split("_", 1)[1]
        if epoch_part.isdigit():
            checkpoint_files_with_epoch.append((path, int(epoch_part)))
    checkpoint_files = [
        path
        for path, _ in list(sorted(checkpoint_files_with_epoch, key=lambda t: t[1]))
    ]
    total_to_flush = max(0, len(checkpoint_files) - keep_last)
    files_to_flush = checkpoint_files[:total_to_flush]
    for path in files_to_flush:
        logger.debug("Removing checkpoint: %s", str(path))
        path.unlink(missing_ok=True)


def check_sharded_checkpoint(
    checkpoint_path: Path, rank0_checkpoint_path: Path
) -> None:
    """Check sharded checkpoint state, ensuring the checkpoints are not corrupted."""
    # Finish the work of a previous run that got interrupted while dumping.
    old_path = Path(str(checkpoint_path) + ".old")
    if old_path.exists():
        raise RuntimeError(
            f"Old checkpoint {old_path} from previous version of this code exist, cannot safely proceed."
        )
    token = Path(str(rank0_checkpoint_path) + ".tmp.done")
    tmp_path = Path(str(checkpoint_path) + ".tmp")
    if token.exists():
        if tmp_path.exists():
            tmp_path.rename(checkpoint_path)
    flashy.distrib.barrier()
    if flashy.distrib.is_rank_zero() and token.exists():
        token.unlink()


def _safe_save_checkpoint(
    state: tp.Any, checkpoint_path: Path, is_sharded: bool = False
) -> None:
    """Save checkpoints in a safe manner even with when sharded checkpoints across nodes."""

    def _barrier_if_sharded():
        if is_sharded:
            flashy.distrib.barrier()

    if flashy.distrib.is_rank_zero():
        token = Path(str(checkpoint_path) + ".tmp.done")
        if token.exists():
            token.unlink()
    _barrier_if_sharded()
    with flashy.utils.write_and_rename(checkpoint_path) as f:
        torch.save(state, f)
        _barrier_if_sharded()
        if flashy.distrib.is_rank_zero():
            token.touch()
        _barrier_if_sharded()
    _barrier_if_sharded()
    if flashy.distrib.rank() == 0:
        token.unlink()
