"""
Custom callbacks for GLEAM-AI training.

This module contains custom PyTorch Lightning callbacks for monitoring
and controlling the training process.
"""

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from typing import Dict, Any, Optional
import logging
import numpy as np
from pathlib import Path

logger = logging.getLogger(__name__)


class ActiveLearningCallback(Callback):
    """
    Callback for active learning integration.
    
    This callback handles the integration between PyTorch Lightning training
    and the active learning framework.
    """
    
    def __init__(
        self,
        active_learner,
        acquisition_interval: int = 1,
        save_predictions: bool = True,
        prediction_dir: Optional[str] = None
    ):
        """
        Initialize the active learning callback.
        
        Args:
            active_learner: Active learning instance
            acquisition_interval: Interval for acquisition (in epochs)
            save_predictions: Whether to save predictions
            prediction_dir: Directory to save predictions
        """
        super().__init__()
        self.active_learner = active_learner
        self.acquisition_interval = acquisition_interval
        self.save_predictions = save_predictions
        self.prediction_dir = Path(prediction_dir) if prediction_dir else None
        
        if self.save_predictions and self.prediction_dir:
            self.prediction_dir.mkdir(parents=True, exist_ok=True)
    
    def on_train_epoch_end(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule
    ) -> None:
        """Called when the train epoch ends."""
        current_epoch = trainer.current_epoch
        
        # Check if it's time for acquisition
        if (current_epoch + 1) % self.acquisition_interval == 0:
            logger.info(f"Performing acquisition at epoch {current_epoch + 1}")
            
            # Perform acquisition
            candidate_batch = self.active_learner._search_candidates()
            self.active_learner._acquire_samples(candidate_batch)
            
            # Update training dataset
            self.active_learner._update_train_dataset()
            
            # Log acquisition results
            stats = self.active_learner.get_stats()
            logger.info(f"Acquisition completed. Train size: {stats['train_size']}, Pool size: {stats['pool_size']}")
    
    def on_validation_end(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule
    ) -> None:
        """Called when validation ends."""
        if self.save_predictions and self.prediction_dir:
            # Save validation predictions
            self._save_predictions(trainer, pl_module, "validation")
    
    def _save_predictions(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
        stage: str
    ) -> None:
        """Save model predictions."""
        try:
            # Get predictions from the model
            if hasattr(pl_module, 'predictions'):
                predictions = pl_module.predictions
                targets = pl_module.targets
                
                # Save predictions
                epoch = trainer.current_epoch
                filename = self.prediction_dir / f"{stage}_predictions_epoch_{epoch:03d}.npz"
                
                np.savez(
                    filename,
                    predictions=predictions,
                    targets=targets,
                    epoch=epoch,
                    stage=stage
                )
                
                logger.info(f"Predictions saved to: {filename}")
        except Exception as e:
            logger.warning(f"Failed to save predictions: {e}")


