import json
import logging
import os
import tempfile
import time
from collections import OrderedDict
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import fsspec
import hydra
import omegaconf
import torch
from torch.distributed.checkpoint.state_dict import (
    get_state_dict,
    set_state_dict,
    StateDictOptions,
)
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint import StorageWriter
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
from torch.distributed.fsdp import (
    FullStateDictConfig,
    FullyShardedDataParallel as FSDP,
)
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer


@dataclass
class Snapshot:
    """Enhanced snapshot class with essential metadata."""

    model_state: "OrderedDict[str, torch.Tensor]"
    optimizer_state: Dict[str, Any]
    lr_scheduler_state: Dict[str, Any] | None
    finished_epoch: int
    n_seen_points: int = 0  # Track total samples seen during training
    config: Dict[str, Any] = field(default_factory=dict)
    tt_embedding_state: Optional["OrderedDict[str, torch.Tensor]"] = None  # TruthTableEncoder state


def save_snapshot(
    model: torch.nn.Module,
    optimizer: Optimizer,
    lr_scheduler: LRScheduler,
    epoch: int,
    path: str,
    n_seen_points: int = 0,
    config: Optional[Dict[str, Any]] = None,
    writer: StorageWriter | None = None,
    tt_embedding: Optional[torch.nn.Module] = None,
) -> None:
    """Save a snapshot with essential training metadata."""
    # Handle FSDP model state dict
    if isinstance(model, FSDP):
        state_dict, optimizer_state = get_state_dict(
            model, # The FSDP wrapped model
            optimizer,
        )
    else:
        state_dict = model.state_dict()
        optimizer_state = optimizer.state_dict()
        
    lr_scheduler_state = lr_scheduler.state_dict()
    
    # Get the TruthTableEncoder state if provided
    tt_embedding_state = None
    if tt_embedding is not None:
        tt_embedding_state = tt_embedding.state_dict()

    config_dict = {}
    if config is not None:
        if isinstance(config, omegaconf.DictConfig):
            config_dict = omegaconf.OmegaConf.to_container(config, resolve=True)
        elif isinstance(config, dict):
            config_dict = config
        else:
            raise ValueError("Config must be a DictConfig or a dictionary.")

    snapshot = Snapshot(
        model_state=state_dict,  # type: ignore
        optimizer_state=optimizer_state,
        lr_scheduler_state=lr_scheduler_state,
        finished_epoch=epoch,
        n_seen_points=n_seen_points,
        config=config_dict,
        tt_embedding_state=tt_embedding_state,
    )
    snapshot_dict = asdict(snapshot)

    # Determine the target directory for dcp.save by removing the extension from the input path
    dcp_target_checkpoint_id = os.path.splitext(path)[0]

    # Save with error handling
    try:
        # Attempt to save using dcp.save with the derived directory path
        logging.debug(f"Attempting to save checkpoint with dcp.save to: {dcp_target_checkpoint_id}")
        dcp.save(
            snapshot_dict,
            checkpoint_id=dcp_target_checkpoint_id, # Use the derived directory path
            storage_writer=writer,
        )
        logging.info(f"Checkpoint successfully saved with dcp.save to: {dcp_target_checkpoint_id}")
    except Exception as e:
        logging.error(f"Failed to save checkpoint with dcp.save to {dcp_target_checkpoint_id}: {e}")
        logging.info(f"Attempting fallback save with torch.save to: {path}")
        # Ensure directory for the .pt file (fallback path) exists
        os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
        try:
            torch.save(snapshot_dict, path)
            logging.info(f"Fallback save to {path} successful.")
        except Exception as e_fallback:
            logging.error(f"Fallback save to {path} also failed: {e_fallback}")
            raise e # Re-raise the original dcp.save exception


