#!/usr/bin/env python3
"""
Evaluation metrics for PINN models.
"""

import torch
import numpy as np
from typing import Dict, Any


def compute_metrics(model: torch.nn.Module, data: Dict[str, torch.Tensor], problem_type: str) -> Dict[str, float]:
    """
    Compute evaluation metrics for a model.
    
    Args:
        model: Trained model
        data: Test data
        problem_type: Type of problem
    
    Returns:
        Dictionary of metrics
    """
    model.eval()
    
    with torch.no_grad():
        x_data = data['x_data']
        y_data = data['y_data']
        
        # Get predictions
        if hasattr(model, 'forward_with_uncertainty'):
            predictions, uncertainty = model.forward_with_uncertainty(x_data)
        elif hasattr(model, 'predict_mean_variance'):
            # For R-PIT models
            predictions, variance = model.predict_mean_variance(x_data)
            uncertainty = torch.sqrt(variance)  # Convert variance to std
        else:
            predictions = model(x_data)
            uncertainty = None
        
        # Basic metrics
        mse = torch.mean((predictions - y_data) ** 2).item()
        mae = torch.mean(torch.abs(predictions - y_data)).item()
        rmse = np.sqrt(mse)
        
        # Relative error
        relative_error = torch.mean(torch.abs(predictions - y_data) / (torch.abs(y_data) + 1e-8)).item()
        
        metrics = {
            'mse': mse,
            'mae': mae,
            'rmse': rmse,
            'relative_error': relative_error
        }
        
        # Uncertainty metrics
        if uncertainty is not None:
            metrics['mean_uncertainty'] = torch.mean(uncertainty).item()
            metrics['max_uncertainty'] = torch.max(uncertainty).item()
            metrics['min_uncertainty'] = torch.min(uncertainty).item()
        
        # Problem-specific metrics
        if problem_type == "lorenz":
            metrics.update(compute_lorenz_metrics(predictions, y_data))
        elif problem_type == "burgers":
            metrics.update(compute_burgers_metrics(predictions, y_data))
        elif problem_type == "inverse_poisson":
            metrics.update(compute_inverse_poisson_metrics(predictions, y_data))
    
    return metrics


def compute_lorenz_metrics(predictions: torch.Tensor, y_data: torch.Tensor) -> Dict[str, float]:
    """Compute Lorenz-specific metrics."""
    # Separate components
    pred_x, pred_y, pred_z = predictions[:, 0], predictions[:, 1], predictions[:, 2]
    true_x, true_y, true_z = y_data[:, 0], y_data[:, 1], y_data[:, 2]
    
    # Component-wise errors
    x_error = torch.mean((pred_x - true_x) ** 2).item()
    y_error = torch.mean((pred_y - true_y) ** 2).item()
    z_error = torch.mean((pred_z - true_z) ** 2).item()
    
    return {
        'x_component_mse': x_error,
        'y_component_mse': y_error,
        'z_component_mse': z_error
    }


def compute_burgers_metrics(predictions: torch.Tensor, y_data: torch.Tensor) -> Dict[str, float]:
    """Compute Burgers-specific metrics."""
    # Separate components
    pred_u, pred_v = predictions[:, 0], predictions[:, 1]
    true_u, true_v = y_data[:, 0], y_data[:, 1]
    
    # Component-wise errors
    u_error = torch.mean((pred_u - true_u) ** 2).item()
    v_error = torch.mean((pred_v - true_v) ** 2).item()
    
    return {
        'u_component_mse': u_error,
        'v_component_mse': v_error
    }


def compute_inverse_poisson_metrics(predictions: torch.Tensor, y_data: torch.Tensor) -> Dict[str, float]:
    """Compute inverse Poisson-specific metrics."""
    # Separate solution and source
    pred_u, pred_f = predictions[:, 0], predictions[:, 1]
    true_u, true_f = y_data[:, 0], y_data[:, 1]
    
    # Solution and source errors
    u_error = torch.mean((pred_u - true_u) ** 2).item()
    f_error = torch.mean((pred_f - true_f) ** 2).item()
    
    return {
        'solution_mse': u_error,
        'source_mse': f_error
    }