class ModelCheckpointCallback(Callback):
    """
    Custom model checkpoint callback with additional metadata.
    
    This callback extends the standard checkpoint functionality
    with custom metadata and validation metrics.
    """
    
    def __init__(
        self,
        save_dir: str,
        filename_prefix: str = "checkpoint",
        save_top_k: int = 3,
        monitor: str = "val_loss",
        mode: str = "min"
    ):
        """
        Initialize the custom checkpoint callback.
        
        Args:
            save_dir: Directory to save checkpoints
            filename_prefix: Prefix for checkpoint filenames
            save_top_k: Number of best checkpoints to keep
            monitor: Metric to monitor
            mode: Mode for monitoring ("min" or "max")
        """
        super().__init__()
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)
        self.filename_prefix = filename_prefix
        self.save_top_k = save_top_k
        self.monitor = monitor
        self.mode = mode
        
        self.best_metrics = []
        self.checkpoint_paths = []
    
    def on_validation_end(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule
    ) -> None:
        """Called when validation ends."""
        current_epoch = trainer.current_epoch
        current_metric = trainer.callback_metrics.get(self.monitor, float('inf'))
        
        # Determine if this is a better checkpoint
        is_better = False
        if not self.best_metrics:
            is_better = True
        elif self.mode == "min":
            is_better = current_metric < min(self.best_metrics)
        else:
            is_better = current_metric > max(self.best_metrics)
        
        if is_better:
            # Save checkpoint
            checkpoint_path = self._save_checkpoint(trainer, pl_module, current_epoch, current_metric)
            
            # Update tracking
            self.best_metrics.append(current_metric)
            self.checkpoint_paths.append(checkpoint_path)
            
            # Keep only top-k checkpoints
            if len(self.checkpoint_paths) > self.save_top_k:
                self._cleanup_old_checkpoints()
            
            logger.info(f"New best checkpoint saved: {checkpoint_path} (metric: {current_metric:.4f})")
    
    def _save_checkpoint(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
        epoch: int,
        metric: float
    ) -> Path:
        """Save a checkpoint with custom metadata."""
        filename = f"{self.filename_prefix}_epoch_{epoch:03d}_metric_{metric:.4f}.ckpt"
        checkpoint_path = self.save_dir / filename
        
        # Save checkpoint with additional metadata
        checkpoint = {
            "epoch": epoch,
            "state_dict": pl_module.state_dict(),
            "optimizer_states": [opt.state_dict() for opt in trainer.optimizers],
            "lr_schedulers": [sched.state_dict() for sched in trainer.lr_schedulers],
            "callbacks": trainer.callback_metrics,
            "monitor_metric": metric,
            "monitor_name": self.monitor
        }
        
        torch.save(checkpoint, checkpoint_path)
        return checkpoint_path
    
    def _cleanup_old_checkpoints(self) -> None:
        """Remove old checkpoints to keep only top-k."""
        if len(self.checkpoint_paths) > self.save_top_k:
            # Sort by metric value
            sorted_indices = np.argsort(self.best_metrics)
            if self.mode == "max":
                sorted_indices = sorted_indices[::-1]
            
            # Keep top-k
            keep_indices = sorted_indices[:self.save_top_k]
            
            # Remove old checkpoints
            for i, path in enumerate(self.checkpoint_paths):
                if i not in keep_indices:
                    if path.exists():
                        path.unlink()
                        logger.info(f"Removed old checkpoint: {path}")
            
            # Update tracking lists
            self.best_metrics = [self.best_metrics[i] for i in keep_indices]
            self.checkpoint_paths = [self.checkpoint_paths[i] for i in keep_indices]