def load_snapshot(
    model_path: str,
    model: Optional[torch.nn.Module] = None,
    optimizer: Optional[Optimizer] = None,
    tt_embedding: Optional[torch.nn.Module] = None,
) -> Snapshot:
    """
    Load a snapshot, handling both directory (distributed) and file (standard) checkpoints.

    Args:
        model_path: Path to the checkpoint file or directory.
        model: Optional model instance to load state into.
        optimizer: Optional optimizer instance to load state into.

    Returns:
        A Snapshot object containing the loaded state and metadata.
    """
    snapshot_data: Dict[str, Any] = {}
    logger = logging.getLogger(__name__)

    if os.path.isdir(model_path):
        # First try to load the distributed checkpoint directly
        logger.info(f"Attempting to load distributed checkpoint directly: {model_path}")
        try:
            # Try to load directly with torch.distributed.checkpoint
            if model is not None:
                # Only attempt direct loading if model is provided
                stateful_objects = {"model": model}
                if optimizer is not None:
                    stateful_objects["optimizer"] = optimizer
                
                # Load state directly from distributed checkpoint
                dcp.load(
                    checkpoint_id=model_path,
                    stateful_objects=stateful_objects,
                )
                
                # Get state dictionaries after loading
                if isinstance(model, FSDP):
                    state_dict, optimizer_state = get_state_dict(
                        model,
                        optimizer if optimizer is not None else None,
                    )
                else:
                    state_dict = model.state_dict()
                    optimizer_state = optimizer.state_dict() if optimizer is not None else {}
                
                # Create snapshot data manually since we loaded directly
                snapshot_data = {
                    "model_state": state_dict,
                    "optimizer_state": optimizer_state,
                    # Check for metadata file for additional info
                    "finished_epoch": -1,  # Will be updated if metadata is found
                }
                
                # Try to load additional metadata like epoch from metadata file
                metadata_path = os.path.join(model_path, "__metadata__.pt")
                if os.path.exists(metadata_path):
                    with fsspec.open(metadata_path, "rb") as f:
                        metadata = torch.load(f, map_location="cpu")
                        # Update snapshot with metadata values
                        if isinstance(metadata, dict):
                            for key, value in metadata.items():
                                if key not in snapshot_data:
                                    snapshot_data[key] = value
                
                logger.info(f"Successfully loaded distributed checkpoint directly from: {model_path}")
            else:
                # If model is not provided, we need to convert to torch save format
                raise ValueError("Model is required for direct loading of distributed checkpoint")
                
        except Exception as e:
            # Fallback to converting the checkpoint using dcp_to_torch_save
            logger.warning(f"Direct loading of distributed checkpoint failed: {e}")
            logger.info(f"Falling back to converting distributed checkpoint: {model_path}")
            
            with tempfile.TemporaryDirectory() as tmpdir:
                temp_torch_save_path = os.path.join(tmpdir, "temp_checkpoint.pt")
                try:
                    dcp_to_torch_save(model_path, temp_torch_save_path)
                    logger.info(f"Loading converted checkpoint file: {temp_torch_save_path}")
                    # Load the converted file using torch.load
                    with fsspec.open(temp_torch_save_path, "rb") as f:
                        snapshot_data = torch.load(f, map_location="cpu", weights_only=False) # type: ignore
                except Exception as conv_e:
                    logger.error(f"Failed to convert or load distributed checkpoint from {model_path}: {conv_e}")
                    raise # Re-raise the exception after logging

    elif os.path.isfile(model_path):
        # Load file-based checkpoint using torch.load
        logger.info(f"Loading standard checkpoint file: {model_path}")
        with fsspec.open(model_path, "rb") as f: # Ensure binary mode 'rb'
            # Load onto CPU first to avoid device mismatches
            snapshot_data = torch.load(f, map_location="cpu", weights_only=False) # type: ignore
    else:
        raise FileNotFoundError(f"Checkpoint path not found or is not a file/directory: {model_path}")

    # Check if the loaded data has the expected structure (after loading either way)
    if not all(k in snapshot_data for k in ["model_state", "optimizer_state", "finished_epoch"]):
         logger.warning(f"Loaded checkpoint from {model_path} might have an unexpected structure: {snapshot_data.keys()}")
         # Attempt to adapt if possible, or raise error if critical keys are missing
         if "model_state" not in snapshot_data:
             raise ValueError(f"Critical key 'model_state' missing in loaded checkpoint from {model_path}")

    # If model and optimizer are provided but not already loaded directly via dcp.load
    if model is not None and optimizer is not None and (not os.path.isdir(model_path) or "model_state" in snapshot_data):
        # Handle FSDP model state dict loading specifically if model is FSDP
        # Note: This part might need adjustment if loading into an FSDP model
        # after converting the checkpoint, as the state dict format might differ.
        # For evaluation (loading into a non-FSDP model), this should be fine.
        if isinstance(model, FSDP):
            # Use the set_state_dict utility for FSDP models
            # This assumes snapshot_data["model_state"] is compatible after conversion
            set_state_dict(
                model=model,
                state_dict=snapshot_data["model_state"],
                options=StateDictOptions(full_state_dict=True) # Assuming we saved full state dict
            )
            # Load optimizer state separately for FSDP
            optim_state_dict_for_rank = FSDP.optim_state_dict_to_load(
                 model, optimizer, snapshot_data["optimizer_state"]
            )
            optimizer.load_state_dict(optim_state_dict_for_rank)

        else:
            # Standard model loading
            model.load_state_dict(snapshot_data["model_state"])
            optimizer.load_state_dict(snapshot_data["optimizer_state"])
            
        # Load TruthTableEncoder state if available and an encoder was provided
        if tt_embedding is not None and "tt_embedding_state" in snapshot_data and snapshot_data["tt_embedding_state"] is not None:
            tt_embedding.load_state_dict(snapshot_data["tt_embedding_state"])
            logger.info("TruthTableEncoder state loaded from checkpoint")

    # Construct and return the Snapshot object
    return Snapshot(
        model_state=snapshot_data["model_state"],
        optimizer_state=snapshot_data["optimizer_state"],
        lr_scheduler_state=snapshot_data.get("lr_scheduler_state"), # Use .get for optional keys
        finished_epoch=snapshot_data["finished_epoch"],
        n_seen_points=snapshot_data.get("n_seen_points", 0),
        config=snapshot_data.get("config", {}),
        tt_embedding_state=snapshot_data.get("tt_embedding_state", None),
    )


