"""Checkpoint utilities for model saving and loading."""

import os
import json
import torch
from pathlib import Path
from typing import Dict, Any, Optional, Tuple


class CheckpointManager:
    """Manages model checkpoints and experiment state."""
    
    def __init__(self, save_dir: str, model_name: str = "model"):
        """Initialize checkpoint manager.
        
        Args:
            save_dir: Directory to save checkpoints
            model_name: Base name for checkpoint files
        """
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)
        self.model_name = model_name
        self.best_metric = float('inf')
        self.best_epoch = 0
    
    def save_checkpoint(
        self,
        epoch: int,
        model: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
        metrics: Optional[Dict[str, float]] = None,
        is_best: bool = False,
        filename: Optional[str] = None
    ) -> str:
        """Save model checkpoint.
        
        Args:
            epoch: Current epoch number
            model: Model to save
            optimizer: Optimizer state
            scheduler: Learning rate scheduler state
            metrics: Dictionary of metrics
            is_best: Whether this is the best checkpoint so far
            filename: Custom filename, defaults to auto-generated
            
        Returns:
            Path to saved checkpoint
        """
        if filename is None:
            filename = f"{self.model_name}_epoch_{epoch}.pth"
        
        checkpoint_path = self.save_dir / filename
        
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'metrics': metrics or {},
            'best_metric': self.best_metric,
            'best_epoch': self.best_epoch
        }
        
        if scheduler is not None:
            checkpoint['scheduler_state_dict'] = scheduler.state_dict()
        
        torch.save(checkpoint, checkpoint_path)
        
        # Save best checkpoint separately
        if is_best:
            best_path = self.save_dir / f"{self.model_name}_best.pth"
            torch.save(checkpoint, best_path)
            self.best_metric = metrics.get('loss', float('inf')) if metrics else float('inf')
            self.best_epoch = epoch
        
        print(f"Saved checkpoint: {checkpoint_path}")
        return str(checkpoint_path)
    
    def load_checkpoint(
        self,
        checkpoint_path: str,
        model: torch.nn.Module,
        optimizer: Optional[torch.optim.Optimizer] = None,
        scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
        device: str = 'cpu'
    ) -> Tuple[int, Dict[str, float]]:
        """Load model checkpoint.
        
        Args:
            checkpoint_path: Path to checkpoint file
            model: Model to load state into
            optimizer: Optimizer to load state into
            scheduler: Scheduler to load state into
            device: Device to load checkpoint on
            
        Returns:
            Tuple of (epoch, metrics)
        """
        checkpoint = torch.load(checkpoint_path, map_location=device)
        
        model.load_state_dict(checkpoint['model_state_dict'])
        
        if optimizer is not None and 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        if scheduler is not None and 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        epoch = checkpoint['epoch']
        metrics = checkpoint.get('metrics', {})
        
        if 'best_metric' in checkpoint:
            self.best_metric = checkpoint['best_metric']
            self.best_epoch = checkpoint['best_epoch']
        
        print(f"Loaded checkpoint from epoch {epoch}: {checkpoint_path}")
        return epoch, metrics
    
    def save_experiment_state(
        self,
        experiment_state: Dict[str, Any],
        filename: str = "experiment_state.json"
    ) -> str:
        """Save experiment state to JSON.
        
        Args:
            experiment_state: Dictionary of experiment state
            filename: Output filename
            
        Returns:
            Path to saved state file
        """
        state_path = self.save_dir / filename
        
        with open(state_path, 'w') as f:
            json.dump(experiment_state, f, indent=2)
        
        print(f"Saved experiment state: {state_path}")
        return str(state_path)
    
    def load_experiment_state(
        self,
        filename: str = "experiment_state.json"
    ) -> Dict[str, Any]:
        """Load experiment state from JSON.
        
        Args:
            filename: Input filename
            
        Returns:
            Dictionary of experiment state
        """
        state_path = self.save_dir / filename
        
        if not state_path.exists():
            print(f"Experiment state file not found: {state_path}")
            return {}
        
        with open(state_path, 'r') as f:
            experiment_state = json.load(f)
        
        print(f"Loaded experiment state: {state_path}")
        return experiment_state
    
    def get_latest_checkpoint(self) -> Optional[str]:
        """Get path to latest checkpoint.
        
        Returns:
            Path to latest checkpoint or None if not found
        """
        checkpoints = list(self.save_dir.glob(f"{self.model_name}_epoch_*.pth"))
        if not checkpoints:
            return None
        
        # Sort by epoch number
        checkpoints.sort(key=lambda x: int(x.stem.split('_')[-1]))
        return str(checkpoints[-1])
    
    def get_best_checkpoint(self) -> Optional[str]:
        """Get path to best checkpoint.
        
        Returns:
            Path to best checkpoint or None if not found
        """
        best_path = self.save_dir / f"{self.model_name}_best.pth"
        return str(best_path) if best_path.exists() else None
    
    def cleanup_old_checkpoints(self, keep_last: int = 5):
        """Remove old checkpoints, keeping only the most recent ones.
        
        Args:
            keep_last: Number of recent checkpoints to keep
        """
        checkpoints = list(self.save_dir.glob(f"{self.model_name}_epoch_*.pth"))
        if len(checkpoints) <= keep_last:
            return
        
        # Sort by epoch number and remove old ones
        checkpoints.sort(key=lambda x: int(x.stem.split('_')[-1]))
        for checkpoint in checkpoints[:-keep_last]:
            checkpoint.unlink()
            print(f"Removed old checkpoint: {checkpoint}")
