# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Utilities for syncing TorchTitan checkpoints with S3."""

from __future__ import annotations

import json
import threading
from collections import deque
from pathlib import Path
from typing import Any, TYPE_CHECKING

from composer.loggers import RemoteUploaderDownloader
from composer.loggers.remote_uploader_downloader import _upload_worker
from composer.utils.file_helpers import list_remote_objects

from torchtitan.tools.logging import logger

if TYPE_CHECKING:
    from collections.abc import Callable, Iterable

    from torchtitan.components.checkpoint import CheckpointManager

    from .configs.config import MosaicJobConfig, S3CheckpointingConfig

__all__ = [
    "S3CheckpointWrapper",
    "create_remote_up_down",
    "download_file_from_s3",
    "get_s3_checkpoint_wrapper_factory",
    "setup_s3_checkpointing",
    "upload_file_to_s3",
]

MANIFEST_FILENAME = "s3_manifest.json"
LATEST_FILENAME = "s3_latest.txt"

# Constants for validation
RESUME_FORMAT_PARTS_COUNT = 2
STEP_PREFIX = "step-"
MAX_FILES_TO_DISPLAY = 20


def download_file_from_s3(
    remote_up_down: RemoteUploaderDownloader,
    remote_file_name: str,
    local_file_name: Path | str,
) -> None:
    """Download a file from S3 using the RemoteUploaderDownloader."""
    remote_up_down._check_workers()
    remote_up_down.download_file(
        remote_file_name=remote_file_name,
        destination=str(local_file_name),
        overwrite=True,
    )


def upload_file_to_s3(
    remote_up_down: RemoteUploaderDownloader,
    remote_file_name: str,
    local_file_name: Path,
) -> None:
    """Upload a file to S3 using the RemoteUploaderDownloader."""
    remote_up_down._check_workers()
    remote_up_down.upload_file(
        state=None,
        remote_file_name=remote_file_name,
        file_path=local_file_name,
        overwrite=True,
    )


def create_remote_up_down(  # noqa: PLR0913
    bucket_name: str,
    prefix: str,
    num_attempts: int,
    client_config: dict[str, Any],
    *,
    num_concurrent_uploads: int = 1,
    upload_staging_folder: str | None = None,
    use_procs: bool = True,
) -> RemoteUploaderDownloader:
    """Create a RemoteUploaderDownloader configured for S3."""
    bucket_uri = f"s3://{bucket_name}"
    return RemoteUploaderDownloader(
        bucket_uri=bucket_uri,
        backend_kwargs={
            "bucket": bucket_name,
            "prefix": prefix,
            "region_name": None,
            "endpoint_url": None,
            "aws_access_key_id": None,
            "aws_secret_access_key": None,
            "aws_session_token": None,
            "client_config": client_config,
            "transfer_config": None,
        },
        file_path_format_string="{remote_file_name}",
        num_concurrent_uploads=num_concurrent_uploads,
        upload_staging_folder=upload_staging_folder,
        use_procs=use_procs,
        num_attempts=num_attempts,
    )