def load_model_with_config_fallback(
    checkpoint_path: str,
    device: torch.device,
    fallback_config_path: Optional[str] = None,
    fallback_wandb_run_path: Optional[str] = None,
    tt_embedding: Optional[torch.nn.Module] = None,
) -> Tuple[torch.nn.Module, omegaconf.DictConfig, Snapshot]:
    """
    Loads a model snapshot, retrieves its configuration (with fallbacks),
    instantiates the model, and loads the state dict.
    
    This function tries to instantiate the model first from the snapshot config. If that fails,
    it will try each fallback config option sequentially until one succeeds.

    Args:
        checkpoint_path: Path to the checkpoint file or directory.
        device: The device to load the model onto.
        fallback_config_path: Optional path to a YAML config file if config not in snapshot.
        fallback_wandb_run_path: Optional Wandb run path if config not in snapshot or YAML.

    Returns:
        A tuple containing:
            - The instantiated model with loaded weights.
            - The model configuration (DictConfig) used for instantiation.
            - The loaded Snapshot object.

    Raises:
        ValueError: If the model configuration cannot be retrieved from any source.
        FileNotFoundError: If the checkpoint path is invalid.
        RuntimeError: If all model instantiation attempts fail.
    """
    logger = logging.getLogger(__name__)

    # 1. Load snapshot
    logger.info(f"Loading snapshot from: {checkpoint_path}")
    snapshot = load_snapshot(model_path=checkpoint_path)
    logger.info(f"Snapshot loaded successfully from epoch {snapshot.finished_epoch}")

    # 2. Collect all potential model configurations to try
    config_sources = []
    
    # 2a. Try getting config from snapshot first
    if snapshot.config and isinstance(snapshot.config, dict) and 'model' in snapshot.config:
        try:
            print(snapshot.config['model'])
            snapshot_cfg = omegaconf.OmegaConf.create(snapshot.config['model'])
            config_sources.append(("snapshot", snapshot_cfg))
            logger.info("Model configuration retrieved from snapshot.")
        except Exception as e:
            logger.warning(f"Could not load model config from snapshot: {e}")
    else:
        logger.warning("Model config not found or invalid in snapshot.")
    
    # 2b. Try getting config from YAML path
    if fallback_config_path:
        logger.info(f"Reading model config from YAML: {fallback_config_path}")
        try:
            full_config = omegaconf.OmegaConf.load(fallback_config_path)
            if 'model' in full_config:
                config_sources.append((f"YAML ({fallback_config_path})", full_config.model))
                logger.info("Model configuration loaded from YAML file.")
            else:
                logger.warning(f"'model' key not found in YAML config: {fallback_config_path}")
        except Exception as e:
            logger.warning(f"Failed to load model config from YAML {fallback_config_path}: {e}")
    
    # 2c. Try getting config from Wandb
    if fallback_wandb_run_path:
        logger.info(f"Reading model config from Wandb run: {fallback_wandb_run_path}")
        try:
            import wandb
            api = wandb.Api()
            wandb_run = api.run(fallback_wandb_run_path)
            wandb_config_dict = dict(wandb_run.config)
            if 'model' in wandb_config_dict:
                 wandb_cfg = omegaconf.OmegaConf.create(wandb_config_dict['model'])
                 config_sources.append((f"Wandb ({fallback_wandb_run_path})", wandb_cfg))
                 logger.info("Model configuration loaded from Wandb.")
            else:
                 logger.warning(f"Could not find 'model' structure in Wandb config for run: {fallback_wandb_run_path}")
        except Exception as e:
            logger.warning(f"Failed to load model config from Wandb run {fallback_wandb_run_path}: {e}")
    
    # 2d. Check if we have any configurations to try
    if not config_sources:
        raise ValueError(
            "Could not retrieve model configuration from snapshot, YAML, or Wandb. "
            "Please provide a valid snapshot with config, fallback_config_path, or fallback_wandb_run_path."
        )
    
    # 3. Try instantiating the model with each config until one succeeds
    last_error = None
    for config_source, model_definition_cfg in config_sources:
        try:
            logger.info(f"Attempting to instantiate model with config from: {config_source}")
            # Instantiate the model (this does not load weights yet)
            model_to_load = hydra.utils.instantiate(model_definition_cfg) # model_definition_cfg is cfg.model
            if tt_embedding is not None and hasattr(model_to_load, "truth_table_encoder"):
                model_to_load.truth_table_encoder = tt_embedding
            logger.info(f"Model {model_to_load.__class__.__name__} instantiated successfully.")

            # Now, load the state dict into the instantiated model
            if isinstance(model_to_load, FSDP):
                logger.info("Target model is FSDP. Using set_state_dict.")
                set_state_dict(
                    model=model_to_load,
                    state_dict=snapshot.model_state, # type: ignore
                    options=StateDictOptions(broadcast_from_rank0=True), 
                )
            else:
                logger.info("Target model is not FSDP. Using model.load_state_dict.")
                print("~" * 20)
                print(snapshot.model_state.keys())
                model_to_load.load_state_dict(snapshot.model_state) # type: ignore

            logger.info(f"Successfully loaded state_dict into {model_to_load.__class__.__name__} from snapshot.")
            model_to_load.to(device) # Move model to target device after loading state
            logger.info(f"Model moved to device: {device}")
            return model_to_load, model_definition_cfg, snapshot # Success
        except Exception as e:
            model_name_for_log = "UnknownModel (instantiation may have failed or occurred outside this direct try-block)"
            current_model_object = None

            # Safely check if model_to_load was instantiated and is available
            if 'model_to_load' in locals() and locals()['model_to_load'] is not None:
                current_model_object = locals()['model_to_load']
                if hasattr(current_model_object, '__class__') and hasattr(current_model_object.__class__, '__name__'):
                    model_name_for_log = current_model_object.__class__.__name__
            
            logger.error(
                f"Error encountered for model '{model_name_for_log}' with config from '{config_source}'. Exception: {e}"
            )
            last_error = e
    
    raise RuntimeError(
        f"Failed to instantiate model with any of the available configurations. "
        f"Last error: {last_error}"
    )