class MetricsLoggerCallback(Callback):
    """
    Callback for logging custom metrics.
    
    This callback logs additional metrics and statistics
    during training and validation.
    """
    
    def __init__(
        self,
        log_interval: int = 10,
        log_predictions: bool = False,
        log_gradients: bool = False
    ):
        """
        Initialize the metrics logger callback.
        
        Args:
            log_interval: Interval for logging (in steps)
            log_predictions: Whether to log prediction statistics
            log_gradients: Whether to log gradient statistics
        """
        super().__init__()
        self.log_interval = log_interval
        self.log_predictions = log_predictions
        self.log_gradients = log_gradients
    
    def on_train_batch_end(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
        outputs: Any,
        batch: Any,
        batch_idx: int
    ) -> None:
        """Called when a training batch ends."""
        if batch_idx % self.log_interval == 0:
            # Log gradient statistics
            if self.log_gradients:
                self._log_gradient_stats(pl_module)
            
            # Log prediction statistics
            if self.log_predictions and hasattr(outputs, 'predictions'):
                self._log_prediction_stats(outputs.predictions)
    
    def on_validation_batch_end(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
        outputs: Any,
        batch: Any,
        batch_idx: int
    ) -> None:
        """Called when a validation batch ends."""
        if batch_idx % self.log_interval == 0:
            # Log prediction statistics
            if self.log_predictions and hasattr(outputs, 'predictions'):
                self._log_prediction_stats(outputs.predictions)
    
    def _log_gradient_stats(self, pl_module: pl.LightningModule) -> None:
        """Log gradient statistics."""
        try:
            total_norm = 0.0
            param_count = 0
            
            for name, param in pl_module.named_parameters():
                if param.grad is not None:
                    param_norm = param.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
                    param_count += 1
            
            total_norm = total_norm ** (1.0 / 2)
            
            pl_module.log("grad_norm", total_norm, on_step=True, on_epoch=False)
            pl_module.log("grad_param_count", param_count, on_step=True, on_epoch=False)
            
        except Exception as e:
            logger.warning(f"Failed to log gradient stats: {e}")
    
    def _log_prediction_stats(self, predictions: torch.Tensor) -> None:
        """Log prediction statistics."""
        try:
            if isinstance(predictions, torch.Tensor):
                mean_pred = predictions.mean().item()
                std_pred = predictions.std().item()
                min_pred = predictions.min().item()
                max_pred = predictions.max().item()
                
                # Log to tensorboard (if available)
                if hasattr(self, 'logger'):
                    self.logger.log_metrics({
                        "pred_mean": mean_pred,
                        "pred_std": std_pred,
                        "pred_min": min_pred,
                        "pred_max": max_pred
                    })
                
        except Exception as e:
            logger.warning(f"Failed to log prediction stats: {e}")


class EarlyStoppingCallback(Callback):
    """
    Custom early stopping callback with additional features.
    
    This callback provides enhanced early stopping functionality
    with custom metrics and patience strategies.
    """
    
    def __init__(
        self,
        monitor: str = "val_loss",
        min_delta: float = 0.0,
        patience: int = 10,
        mode: str = "min",
        restore_best_weights: bool = True,
        cooldown: int = 0
    ):
        """
        Initialize the early stopping callback.
        
        Args:
            monitor: Metric to monitor
            min_delta: Minimum change to qualify as improvement
            patience: Number of epochs to wait before stopping
            mode: Mode for monitoring ("min" or "max")
            restore_best_weights: Whether to restore best weights
            cooldown: Number of epochs to wait after stopping before resuming
        """
        super().__init__()
        self.monitor = monitor
        self.min_delta = min_delta
        self.patience = patience
        self.mode = mode
        self.restore_best_weights = restore_best_weights
        self.cooldown = cooldown
        
        self.wait_count = 0
        self.best_score = None
        self.best_weights = None
        self.cooldown_counter = 0
    
    def on_validation_end(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule
    ) -> None:
        """Called when validation ends."""
        current_score = trainer.callback_metrics.get(self.monitor, None)
        
        if current_score is None:
            logger.warning(f"Early stopping callback: {self.monitor} not found in metrics")
            return
        
        # Check if this is the best score
        if self.best_score is None:
            self.best_score = current_score
            if self.restore_best_weights:
                self.best_weights = pl_module.state_dict().copy()
        else:
            is_better = False
            if self.mode == "min":
                is_better = current_score < (self.best_score - self.min_delta)
            else:
                is_better = current_score > (self.best_score + self.min_delta)
            
            if is_better:
                self.best_score = current_score
                self.wait_count = 0
                if self.restore_best_weights:
                    self.best_weights = pl_module.state_dict().copy()
            else:
                self.wait_count += 1
        
        # Check if we should stop
        if self.wait_count >= self.patience:
            if self.cooldown_counter > 0:
                self.cooldown_counter -= 1
            else:
                logger.info(f"Early stopping triggered. Best {self.monitor}: {self.best_score:.4f}")
                
                if self.restore_best_weights and self.best_weights is not None:
                    pl_module.load_state_dict(self.best_weights)
                    logger.info("Best weights restored")
                
                trainer.should_stop = True
