#!/usr/bin/env python3
"""
Training utilities for PINN models.
"""

import torch
import torch.optim as optim
from typing import Dict, Any, List, Optional
import numpy as np


class PINNTrainer:
    """Simple trainer for PINN models."""
    
    def __init__(
        self,
        model: torch.nn.Module,
        epochs: int = 1000,
        lr: float = 1e-3,
        batch_size: int = 1000,
        log_frequency: int = 100,
        device: str = "cpu",
        **kwargs
    ):
        """
        Initialize trainer.
        
        Args:
            model: PINN model to train
            epochs: Number of training epochs
            lr: Learning rate
            batch_size: Batch size for training
            log_frequency: Frequency of logging
            device: Device to use for training
        """
        self.model = model
        self.epochs = epochs
        self.lr = lr
        self.batch_size = batch_size
        self.log_frequency = log_frequency
        self.device = device
        
        # Move model to device
        self.model.to(device)
        
        # Initialize optimizer
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        
        # Training history
        self.losses_history = []
    
    def train(self, data: Dict[str, torch.Tensor]) -> List[float]:
        """
        Train the model.
        
        Args:
            data: Dictionary containing training data
                - 'x_collocation': Collocation points
                - 'x_data': Data points
                - 'y_data': Target values
        
        Returns:
            List of total losses during training
        """
        self.model.train()
        
        x_collocation = data['x_collocation'].to(self.device)
        x_data = data['x_data'].to(self.device)
        y_data = data['y_data'].to(self.device)
        
        losses_history = []
        
        for epoch in range(self.epochs):
            self.optimizer.zero_grad()
            
            # Compute loss
            losses = self.model.compute_total_loss(x_collocation, x_data, y_data)
            
            # Backward pass
            losses['total_loss'].backward()
            self.optimizer.step()
            
            # Store loss
            losses_history.append(losses['total_loss'].item())
            
            # Log progress
            if epoch % self.log_frequency == 0:
                print(f"Epoch {epoch:5d}: Total Loss = {losses['total_loss'].item():.6f}")
        
        self.losses_history = losses_history
        return losses_history
    
    def evaluate(self, data: Dict[str, torch.Tensor]) -> Dict[str, float]:
        """
        Evaluate the model.
        
        Args:
            data: Dictionary containing test data
        
        Returns:
            Dictionary of evaluation metrics
        """
        self.model.eval()
        
        with torch.no_grad():
            x_data = data['x_data'].to(self.device)
            y_data = data['y_data'].to(self.device)
            
            # Get predictions
            if hasattr(self.model, 'forward_with_uncertainty'):
                predictions, uncertainty = self.model.forward_with_uncertainty(x_data)
            else:
                predictions = self.model(x_data)
                uncertainty = None
            
            # Compute metrics
            mse = torch.mean((predictions - y_data) ** 2).item()
            mae = torch.mean(torch.abs(predictions - y_data)).item()
            
            metrics = {
                'mse': mse,
                'mae': mae,
                'rmse': np.sqrt(mse)
            }
            
            if uncertainty is not None:
                metrics['mean_uncertainty'] = torch.mean(uncertainty).item()
        
        return metrics