class S3CheckpointWrapper:
    """Synchronise checkpoints produced by a :class:`CheckpointManager` with S3.

    The wrapper composes with the wrapped checkpointer via delegation. It overrides
    a subset of the public API (``save``/``maybe_wait_for_staging``/``close``)
    while forwarding all other attributes to the inner manager. This keeps the
    original instance untouched so other references (e.g. TorchFT internals)
    continue to observe the same object graph.
    """

    def __init__(
        self,
        checkpointer: CheckpointManager,
        config: S3CheckpointingConfig,
        job_config: MosaicJobConfig,
        *,
        enable_uploads: bool = True,
    ) -> None:
        self._checkpointer = checkpointer
        self.config = config
        self.job_config = job_config
        self._enable_uploads = enable_uploads
        self.remote_root = self._resolve_remote_root()
        self.remote_up_down = create_remote_up_down(
            bucket_name=config.bucket,
            prefix=config.prefix,
            num_attempts=config.num_attempts,
            client_config=config.client_config,
            num_concurrent_uploads=config.num_concurrent_uploads,
            upload_staging_folder=config.upload_staging_folder,
            use_procs=config.use_procs,
        )
        # Set the run name for the RemoteUploaderDownloader
        # This is normally set in init() but we're using it standalone
        run_name = (
            config.run_uuid or job_config.job.description or Path(job_config.job.dump_folder).name or "torchtitan-run"
        )
        self.remote_up_down._run_name = str(run_name)

        self._base_folder = Path(checkpointer.folder)
        self._ft_mode = bool(getattr(checkpointer, "ft_manager", None))
        self._ft_folder_path: Path | None = None
        self._ft_relative: Path | None = None
        if self._ft_mode:
            ft_folder_str = checkpointer._ft_folder()
            self._ft_folder_path = Path(ft_folder_str)
            try:
                self._ft_relative = self._ft_folder_path.relative_to(self._base_folder)
            except ValueError:
                self._ft_relative = Path(self._ft_folder_path.name)

        self._pending_steps: deque[tuple[int, Path]] = deque()
        self._uploaded_steps: set[int] = set()
        self._latest_uploaded_step: int | None = None
        self._closed = False

        # Install tracking
        self._missing_directory_steps: set[int] = set()
        self._not_ready_steps: set[int] = set()
        self._missing_metadata_steps: set[int] = set()
        self._orig_save = checkpointer.save
        self._orig_maybe_wait = checkpointer.maybe_wait_for_staging
        self._orig_close = checkpointer.close

        if self._enable_uploads:
            self._start_remote_workers()

    @property
    def checkpointer(self) -> CheckpointManager:
        """Get the underlying CheckpointManager instance."""
        return self._checkpointer

    def __getattr__(self, name: str) -> Any:
        """Proxy attribute access to the underlying CheckpointManager."""
        if hasattr(self._checkpointer, name):
            return getattr(self._checkpointer, name)
        msg = f"'{type(self).__name__}' proxy: '{type(self._checkpointer).__name__}' object has no attribute '{name}'"
        raise AttributeError(msg)

    def attach_to_trainer(self, trainer: Any) -> None:
        """Replace ``trainer.checkpointer`` with this wrapper."""
        trainer.checkpointer = self

    def install_onto_checkpointer(self) -> None:
        """Patch the wrapped checkpointer for legacy compatibility."""
        existing = getattr(self._checkpointer, "_s3_wrapper", None)
        if existing is self:
            return
        if existing is not None and existing is not self:
            logger.warning(
                "Replacing existing S3 checkpoint wrapper on %s",
                type(self._checkpointer).__name__,
            )
        self._checkpointer._s3_wrapper = self
        self._checkpointer.save = self.save  # type: ignore[assignment]
        self._checkpointer.maybe_wait_for_staging = self.maybe_wait_for_staging  # type: ignore[assignment]
        self._checkpointer.close = self.close  # type: ignore[assignment]

    def _start_remote_workers(self) -> None:
        """Start the RemoteUploaderDownloader background workers."""
        rud = self.remote_up_down

        if rud._worker_flag is not None:
            return  # Already initialized

        rud._worker_flag = rud._finished_cls()

        # Create the enqueue thread
        rud._enqueue_thread_flag = rud._finished_cls()
        rud._enqueue_thread = threading.Thread(target=rud._enqueue_uploads, daemon=True)
        rud._enqueue_thread.start()

        # Start the upload workers
        for _ in range(rud._num_concurrent_uploads):
            worker = rud._proc_class(
                target=_upload_worker,
                kwargs={
                    "file_queue": rud._file_upload_queue,
                    "is_finished": rud._worker_flag,
                    "remote_backend_name": rud.remote_backend_name,
                    "backend_kwargs": rud.backend_kwargs,
                    "num_attempts": rud.num_attempts,
                    "completed_queue": rud._completed_queue,
                    "exception_queue": rud._exception_queue,
                },
                daemon=True,
            )
            worker.start()
            rud._workers.append(worker)

        logger.info("Started %d S3 upload workers", rud._num_concurrent_uploads)

    def __del__(self) -> None:
        """Clean up resources on object destruction."""
        self.close()

    def download_if_needed(self) -> None:
        """Optionally download a checkpoint from S3 before training starts.

        If resume_from_run_step is set, downloads from that specific run/step.
        Otherwise, looks for the latest checkpoint in the current run.
        """
        if not self.config.download_on_start:
            logger.info("S3 download skipped: download_on_start=False")
            return

        if self._ft_mode and self._ft_folder_path is not None:
            base_folder = self._ft_folder_path
            local_latest = self.checkpointer._find_load_step(folder=str(self._ft_folder_path))
        else:
            base_folder = self._base_folder
            local_latest = self._find_local_latest_step()
        logger.info(
            "Checking for local checkpoints in: %s (found step: %s)",
            base_folder,
            local_latest if local_latest != -1 else "none",
        )
        if local_latest != -1:
            logger.info("Skipping S3 download: local checkpoint found at step %s", local_latest)
            return

        # Determine what to download
        if self.config.resume_from_run_step:
            # Parse format: "{run_uuid}/step-{N}"
            try:
                parts = self.config.resume_from_run_step.split("/")
                if len(parts) != RESUME_FORMAT_PARTS_COUNT or not parts[1].startswith(STEP_PREFIX):
                    self._raise_invalid_resume_format()
                run_uuid = parts[0]
                step_str = parts[1][len(STEP_PREFIX) :]  # Remove "step-" prefix
                step = int(step_str)
                remote_path = f"torchtitan/{run_uuid}"
                relative_suffix = Path(f"step-{step}")
                if self._ft_relative is not None:
                    relative_suffix = self._ft_relative / relative_suffix
                remote_preview = f"{remote_path}/{relative_suffix.as_posix()}"

                prefix_display = (self.config.prefix or "").strip("/")
                components = [comp for comp in (prefix_display, remote_preview) if comp]
                combined_path = "/".join(components)
                logger.info(
                    "Resuming from run step: %s (downloading from: s3://%s/%s)",
                    self.config.resume_from_run_step,
                    self.config.bucket,
                    combined_path,
                )
            except (ValueError, IndexError) as e:
                logger.exception(
                    "Failed to parse resume_from_run_step: %s",
                    self.config.resume_from_run_step,
                )
                self._raise_invalid_resume_format(e)
        else:
            # Look for latest in current run
            remote_path = self.remote_root
            step = self._read_remote_latest_step() or -1

            logger.info(
                "Continuing current run: %s (looking for latest in: s3://%s/%s)",
                self.config.run_uuid,
                self.config.bucket,
                remote_path,
            )

            if step == -1:
                logger.info("No remote checkpoint available for download.")
                return

        try:
            self._download_step(step, remote_path=remote_path)
            logger.info("Downloaded checkpoint step %s from S3.", step)
        except Exception:
            logger.exception("Failed to download checkpoint step %s from S3.", step)
            raise

    def close(self) -> None:
        """Flush pending uploads and release remote resources."""
        if self._closed:
            return
        try:
            self._wait_for_staging_with_logging()
            if self._enable_uploads:
                self._process_pending(flush=True)
            # Note: RemoteUploaderDownloader cleanup is handled by Composer internally
            self._orig_close()
        finally:
            self._closed = True

    def _wait_for_staging_with_logging(self) -> None:
        try:
            self._orig_maybe_wait()
        except Exception:  # noqa: BLE001
            logger.exception("Failed while waiting for staged checkpoints before upload.")
        if self._enable_uploads:
            self._process_pending(flush=True)
        # Note: RemoteUploaderDownloader cleanup is handled by Composer internally

    def _resolve_remote_root(self) -> str:
        root = self.config.remote_checkpoint_folder or self.job_config.checkpoint.folder
        return root.strip("/")

    def _checkpoint_dir(self, step: int) -> Path:
        if self._enable_uploads and self._ft_mode and self._ft_folder_path is not None:
            return self._ft_folder_path / f"step-{step}"
        return self._base_folder / f"step-{step}"

    def _raise_invalid_resume_format(self, cause: Exception | None = None) -> None:
        """Raise a ValueError for invalid resume_from_run_step format.

        Args:
            cause: Optional exception that caused this error
        """
        msg = (
            f"Invalid resume_from_run_step format: '{self.config.resume_from_run_step}'. "
            "Expected format: '{{run_uuid}}/step-{{N}}' (e.g., '16M-baseline-20251011-122516/step-10')"
        )
        if cause:
            raise ValueError(msg) from cause
        raise ValueError(msg)

    def _remote_key(self, relative_path: Path, remote_root: str | None = None) -> str:
        """Get the remote S3 key.

        Args:
            relative_path: Path relative to the checkpoint folder
            remote_root: Optional override for the remote root. If not provided, uses self.remote_root
        """
        relative_key = relative_path.as_posix()
        root = remote_root if remote_root is not None else self.remote_root
        if root:
            return f"{root}/{relative_key}"
        return relative_key

    def _remote_listing_uri(self, remote_key: str) -> str:
        """Build a fully-qualified S3 URI for listing remote objects."""
        bucket = self.config.bucket.strip("/")
        prefix = (self.config.prefix or "").strip("/")
        remote_key = remote_key.strip("/")
        path_components = [part for part in (prefix, remote_key) if part]
        uri = f"s3://{bucket}"
        if path_components:
            uri = f"{uri}/{'/'.join(path_components)}"
        if not uri.endswith("/"):
            uri += "/"
        return uri

    def _remote_listing_prefix(self, remote_key: str) -> str:
        """Compute the prefix used when listing remote objects."""
        prefix = (self.config.prefix or "").strip("/")
        remote_key = remote_key.strip("/")
        path_components = [part for part in (prefix, remote_key) if part]
        if not path_components:
            return ""
        return "/".join(path_components).rstrip("/") + "/"

    def _enumerate_remote_step_files(
        self,
        remote_root: str,
        candidate_relatives: list[Path],
    ) -> tuple[list[str], Path] | tuple[None, None]:
        """Enumerate checkpoint files directly from S3 when the manifest is missing."""
        for relative_base in candidate_relatives:
            remote_key = self._remote_key(relative_base, remote_root=remote_root)
            listing_uri = self._remote_listing_uri(remote_key)
            logger.info(
                "Manifest %s not found; enumerating checkpoint files via %s",
                MANIFEST_FILENAME,
                listing_uri,
            )
            try:
                object_keys = list_remote_objects(listing_uri)
            except Exception as err:  # noqa: BLE001
                logger.warning(
                    "Failed to list remote objects at %s (%s); trying alternate layout.",
                    listing_uri,
                    err,
                )
                continue

            listing_prefix = self._remote_listing_prefix(remote_key)
            entries: list[str] = []
            for key in object_keys:
                candidate = key
                if listing_prefix:
                    if not key.startswith(listing_prefix):
                        continue
                    candidate = key[len(listing_prefix) :]
                candidate = candidate.strip("/")
                if not candidate:
                    continue
                entries.append(candidate)

            if entries:
                deduped_entries = sorted(set(entries))
                logger.info(
                    "Identified %d file(s) for checkpoint step using manifestless discovery.",
                    len(deduped_entries),
                )
                return deduped_entries, relative_base
            logger.warning(
                "No checkpoint files found when listing %s; trying alternate layout.",
                listing_uri,
            )
        return None, None

    def save(self, curr_step: int, *, last_step: bool = False) -> None:
        """Save checkpoint and queue for S3 upload.

        Args:
            curr_step: Current training step
            last_step: Whether this is the final checkpoint
        """
        self._orig_save(curr_step, last_step=last_step)
        if not self._enable_uploads:
            return
        checkpoint_dir = self._checkpoint_dir(curr_step)
        if not checkpoint_dir.exists():
            logger.warning(
                "Checkpoint directory %s for step %s does not exist immediately after save; "
                "upload will be retried once it becomes available.",
                checkpoint_dir,
                curr_step,
            )
        self._pending_steps.append((curr_step, checkpoint_dir))
        if last_step:
            try:
                self._orig_maybe_wait()
            except Exception:  # noqa: BLE001
                logger.exception("Failed while waiting for staged checkpoints before final upload.")
            self._process_pending(flush=True)

    def maybe_wait_for_staging(self) -> None:
        """Wait for staged checkpoints and process pending uploads."""
        self._orig_maybe_wait()
        if self._enable_uploads:
            self._process_pending()

    def _process_pending(self, flush: bool = False) -> None:  # noqa: FBT001, FBT002
        pending: deque[tuple[int, Path]] = deque()
        while self._pending_steps:
            step, directory = self._pending_steps.popleft()
            if step in self._uploaded_steps:
                continue
            if not directory.exists():
                if flush:
                    logger.error(
                        "Checkpoint directory %s for step %s does not exist and will not be uploaded during flush.",
                        directory,
                        step,
                    )
                else:
                    if step not in self._missing_directory_steps:
                        self._missing_directory_steps.add(step)
                    pending.append((step, directory))
                continue
            self._missing_directory_steps.discard(step)
            if not self._is_directory_ready_for_upload(step, directory):
                if flush:
                    logger.error(
                        "Checkpoint directory %s for step %s is not ready for upload during flush and will be skipped.",
                        directory,
                        step,
                    )
                else:
                    if step not in self._not_ready_steps:
                        logger.info(
                            "Checkpoint directory %s for step %s is still being written; deferring upload.",
                            directory,
                            step,
                        )
                        self._not_ready_steps.add(step)
                    pending.append((step, directory))
                continue
            self._not_ready_steps.discard(step)
            self._upload_step(step, directory)
        self._pending_steps = pending

    def _is_directory_ready_for_upload(self, step: int, directory: Path) -> bool:
        try:
            entries = list(directory.iterdir())
        except FileNotFoundError:
            return False
        if not entries:
            return False
        try:
            has_temp_files = any(directory.rglob("*.tmp"))
        except FileNotFoundError:
            return False
        if has_temp_files:
            return False

        files = [path for path in entries if path.is_file()]
        has_distcp = any(path.suffix == ".distcp" for path in files)
        metadata_path = directory / ".metadata"
        if has_distcp and not metadata_path.exists():
            if step not in self._missing_metadata_steps:
                logger.info(
                    "Checkpoint step %s is waiting for metadata file before upload (expected at %s).",
                    step,
                    metadata_path,
                )
                self._missing_metadata_steps.add(step)
            return False
        self._missing_metadata_steps.discard(step)

        has_safetensors = any(path.suffix == ".safetensors" for path in files)
        if has_safetensors:
            index_file = directory / "model.safetensors.index.json"
            if not index_file.exists():
                if step not in self._missing_metadata_steps:
                    logger.info(
                        "Checkpoint step %s is waiting for safetensors index file before upload (expected at %s).",
                        step,
                        index_file,
                    )
                    self._missing_metadata_steps.add(step)
                return False
            self._missing_metadata_steps.discard(step)

        return True

    def _upload_step(self, step: int, directory: Path) -> None:
        files = sorted(self._iter_checkpoint_files(directory))
        if not files:
            return

        manifest_path = directory / MANIFEST_FILENAME
        manifest_content = [path.relative_to(directory).as_posix() for path in files]
        manifest_path.write_text(json.dumps(manifest_content))

        upload_targets = [*files, manifest_path]
        uploaded_paths = []
        for file_path in upload_targets:
            relative = file_path.relative_to(Path(self.checkpointer.folder))
            remote_key = self._remote_key(relative)
            upload_file_to_s3(self.remote_up_down, remote_key, file_path)
            # Log the full S3 path
            s3_uri = f"s3://{self.config.bucket}/{remote_key}"
            uploaded_paths.append(s3_uri)
            logger.info("Uploaded: %s -> %s", file_path, s3_uri)

        self._uploaded_steps.add(step)
        if self._latest_uploaded_step is None or step > self._latest_uploaded_step:
            self._latest_uploaded_step = step
            self._write_latest_marker(step)
        logger.info("Uploaded checkpoint step %s to S3 (%s files).", step, len(files))
        logger.info("All uploaded files for step %s: %s", step, uploaded_paths)

    def _iter_checkpoint_files(self, directory: Path) -> Iterable[Path]:
        for path in directory.rglob("*"):
            if path.is_file() and path.name != MANIFEST_FILENAME:
                yield path

    def _write_latest_marker(self, step: int) -> None:
        marker_path = Path(self.checkpointer.folder) / LATEST_FILENAME
        marker_path.write_text(f"{step}\n")
        remote_key = self._remote_key(Path(LATEST_FILENAME))
        s3_uri = f"s3://{self.config.bucket}/{remote_key}"
        upload_file_to_s3(self.remote_up_down, remote_key, marker_path)
        logger.info(
            "Uploaded latest marker: %s -> %s (points to step %s)",
            marker_path,
            s3_uri,
            step,
        )

    def _find_local_latest_step(self) -> int:
        try:
            return self.checkpointer._find_load_step()
        except Exception:  # noqa: BLE001
            return -1

    def _read_remote_latest_step(self, remote_root: str | None = None) -> int | None:
        """Read the latest checkpoint step from S3.

        Args:
            remote_root: Optional override for the remote root. If not provided, uses self.remote_root
        """
        marker_path = Path(self.checkpointer.folder) / LATEST_FILENAME
        try:
            marker_path.parent.mkdir(parents=True, exist_ok=True)
            download_file_from_s3(
                self.remote_up_down,
                self._remote_key(Path(LATEST_FILENAME), remote_root=remote_root),
                marker_path,
            )
        except Exception:  # noqa: BLE001
            return None
        try:
            return int(marker_path.read_text().strip())
        except ValueError:
            logger.warning("Invalid latest checkpoint marker downloaded from S3.")
            return None

    def _download_step(self, step: int, remote_path: str) -> None:
        """Download a specific checkpoint step from S3.

        Args:
            step: The checkpoint step number to download
            remote_path: The remote S3 path prefix (e.g., "torchtitan/16M-baseline-20251011-122516")
        """
        checkpoint_dir = self._checkpoint_dir(step)
        logger.info("Downloading checkpoint step %s to: %s", step, checkpoint_dir)
        checkpoint_dir.mkdir(parents=True, exist_ok=True)
        manifest_path = checkpoint_dir / MANIFEST_FILENAME

        candidate_relatives: list[Path] = []
        base_relative = Path(f"step-{step}")
        if self._ft_relative is not None:
            candidate_relatives.append(self._ft_relative / base_relative)
        candidate_relatives.append(base_relative)

        manifest_entries: list[str] | None = None
        chosen_relative_base: Path | None = None
        last_error: Exception | None = None

        for idx, relative_base in enumerate(candidate_relatives):
            remote_manifest_key = self._remote_key(relative_base / MANIFEST_FILENAME, remote_root=remote_path)
            logger.info(
                "Downloading manifest from S3: s3://%s/%s/%s",
                self.config.bucket,
                self.config.prefix,
                remote_manifest_key,
            )
            try:
                download_file_from_s3(
                    self.remote_up_down,
                    remote_manifest_key,
                    manifest_path,
                )
                manifest_entries = json.loads(manifest_path.read_text())
                chosen_relative_base = relative_base
                if idx > 0:
                    logger.info("Checkpoint manifest located without TorchFT replica directory; using fallback layout.")
                break
            except Exception as err:  # noqa: BLE001
                last_error = err
                logger.warning(
                    "Failed to download manifest from %s (attempt %d/%d).",
                    remote_manifest_key,
                    idx + 1,
                    len(candidate_relatives),
                )
                manifest_path.unlink(missing_ok=True)

        if manifest_entries is None or chosen_relative_base is None:
            logger.info(
                "Manifest %s could not be retrieved from S3; attempting manifestless download for step %s.",
                MANIFEST_FILENAME,
                step,
            )
            fallback_entries, fallback_relative = self._enumerate_remote_step_files(remote_path, candidate_relatives)
            if fallback_entries is None or fallback_relative is None:
                if last_error is not None:
                    raise last_error
                msg = (
                    f"Unable to locate checkpoint files for step {step} using manifestless discovery. "
                    "Ensure the remote path contains uploaded checkpoint artifacts."
                )
                raise FileNotFoundError(msg)
            manifest_entries = fallback_entries
            chosen_relative_base = fallback_relative
            manifest_path.write_text(json.dumps(manifest_entries))

        logger.info("Manifest contains %d files to download", len(manifest_entries))

        for relative in manifest_entries:
            relative_path = Path(relative)
            local_path = checkpoint_dir / relative_path
            local_path.parent.mkdir(parents=True, exist_ok=True)
            download_file_from_s3(
                self.remote_up_down,
                self._remote_key(
                    chosen_relative_base / relative_path,
                    remote_root=remote_path,
                ),
                local_path,
            )
        logger.info(
            "Successfully downloaded all %d files for step %s",
            len(manifest_entries),
            step,
        )

        # Verify the checkpoint directory structure
        metadata_file = checkpoint_dir / ".metadata"
        distcp_shards = list(checkpoint_dir.glob("*.distcp"))
        if metadata_file.exists():
            logger.info("✓ Checkpoint metadata file exists: %s", metadata_file)
        elif distcp_shards:
            logger.info(
                "Checkpoint metadata file not found for step %s; detected %d distcp shard(s).",
                step,
                len(distcp_shards),
            )
        else:
            logger.error("✗ Checkpoint metadata file MISSING: %s", metadata_file)

        # List all downloaded files for verification
        all_files = list(checkpoint_dir.rglob("*"))
        logger.info(
            "Downloaded checkpoint contains %d total paths (files + directories)",
            len(all_files),
        )
        logger.info("Checkpoint directory structure:")
        for path in sorted(all_files)[:MAX_FILES_TO_DISPLAY]:
            logger.info("  - %s", path.relative_to(checkpoint_dir))
        if len(all_files) > MAX_FILES_TO_DISPLAY:
            logger.info("  ... and %d more", len(all_files) - MAX_FILES_TO_DISPLAY)

        self._latest_uploaded_step = max(self._latest_uploaded_step or -1, step)
        self._write_latest_marker(step)
        logger.info(
            "✓ Checkpoint download complete for step %s. Native checkpointer should now load it.",
            step,
        )


