"""Training module for neural operator models.

This module provides the Trainer class for training neural operator models with support
for distributed training, early stopping, and comprehensive logging.
"""

import os
from pathlib import Path
from typing import Optional, Dict, List

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm

from src.config import Config
from src.utils.logger import CustomLogger


class Trainer:
    """Trainer for neural operator models.
    
    This class handles the complete training pipeline including distributed training,
    validation, early stopping, checkpointing, and logging.
    
    Parameters
    ----------
    model : nn.Module
        The neural operator model to train.
    config : Config
        Experiment configuration containing all training parameters.
    logger : CustomLogger
        Logger for metrics and artifacts.
    train_loader : DataLoader
        Training data loader.
    val_loader : DataLoader
        Validation data loader.
    criterion : nn.Module
        Loss function to use for training.
    val_criterion : nn.Module
        Loss function to use for validation.
    optimizer : torch.optim.Optimizer
        Optimizer for model parameters.
    scheduler : Optional[torch.optim.lr_scheduler._LRScheduler]
        Learning rate scheduler, if any.
    
    Attributes
    ----------
    model : nn.Module
        The neural operator model (wrapped in DDP if distributed).
    config : Config
        Experiment configuration.
    logger : CustomLogger
        Logger instance.
    device : torch.device
        Device to run training on.
    optimizer : torch.optim.Optimizer
        Optimizer instance.
    scheduler : Optional[torch.optim.lr_scheduler._LRScheduler]
        Learning rate scheduler.
    criterion : nn.Module
        Loss function.
    best_val_loss : float
        Best validation loss achieved so far.
    epochs_without_improvement : int
        Number of epochs without improvement for early stopping.
    checkpoint : Optional[dict], optional
        Initial checkpoint to load model state from, by default None.
    """

    def __init__(
            self,
            model: nn.Module,
            config: Config,
            logger: CustomLogger,
            train_loader: DataLoader,
            val_loader: DataLoader,
            criterion: nn.Module,
            val_criterion: nn.Module,
            optimizer: torch.optim.Optimizer,
            scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
            inference_engine,
            checkpoint: Optional[dict] = None
    ) -> None:
        self.model = model

        # General setup
        self.config = config
        self.logger = logger
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.inference_engine = inference_engine
        self.device = torch.device(config.training.device)

        # Output directories
        self.output_dir = self.config.output_dir
        self.vis_dir = Path(self.output_dir) / "visualizations"
        self.vis_dir.mkdir(parents=True, exist_ok=True)

        # Setup optimizer and criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.criterion = criterion
        self.val_criterion = val_criterion

        # Training state
        self.best_val_loss = float('inf')
        self.epochs_without_improvement = 0

        # Load initial checkpoint if provided
        if checkpoint is not None:
            self.load_checkpoint(checkpoint)

    def train(self) -> None:
        """Main training loop.
        
        Trains the model for the specified number of epochs, evaluates on validation set,
        logs metrics, and handles early stopping and checkpointing.
        """
        for epoch in range(self.config.training.n_epochs + 1):
            # Training and validation
            train_metrics = self.train_epoch(epoch)
            val_metrics = self.evaluate(epoch)

            # Log metrics
            metrics = {**train_metrics, **val_metrics, 'lr': self.optimizer.param_groups[0]['lr']}
            self.logger.log_metrics(metrics, step=epoch)
            if epoch % 10 == 0:
                visuals, dtype = self.inference_engine.visualize(self.vis_dir)
                self.logger.log_visuals(visuals, step=epoch, dtype=dtype)

            # Learning rate scheduling
            if self.scheduler is not None:
                if self.config.training.scheduler.lower() == "reduce_on_plateau":
                    self.scheduler.step(val_metrics["val_loss"])
                else:
                    self.scheduler.step()

            # Early stopping
            if val_metrics["val_loss"] < self.best_val_loss:
                self.best_val_loss = val_metrics["val_loss"]
                self.epochs_without_improvement = 0
                self.save_checkpoint("best_model.pt", aliases=["best", f"epoch_{epoch}"])
            else:
                self.epochs_without_improvement += 1
                if self.epochs_without_improvement >= self.config.training.early_stopping_patience:
                    self.logger.logger.info("Early stopping triggered")
                    break

            # Save regular checkpoint
            if epoch % 10 == 0:
                self.save_checkpoint(f"checkpoint_epoch_{epoch}.pt")

    def train_epoch(self, epoch: int) -> Dict[str, float]:
        """Train for one epoch.
        
        Parameters
        ----------
        epoch : int
            Current epoch number.
            
        Returns
        -------
        Dict[str, float]
            Dictionary containing average training loss.
        """
        self.model.train()
        total_loss = 0.0

        pbar = tqdm(self.train_loader, desc=f"Training Epoch {epoch}")
        for data in pbar:
            batch_loss = self._run_batch(data)

            self.optimizer.zero_grad()
            batch_loss.backward()
            self.optimizer.step()

            total_loss += batch_loss.item()
            pbar.set_postfix({"loss": batch_loss.item()})

        return {"train_loss": total_loss / len(self.train_loader)}

    @torch.no_grad()
    def evaluate(self, epoch: int) -> Dict[str, float]:
        """Evaluate the model on the given data loader.
        
        Parameters
        ----------
        loader : DataLoader
            DataLoader for evaluation (validation or test).
            
        Returns
        -------
        Dict[str, float]
            Dictionary containing average validation loss.
        """
        self.model.eval()
        total_loss = 0.0

        pbar = tqdm(self.val_loader, desc=f"Validation Epoch {epoch}")
        for data in pbar:
            loss = self._run_batch(data, crit="val")
            total_loss += loss.item()

        avg_loss = torch.tensor(total_loss / len(self.val_loader), device=self.device).item()

        return {"val_loss": avg_loss}

    def _run_batch(self, data: Dict[str, torch.Tensor], crit="train") -> torch.Tensor:
        """Run a single batch through the model and compute the loss.
        
        Parameters
        ----------
        data : List[torch.Tensor]
            List containing the input data and target labels for the batch.
            The last element is the target, all others are inputs.
        crit : str, optional
            Which criterion to use ('train' or 'val'), by default "train".
            
        Returns
        -------
        torch.Tensor
            Computed loss for the batch.
        """
        criterion = self.criterion if crit == "train" else self.val_criterion
        data = {k: v.to(self.device) for k, v in data.items()}

        output = self.model(data)
        loss = criterion(output, data['y'])

        return loss

    def save_checkpoint(self, filename: str, aliases: list[str] = None) -> None:
        """Save model checkpoint.
        
        Parameters
        ----------
        filename : str
            Name of the checkpoint file.
        """
        checkpoint = {
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
            'epoch': self.epochs_without_improvement,
            'best_val_loss': self.best_val_loss,
            'config': self.config
        }

        save_path = os.path.join(self.config.output_dir, filename)
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        torch.save(checkpoint, save_path)
        self.logger.log_artifact(save_path, filename, aliases=aliases)

    def load_checkpoint(self, checkpoint: dict) -> None:
        """Load model checkpoint.
        
        Parameters
        ----------
        filename : str
            Path to the checkpoint file.
        """
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if self.scheduler and checkpoint['scheduler_state_dict']:
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.epochs_without_improvement = checkpoint['epoch']
        self.best_val_loss = checkpoint['best_val_loss']