class CheckpointManager:
    def __init__(
        self,
        save_dir: str,
        model: torch.nn.Module,
        optimizer: Optimizer,
        lr_scheduler: LRScheduler,
        logger=None,  # BaseLogger type
        keep_best: int = 3,
        keep_last: int = 2,
        save_freq: int = 1,
        metric_name: str = "eval/Accuracy",
        mode: str = "max",
        checkpoint_format: str = "checkpoint_epoch_{epoch}",
        config: Optional[Dict[str, Any]] = None,
        tt_embedding: Optional[torch.nn.Module] = None,
    ):
        """
        Enhanced checkpoint manager with config storage.

        Args:
            save_dir: Directory to save checkpoints
            model: The model to checkpoint
            optimizer: The optimizer to checkpoint
            lr_scheduler: Optional learning rate scheduler
            logger: Logger for experiment tracking
            keep_best: Number of best checkpoints to keep
            keep_last: Number of most recent checkpoints to keep
            save_freq: Save frequency in epochs
            metric_name: Metric to track for best checkpoints
            mode: "min" or "max" for the metric
            checkpoint_format: Format string for checkpoint filenames
            config: Training configuration to save with checkpoints
        """
        self.save_dir = Path(save_dir)
        self.model = model
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.logger = logger
        self.keep_best = keep_best
        self.keep_last = keep_last
        self.save_freq = save_freq
        self.metric_name = metric_name
        self.mode = mode
        self.checkpoint_format = checkpoint_format
        self.config = config or {}
        self.checkpoint_future = None
        self.writer = None
        self.tt_embedding = tt_embedding

        # Logger
        self.console_logger = logging.getLogger(__name__)

        # Create directory if it doesn't exist
        self.save_dir.mkdir(parents=True, exist_ok=True)

        # Track best checkpoints and latest checkpoints
        self.best_checkpoints: List[Dict[str, Any]] = []
        self.latest_checkpoints: List[Dict[str, Any]] = []

        # Load existing checkpoint registry if any
        self._load_checkpoint_registry()

    def _load_checkpoint_registry(self) -> None:
        """Load existing checkpoint registry from the save directory."""
        registry_path = self.save_dir / "checkpoint_registry.json"
        if registry_path.exists():
            with open(registry_path, "r") as f:
                try:
                    registry = json.load(f)
                    self.best_checkpoints = registry.get("best_checkpoints", [])
                    self.latest_checkpoints = registry.get("latest_checkpoints", [])
                    self.console_logger.info(
                        f"Loaded checkpoint registry with {len(self.latest_checkpoints)} checkpoints"
                    )
                except json.JSONDecodeError:
                    self.console_logger.warning(
                        "Failed to load checkpoint registry, starting fresh"
                    )

    def _save_checkpoint_registry(self) -> None:
        """Save the checkpoint registry to disk."""
        registry_path = self.save_dir / "checkpoint_registry.json"
        registry = {
            "best_checkpoints": self.best_checkpoints,
            "latest_checkpoints": self.latest_checkpoints,
        }
        with open(registry_path, "w") as f:
            json.dump(registry, f, indent=2)

    def _is_better(self, current_metric: float, best_metric: float) -> bool:
        """Check if the current metric is better than the best metric."""
        if self.mode == "min":
            return current_metric < best_metric
        return current_metric > best_metric

    def save_checkpoint(
        self,
        epoch: int,
        metrics: Optional[Dict[str, float]] = None,
        n_seen_points: int = 0,
    ) -> str:
        """
        Save a checkpoint and update registries.

        Args:
            epoch: Current epoch number
            metrics: Dictionary of metrics to track (used only for registry, not stored in checkpoint)
            n_seen_points: Number of training samples seen so far

        Returns:
            Path to the saved checkpoint
        """
        # Only save if it's time to save or metrics dict is provided (explicit save)
        if epoch % self.save_freq != 0 and metrics is None:
            return ""

        # Generate checkpoint name
        checkpoint_name = self.checkpoint_format.format(epoch=epoch)
        checkpoint_path = str(self.save_dir / checkpoint_name)

        # Save the checkpoint
        self.console_logger.info(f"Saving checkpoint to {checkpoint_path}")
        save_snapshot(
            model=self.model,
            optimizer=self.optimizer,
            lr_scheduler=self.lr_scheduler,
            epoch=epoch,
            path=checkpoint_path,
            n_seen_points=n_seen_points,
            config=self.config,
            writer=self.writer,
            tt_embedding=self.tt_embedding,
        )

        # Determine the type of checkpoint that was saved
        # If checkpoint_path has a .pt extension, it's a file checkpoint
        # Otherwise, it's a directory checkpoint (from distributed save)
        is_file_checkpoint = checkpoint_path.endswith(".pt")
        checkpoint_type = "file" if is_file_checkpoint else "directory"

        # Log the save event to experiment tracker
        if self.logger:
            try:
                self.logger.log_param(f"checkpoint_epoch_{epoch}", checkpoint_path)
                self.logger.log_param(f"checkpoint_type_{epoch}", checkpoint_type)
            except Exception as e:
                self.console_logger.warning(
                    f"Failed to log checkpoint to experiment tracker: {e}"
                )

        # Add to latest checkpoints with any metrics for tracking best models
        checkpoint_info = {
            "path": checkpoint_path,
            "epoch": epoch,
            "metrics": metrics or {},
            "n_seen_points": n_seen_points,
            "type": checkpoint_type,
            "timestamp": time.time(),
        }

        self.latest_checkpoints.append(checkpoint_info)

        # Keep only the latest N checkpoints
        self._cleanup_old_checkpoints()

        # Update best checkpoints if metrics are provided
        if metrics and self.metric_name in metrics:
            self._update_best_checkpoints(checkpoint_info)

        # Save the updated registry
        self._save_checkpoint_registry()

        return checkpoint_path

    def _update_best_checkpoints(self, checkpoint_info: Dict[str, Any]) -> None:
        """Update the registry of best checkpoints based on the metric."""
        current_metric = checkpoint_info["metrics"][self.metric_name]

        # Initialize default comparison value based on mode
        default_value = float("inf") if self.mode == "min" else -float("inf")

        # Check if this checkpoint should be in best checkpoints
        should_add = False

        # If we don't have enough best checkpoints yet, add it
        if len(self.best_checkpoints) < self.keep_best:
            should_add = True
        else:
            # Check if it's better than the worst best checkpoint
            worst_best_ckpt = self.best_checkpoints[-1]
            worst_metric = worst_best_ckpt["metrics"].get(
                self.metric_name, default_value
            )
            if self._is_better(current_metric, worst_metric):
                should_add = True

        if should_add:
            # Add to best checkpoints
            self.best_checkpoints.append(checkpoint_info)

            # Sort best checkpoints
            self.best_checkpoints.sort(
                key=lambda x: x["metrics"].get(self.metric_name, default_value),
                reverse=(self.mode == "max"),
            )

            # Keep only top K
            while len(self.best_checkpoints) > self.keep_best:
                removed = self.best_checkpoints.pop()
                # Don't delete the file yet - it might still be in latest checkpoints

    def _cleanup_old_checkpoints(self) -> None:
        """Remove old checkpoints that are no longer needed."""
        # First, keep only the most recent checkpoints
        while len(self.latest_checkpoints) > self.keep_last:
            old_checkpoint = self.latest_checkpoints.pop(0)

            # Check if this checkpoint is also in best checkpoints
            in_best = any(
                old_checkpoint["path"] == ckpt["path"] for ckpt in self.best_checkpoints
            )

            # Delete the checkpoint if it's not in best checkpoints
            if not in_best:
                try:
                    path = Path(old_checkpoint["path"])
                    
                    # Handle different types of checkpoints (file or directory)
                    is_directory = old_checkpoint.get("type") == "directory" or path.is_dir()
                    
                    if is_directory:
                        # For distributed checkpoints (directories)
                        import shutil
                        shutil.rmtree(path)
                        self.console_logger.info(
                            f"Removed distributed checkpoint directory: {path}"
                        )
                    else:
                        # For traditional checkpoints (files)
                        os.remove(path)
                        self.console_logger.info(
                            f"Removed traditional checkpoint file: {path}"
                        )
                except Exception as e:
                    self.console_logger.warning(f"Failed to remove old checkpoint: {e}")

    def load_latest_checkpoint(self) -> Optional[Tuple[Snapshot, int]]:
        """
        Load the latest checkpoint if available.

        Returns:
            Tuple of (snapshot, epoch) if successful, None otherwise
        """
        if not self.latest_checkpoints:
            self.console_logger.info("No latest checkpoint found")
            return None

        latest_checkpoint = self.latest_checkpoints[-1]
        return self._load_checkpoint(latest_checkpoint)

    def load_best_checkpoint(self) -> Optional[Tuple[Snapshot, int]]:
        """
        Load the best checkpoint if available.

        Returns:
            Tuple of (snapshot, epoch) if successful, None otherwise
        """
        if not self.best_checkpoints:
            self.console_logger.info("No best checkpoint found")
            return None

        best_checkpoint = self.best_checkpoints[0]  # Best is at index 0
        return self._load_checkpoint(best_checkpoint)

    def _load_checkpoint(
        self, checkpoint_info: Dict[str, Any]
    ) -> Optional[Tuple[Snapshot, int]]:
        """Helper to load a checkpoint from info dict."""
        try:
            path = checkpoint_info["path"]
            self.console_logger.info(f"Loading checkpoint: {path}")
            
            # Determine if it's a file or directory checkpoint
            is_directory = checkpoint_info.get("type") == "directory" or Path(path).is_dir()
            if is_directory:
                self.console_logger.info(f"Loading distributed checkpoint from directory: {path}")
            else:
                self.console_logger.info(f"Loading standard checkpoint file: {path}")
                
            snapshot = load_snapshot(path, self.model, self.optimizer, self.tt_embedding)

            # Log to experiment tracker
            if self.logger:
                try:
                    self.logger.log_param("loaded_checkpoint", path)
                    self.logger.log_param("loaded_checkpoint_epoch", checkpoint_info["epoch"])
                    self.logger.log_param("loaded_checkpoint_type", "distributed" if is_directory else "file")
                    
                    if "metrics" in checkpoint_info and checkpoint_info["metrics"]:
                        best_metric = checkpoint_info["metrics"].get(self.metric_name)
                        if best_metric:
                            self.logger.log_param("loaded_checkpoint_metric", best_metric)
                except Exception as log_error:
                    self.console_logger.warning(
                        f"Failed to log checkpoint loading to experiment tracker: {log_error}"
                    )

            return snapshot, checkpoint_info["epoch"]
        except Exception as e:
            self.console_logger.error(f"Failed to load checkpoint: {e}")
            return None

    def restore(self, snapshot: Snapshot) -> int:
        """
        Restore model, optimizer and scheduler state from a snapshot.

        Args:
            snapshot: Snapshot to restore from

        Returns:
            The epoch number from the snapshot
        """
        # Handle FSDP model state dict
        if isinstance(self.model, FSDP):
            torch.distributed.checkpoint.state_dict.set_state_dict(
                self.model,
                snapshot.model_state,
                full_state_dict_config=FullStateDictConfig(
                    offload_to_cpu=True,
                    rank0_only=True,
                ),
            )
            FSDP.optim_state_dict_to_load(
                self.model, self.optimizer, snapshot.optimizer_state
            )
        else:
            self.model.load_state_dict(snapshot.model_state)
            self.optimizer.load_state_dict(snapshot.optimizer_state)

        if self.lr_scheduler and snapshot.lr_scheduler_state:
            self.lr_scheduler.load_state_dict(snapshot.lr_scheduler_state)
            
        # Handle TruthTableEncoder state if present in snapshot and we have a TruthTableEncoder
        if self.tt_embedding is not None and snapshot.tt_embedding_state is not None:
            self.tt_embedding.load_state_dict(snapshot.tt_embedding_state)
            self.console_logger.info("Restored TruthTableEncoder state from snapshot")
        elif snapshot.tt_embedding_state is not None:
            self.console_logger.info("TruthTableEncoder state found in snapshot but no encoder provided to CheckpointManager")

        self.console_logger.info(
            f"Restored from checkpoint at epoch {snapshot.finished_epoch}"
        )

        # Log restored config if available
        if hasattr(snapshot, "config") and snapshot.config and self.logger:
            try:
                for k, v in snapshot.config.items():
                    self.logger.log_param(f"restored_config.{k}", v)
            except:
                pass

        return snapshot.finished_epoch

    def print_checkpoint_info(self) -> None:
        """Print information about available checkpoints."""
        print("=== Latest Checkpoints ===")
        for ckpt in self.latest_checkpoints:
            metrics_str = ", ".join(
                f"{k}: {v:.4f}" for k, v in ckpt.get("metrics", {}).items()
            )
            timestamp = ckpt.get("timestamp", 0)
            n_seen = ckpt.get("n_seen_points", 0)
            if timestamp:
                import datetime

                time_str = datetime.datetime.fromtimestamp(timestamp).strftime(
                    "%Y-%m-%d %H:%M:%S"
                )
                print(
                    f"Epoch {ckpt['epoch']} (samples: {n_seen:,}): {ckpt['path']} | {time_str} | {metrics_str}"
                )
            else:
                print(
                    f"Epoch {ckpt['epoch']} (samples: {n_seen:,}): {ckpt['path']} | {metrics_str}"
                )

        print("\n=== Best Checkpoints ===")
        for i, ckpt in enumerate(self.best_checkpoints):
            metrics_str = ", ".join(
                f"{k}: {v:.4f}" for k, v in ckpt.get("metrics", {}).items()
            )
            n_seen = ckpt.get("n_seen_points", 0)
            print(
                f"#{i+1} Epoch {ckpt['epoch']} (samples: {n_seen:,}): {ckpt['path']} | {metrics_str}"
            )

    def scan_checkpoint_directory(
        self, directory_path: str | None = None
    ) -> Dict[str, List[Dict[str, Any]]]:
        """
        Scan a directory for checkpoint files (.pt) and distributed checkpoint directories 
        and build a temporary registry.

        Args:
            directory_path: Directory to scan (uses self.save_dir if None)

        Returns:
            Dict with 'latest_checkpoints' and empty 'best_checkpoints'
        """
        directory = Path(directory_path) if directory_path else self.save_dir
        if not directory.exists():
            self.console_logger.warning(f"Directory {directory} does not exist")
            return {"latest_checkpoints": [], "best_checkpoints": []}

        checkpoint_infos = []
        
        # 1. Look for traditional checkpoint files (.pt) matching pattern
        checkpoint_files = list(directory.glob("checkpoint_epoch_*.pt"))
        for file_path in checkpoint_files:
            try:
                # Extract epoch number from filename
                filename = file_path.name
                epoch_str = filename.split("checkpoint_epoch_")[1].split(".pt")[0]
                epoch = int(epoch_str)

                checkpoint_infos.append({
                    "path": str(file_path),
                    "epoch": epoch,
                    "metrics": {},  # No metrics available without registry
                    "type": "file",
                })
                self.console_logger.debug(f"Found traditional checkpoint file: {file_path} (epoch {epoch})")
            except (ValueError, IndexError):
                self.console_logger.warning(f"Could not parse epoch from {file_path}")
        
        # 2. Look for distributed checkpoint directories matching pattern
        # These are directories without .pt extension, created by dcp.save
        all_items = list(directory.glob("checkpoint_epoch_*"))
        checkpoint_dirs = [d for d in all_items if d.is_dir() and not str(d).endswith(".pt")]
        
        for dir_path in checkpoint_dirs:
            try:
                # Extract epoch number from directory name
                dirname = dir_path.name
                epoch_str = dirname.split("checkpoint_epoch_")[1]
                epoch = int(epoch_str)
                
                # Check if it's a valid DCP checkpoint by looking for metadata file
                metadata_file = dir_path / "__metadata__.pt"
                if metadata_file.exists():
                    checkpoint_infos.append({
                        "path": str(dir_path),
                        "epoch": epoch,
                        "metrics": {},  # No metrics available without registry
                        "type": "directory",
                    })
                    self.console_logger.debug(f"Found distributed checkpoint directory: {dir_path} (epoch {epoch})")
                else:
                    self.console_logger.warning(f"Directory {dir_path} doesn't appear to be a valid distributed checkpoint")
            except (ValueError, IndexError):
                self.console_logger.warning(f"Could not parse epoch from directory {dir_path}")

        # Sort by epoch
        checkpoint_infos.sort(key=lambda x: x["epoch"])
        
        self.console_logger.info(f"Found {len(checkpoint_infos)} checkpoints in {directory}")
        return {
            "latest_checkpoints": checkpoint_infos,
            "best_checkpoints": [],  # No best checkpoints without metrics
        }

    def load_checkpoint_from_path(
        self, checkpoint_path: str
    ) -> Optional[Tuple[Snapshot, int]]:
        """
        Load a checkpoint from a specific path.

        Args:
            checkpoint_path: Path to the checkpoint file or directory

        Returns:
            Tuple of (snapshot, epoch) if successful, None otherwise
        """
        try:
            self.console_logger.info(f"Loading checkpoint from path: {checkpoint_path}")
            path = Path(checkpoint_path)

            # Check if path exists (either as file or directory)
            if not path.exists():
                self.console_logger.error(f"Checkpoint path not found: {checkpoint_path}")
                return None

            # Determine if this is a file or directory checkpoint
            is_directory_checkpoint = path.is_dir()
            
            # Extract epoch from filename/dirname if possible
            try:
                name = path.name
                if is_directory_checkpoint:
                    # For directory checkpoints, the epoch is in the directory name
                    epoch_str = name.split("checkpoint_epoch_")[1]
                else:
                    # For file checkpoints, extract from filename.pt
                    epoch_str = name.split("checkpoint_epoch_")[1].split(".pt")[0]
                epoch = int(epoch_str)
            except (ValueError, IndexError):
                self.console_logger.warning(f"Could not extract epoch from path name: {name}")
                epoch = -1  # Unknown epoch

            # Load the snapshot
            self.console_logger.info(f"Loading {'distributed' if is_directory_checkpoint else 'file'} checkpoint")
            snapshot = load_snapshot(str(path), self.model, self.optimizer, self.tt_embedding)

            # Use snapshot's epoch if available, otherwise use extracted epoch
            effective_epoch = (
                snapshot.finished_epoch
                if hasattr(snapshot, "finished_epoch") and snapshot.finished_epoch is not None
                else epoch
            )

            # Log to experiment tracker
            if self.logger:
                try:
                    self.logger.log_param("loaded_checkpoint", checkpoint_path)
                    self.logger.log_param("loaded_checkpoint_epoch", effective_epoch)
                    self.logger.log_param("loaded_checkpoint_type", 
                                         "distributed" if is_directory_checkpoint else "file")
                except Exception as log_error:
                    self.console_logger.warning(
                        f"Failed to log checkpoint loading to experiment tracker: {log_error}"
                    )

            self.console_logger.info(f"Successfully loaded checkpoint from epoch {effective_epoch}")
            return snapshot, effective_epoch
        except Exception as e:
            self.console_logger.error(
                f"Failed to load checkpoint from {checkpoint_path}: {e}"
            )
            return None

    def load_from_directory(
        self, directory_path: str, mode: str = "latest"
    ) -> Optional[Tuple[Snapshot, int]]:
        """
        Load a checkpoint from a directory.

        Args:
            directory_path: Path to directory containing checkpoints
            mode: "latest" to load the most recent checkpoint, "best" to load registry's best checkpoint

        Returns:
            Tuple of (snapshot, epoch) if successful, None otherwise
        """
        directory = Path(directory_path)

        # Check if directory exists
        if not directory.exists():
            self.console_logger.error(f"Directory not found: {directory_path}")
            return None

        # Check if registry exists
        registry_path = directory / "checkpoint_registry.json"
        if registry_path.exists():
            # Temporarily change save_dir to load from this registry
            original_save_dir = self.save_dir
            self.save_dir = directory

            # Load registry from the directory
            self._load_checkpoint_registry()

            # Load checkpoint based on mode
            if mode == "latest":
                result = self.load_latest_checkpoint()
            elif mode == "best":
                result = self.load_best_checkpoint()
            else:
                self.console_logger.error(f"Invalid mode: {mode}")
                result = None

            # Restore original save_dir
            self.save_dir = original_save_dir
            # Reload the original registry
            self._load_checkpoint_registry()

            return result
        else:
            # No registry found, scan directory for checkpoints (both file and directory types)
            self.console_logger.info(
                f"No registry found in {directory_path}, scanning for checkpoints"
            )
            registry = self.scan_checkpoint_directory(directory_path)

            if not registry["latest_checkpoints"]:
                self.console_logger.error(f"No checkpoints found in {directory_path}")
                return None

            if mode == "latest":
                # Get the checkpoint with the highest epoch
                checkpoint_info = max(
                    registry["latest_checkpoints"], key=lambda x: x["epoch"]
                )
                self.console_logger.info(f"Loading latest checkpoint: {checkpoint_info['path']} (epoch {checkpoint_info['epoch']})")
                return self._load_checkpoint(checkpoint_info)
            elif mode == "best":
                self.console_logger.warning(
                    f"Cannot load best checkpoint without registry, loading latest instead"
                )
                checkpoint_info = max(
                    registry["latest_checkpoints"], key=lambda x: x["epoch"]
                )
                self.console_logger.info(f"Loading latest checkpoint instead of best: {checkpoint_info['path']} (epoch {checkpoint_info['epoch']})")
                return self._load_checkpoint(checkpoint_info)
            else:
                self.console_logger.error(f"Invalid mode: {mode}")
                return None
