"""
R-PIT (Robustness-Regularized Physics-Informed Training) implementation.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Any, Tuple
import numpy as np
from .base_pinn import BasePINN


class RPITPINN(BasePINN):
    """
    Robustness-Regularized Physics-Informed Neural Network.
    
    This implementation includes:
    1. Noise injection for robustness
    2. Sensitivity regularization
    3. Variance-aware loss for uncertainty quantification
    """
    
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        hidden_layers: list,
        activation: str = "tanh",
        device: str = "cpu",
        lambda_sens: float = 0.1,
        lambda_var: float = 1.0,
        noise_std: float = 0.1,
        uncertainty_output: bool = True
    ):
        """
        Initialize R-PIT PINN.
        
        Args:
            input_dim: Number of input dimensions
            output_dim: Number of output dimensions
            hidden_layers: List of hidden layer sizes
            activation: Activation function name
            device: Device to run on
            lambda_sens: Sensitivity regularization weight
            lambda_var: Variance loss weight
            noise_std: Noise injection standard deviation
            uncertainty_output: Whether to output uncertainty estimates
        """
        self.lambda_sens = lambda_sens
        self.lambda_var = lambda_var
        self.noise_std = noise_std
        self.uncertainty_output = uncertainty_output
        
        # For uncertainty output, we need 2x output dimensions (mean + variance)
        actual_output_dim = output_dim * 2 if uncertainty_output else output_dim
        
        super().__init__(
            input_dim=input_dim,
            output_dim=actual_output_dim,
            hidden_layers=hidden_layers,
            activation=activation,
            device=device
        )
        
        self.base_output_dim = output_dim
    
    def forward(self, x: torch.Tensor, add_noise: bool = False) -> torch.Tensor:
        """
        Forward pass with optional noise injection.
        
        Args:
            x: Input tensor
            add_noise: Whether to add noise to inputs
            
        Returns:
            Output tensor (mean + variance if uncertainty_output=True)
        """
        if add_noise and self.training:
            # Add Gaussian noise to inputs
            noise = torch.randn_like(x) * self.noise_std
            x = x + noise
        
        return self.network(x)
    
    def predict_mean_variance(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Predict mean and variance separately.
        
        Args:
            x: Input tensor
            
        Returns:
            Tuple of (mean, variance) tensors
        """
        if not self.uncertainty_output:
            mean = self.forward(x)
            variance = torch.zeros_like(mean)
            return mean, variance
        
        output = self.forward(x)
        mean = output[:, :self.base_output_dim]
        variance = F.softplus(output[:, self.base_output_dim:]) + 1e-6  # Ensure positive variance
        
        return mean, variance
    
    def compute_physics_loss(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute physics loss using mean predictions.
        
        Args:
            x: Collocation points
            
        Returns:
            Physics loss tensor
        """
        # This is a placeholder - specific problem implementations will override
        # For now, return zero to avoid errors during testing
        return torch.tensor(0.0, device=self.device, requires_grad=True)
    
    def compute_data_loss(self, x_data: torch.Tensor, y_data: torch.Tensor) -> torch.Tensor:
        """
        Compute variance-aware data loss.
        
        Args:
            x_data: Data input points
            y_data: Data target values
            
        Returns:
            Data loss tensor
        """
        if not self.uncertainty_output:
            # Standard MSE loss
            predictions = self.forward(x_data)
            return F.mse_loss(predictions, y_data)
        
        # Variance-aware loss (negative log-likelihood)
        mean, variance = self.predict_mean_variance(x_data)
        
        # Ensure variance is positive and not too small
        variance = torch.clamp(variance, min=1e-6)
        
        # Negative log-likelihood under Gaussian assumption
        nll = 0.5 * torch.log(2 * np.pi * variance) + 0.5 * (mean - y_data)**2 / variance
        return nll.mean()
    
    def compute_sensitivity_loss(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute sensitivity regularization loss.
        
        Args:
            x: Input points for sensitivity computation
            
        Returns:
            Sensitivity loss tensor
        """
        x.requires_grad_(True)
        
        # Forward pass
        output = self.forward(x, add_noise=True)
        mean = output[:, :self.base_output_dim]
        
        # Compute Jacobian
        jacobian = torch.autograd.grad(
            outputs=mean,
            inputs=x,
            grad_outputs=torch.ones_like(mean),
            create_graph=True,
            retain_graph=True
        )[0]
        
        # Frobenius norm of Jacobian
        sensitivity_loss = torch.mean(torch.sum(jacobian**2, dim=1))
        
        return sensitivity_loss
    
    def compute_variance_loss(self, x_data: torch.Tensor) -> torch.Tensor:
        """
        Compute variance regularization loss.
        
        Args:
            x_data: Data input points
            
        Returns:
            Variance loss tensor
        """
        if not self.uncertainty_output:
            return torch.tensor(0.0, device=self.device, requires_grad=True)
        
        _, variance = self.predict_mean_variance(x_data)
        
        # Penalize excessive variance (encourage confident predictions)
        variance_loss = torch.mean(variance)
        
        return variance_loss
    
    def compute_total_loss(
        self,
        x_collocation: torch.Tensor,
        x_data: torch.Tensor,
        y_data: torch.Tensor,
        **kwargs
    ) -> Dict[str, torch.Tensor]:
        """
        Compute the total R-PIT loss.
        
        Args:
            x_collocation: Collocation points for physics loss
            x_data: Data input points
            y_data: Data target values
            **kwargs: Additional arguments
            
        Returns:
            Dictionary containing all loss components
        """
        # Standard losses
        physics_loss = self.compute_physics_loss(x_collocation)
        data_loss = self.compute_data_loss(x_data, y_data)
        
        # R-PIT specific losses
        sensitivity_loss = self.compute_sensitivity_loss(x_collocation)
        variance_loss = self.compute_variance_loss(x_data)
        
        # Total loss with all components
        total_loss = (
            physics_loss + 
            data_loss + 
            self.lambda_sens * sensitivity_loss +
            self.lambda_var * variance_loss
        )
        
        return {
            'total_loss': total_loss,
            'physics_loss': physics_loss,
            'data_loss': data_loss,
            'sensitivity_loss': sensitivity_loss,
            'variance_loss': variance_loss,
        }
    
    def get_uncertainty_estimates(self, x: torch.Tensor) -> torch.Tensor:
        """
        Get uncertainty estimates for predictions.
        
        Args:
            x: Input points
            
        Returns:
            Uncertainty estimates (standard deviation)
        """
        if not self.uncertainty_output:
            return torch.zeros(x.shape[0], self.base_output_dim, device=self.device)
        
        _, variance = self.predict_mean_variance(x)
        return torch.sqrt(variance)
    
    def get_model_info(self) -> Dict[str, Any]:
        """Get R-PIT specific model information."""
        base_info = super().get_model_info()
        base_info.update({
            'lambda_sens': self.lambda_sens,
            'lambda_var': self.lambda_var,
            'noise_std': self.noise_std,
            'uncertainty_output': self.uncertainty_output,
            'base_output_dim': self.base_output_dim,
        })
        return base_info