def setup_s3_checkpointing(
    checkpointer: CheckpointManager,
    job_config: MosaicJobConfig,
    *,
    install: bool = True,
) -> S3CheckpointWrapper | None:
    """Create an :class:`S3CheckpointWrapper` if configured.

    This helper is kept for backwards compatibility; new call sites should
    prefer :func:`get_s3_checkpoint_wrapper_factory` to obtain a wrapper
    factory explicitly. When ``install`` is ``True`` the wrapper is also
    patched onto the provided ``checkpointer`` so existing references keep the
    S3 synchronisation behaviour.
    """
    factory = get_s3_checkpoint_wrapper_factory(job_config)
    if factory is None:
        return None

    wrapper = factory(checkpointer, enable_uploads=install)
    if install:
        wrapper.install_onto_checkpointer()
    return wrapper


def get_s3_checkpoint_wrapper_factory(
    job_config: MosaicJobConfig,
) -> Callable[[CheckpointManager, bool], S3CheckpointWrapper] | None:
    """Return a factory producing :class:`S3CheckpointWrapper` instances.

    Args:
        job_config: The Mosaic job configuration.

    Returns:
        A callable that accepts a :class:`CheckpointManager` and a boolean flag
        indicating whether uploads should be enabled. ``None`` is returned when
        S3 checkpointing is disabled or misconfigured.
    """
    config = job_config.s3_checkpoint
    if not config.enable:
        return None
    if not config.bucket or config.prefix is None:
        logger.warning("S3 checkpointing is enabled but bucket or prefix is not provided; skipping.")
        return None

    def factory(
        checkpointer: CheckpointManager,
        *,
        enable_uploads: bool = True,
    ) -> S3CheckpointWrapper:
        return S3CheckpointWrapper(
            checkpointer,
            config,
            job_config,
            enable_uploads=enable_uploads,
        )

    return factory
