"""
Model checkpoint saving and loading utilities.
"""

import os
import torch
from pathlib import Path
from typing import Dict, Any, Optional
import shutil


class CheckpointManager:
    """
    Manage model checkpoints with automatic saving and loading.
    """
    
    def __init__(
        self,
        checkpoint_dir: str,
        experiment_name: str,
        max_to_keep: int = 5,
        keep_best: bool = True
    ):
        """
        Initialize checkpoint manager.
        
        Args:
            checkpoint_dir: Directory to save checkpoints
            experiment_name: Name of experiment
            max_to_keep: Maximum number of checkpoints to keep
            keep_best: Whether to always keep the best checkpoint
        """
        self.checkpoint_dir = Path(checkpoint_dir) / experiment_name
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        
        self.max_to_keep = max_to_keep
        self.keep_best = keep_best
        
        self.checkpoints = []  # List of (iteration, metric_value, path)
        self.best_checkpoint = None
        self.best_metric_value = float('inf')
        
        print(f"✓ CheckpointManager initialized: {self.checkpoint_dir}")
    
    def save(
        self,
        model: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        scheduler: Optional[Any],
        iteration: int,
        config: Dict[str, Any],
        metrics: Optional[Dict[str, float]] = None,
        is_best: bool = False
    ) -> str:
        """
        Save a checkpoint.
        
        Args:
            model: Model to save
            optimizer: Optimizer state
            scheduler: Learning rate scheduler state
            iteration: Current iteration
            config: Configuration dictionary
            metrics: Optional metrics dictionary
            is_best: Whether this is the best checkpoint so far
        
        Returns:
            Path to saved checkpoint
        """
        checkpoint_path = self.checkpoint_dir / f"checkpoint_iter_{iteration}.pt"
        
        # Prepare checkpoint dictionary
        checkpoint = {
            'iteration': iteration,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'config': config,
        }
        
        if scheduler is not None:
            checkpoint['scheduler_state_dict'] = scheduler.state_dict()
        
        if metrics is not None:
            checkpoint['metrics'] = metrics
        
        # Save checkpoint
        torch.save(checkpoint, checkpoint_path)
        
        # Track checkpoint
        metric_value = metrics.get('val_loss', 0.0) if metrics else 0.0
        self.checkpoints.append((iteration, metric_value, str(checkpoint_path)))
        
        # Handle best checkpoint
        if is_best or (metrics and metric_value < self.best_metric_value):
            self.best_metric_value = metric_value
            self.best_checkpoint = str(checkpoint_path)
            
            # Create symlink to best checkpoint
            best_link = self.checkpoint_dir / "best_checkpoint.pt"
            if best_link.exists() or best_link.is_symlink():
                best_link.unlink()
            # Copy instead of symlink for better compatibility
            shutil.copy(checkpoint_path, best_link)
            
            print(f"✓ New best checkpoint saved (metric={metric_value:.6f})")
        
        # Clean up old checkpoints
        self._cleanup_checkpoints()
        
        print(f"✓ Checkpoint saved: {checkpoint_path}")
        return str(checkpoint_path)
    
    def _cleanup_checkpoints(self):
        """Remove old checkpoints to keep only max_to_keep."""
        if len(self.checkpoints) <= self.max_to_keep:
            return
        
        # Sort by metric value (lower is better)
        sorted_checkpoints = sorted(self.checkpoints, key=lambda x: x[1])
        
        # Keep best checkpoints and most recent ones
        to_keep = set()
        
        # Keep best checkpoints
        for i in range(min(self.max_to_keep // 2, len(sorted_checkpoints))):
            to_keep.add(sorted_checkpoints[i][2])
        
        # Keep most recent checkpoints
        recent_checkpoints = sorted(self.checkpoints, key=lambda x: x[0], reverse=True)
        for i in range(min(self.max_to_keep - len(to_keep), len(recent_checkpoints))):
            to_keep.add(recent_checkpoints[i][2])
        
        # Always keep best checkpoint
        if self.keep_best and self.best_checkpoint:
            to_keep.add(self.best_checkpoint)
        
        # Remove checkpoints not in keep set
        checkpoints_to_remove = []
        for iteration, metric, path in self.checkpoints:
            if path not in to_keep:
                if os.path.exists(path):
                    os.remove(path)
                    print(f"✓ Removed old checkpoint: {path}")
                checkpoints_to_remove.append((iteration, metric, path))
        
        # Update checkpoint list
        for ckpt in checkpoints_to_remove:
            self.checkpoints.remove(ckpt)
    
    def load(
        self,
        model: torch.nn.Module,
        optimizer: Optional[torch.optim.Optimizer] = None,
        scheduler: Optional[Any] = None,
        checkpoint_path: Optional[str] = None,
        load_best: bool = False,
        device: str = 'cpu'
    ) -> Dict[str, Any]:
        """
        Load a checkpoint.
        
        Args:
            model: Model to load weights into
            optimizer: Optional optimizer to load state into
            scheduler: Optional scheduler to load state into
            checkpoint_path: Path to checkpoint file (if None, loads latest or best)
            load_best: Whether to load best checkpoint
            device: Device to load checkpoint to
        
        Returns:
            Checkpoint dictionary with metadata
        """
        # Determine which checkpoint to load
        if checkpoint_path is None:
            if load_best:
                checkpoint_path = self.checkpoint_dir / "best_checkpoint.pt"
            else:
                # Load latest checkpoint
                checkpoints = list(self.checkpoint_dir.glob("checkpoint_iter_*.pt"))
                if not checkpoints:
                    raise FileNotFoundError(f"No checkpoints found in {self.checkpoint_dir}")
                checkpoint_path = max(checkpoints, key=lambda p: int(p.stem.split('_')[-1]))
        
        checkpoint_path = Path(checkpoint_path)
        if not checkpoint_path.exists():
            raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
        
        print(f"Loading checkpoint: {checkpoint_path}")
        
        # Load checkpoint
        checkpoint = torch.load(checkpoint_path, map_location=device)
        
        # Load model state
        model.load_state_dict(checkpoint['model_state_dict'])
        
        # Load optimizer state
        if optimizer is not None and 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        # Load scheduler state
        if scheduler is not None and 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        print(f"✓ Checkpoint loaded from iteration {checkpoint['iteration']}")
        
        return checkpoint
    
    def get_latest_checkpoint(self) -> Optional[str]:
        """Get path to latest checkpoint."""
        checkpoints = list(self.checkpoint_dir.glob("checkpoint_iter_*.pt"))
        if not checkpoints:
            return None
        return str(max(checkpoints, key=lambda p: int(p.stem.split('_')[-1])))
    
    def get_best_checkpoint(self) -> Optional[str]:
        """Get path to best checkpoint."""
        best_path = self.checkpoint_dir / "best_checkpoint.pt"
        return str(best_path) if best_path.exists() else None