"""
1D Inverse Poisson problem implementation.
"""
import torch
import numpy as np
from typing import Tuple, Dict, Any, Callable
from losses.physics_losses import PoissonPhysicsLoss


class InversePoisson:
    """
    1D Inverse Poisson problem for testing robustness to data corruption.
    
    The problem is: -∇²u = f(x) with unknown source term f(x).
    We observe sparse, corrupted data and need to reconstruct the source.
    """
    
    def __init__(
        self,
        x_start: float = 0.0,
        x_end: float = 1.0,
        source_function: Callable = None,
        device: str = "cpu"
    ):
        """
        Initialize inverse Poisson problem.
        
        Args:
            x_start: Start of spatial domain
            x_end: End of spatial domain
            source_function: True source function f(x)
            device: Device to run on
        """
        self.x_start = x_start
        self.x_end = x_end
        self.device = device
        
        # Default source function
        if source_function is None:
            self.source_function = lambda x: torch.sin(2 * np.pi * x) + 0.5 * torch.sin(4 * np.pi * x)
        else:
            self.source_function = source_function
        
        # Create physics loss
        self.physics_loss = PoissonPhysicsLoss(self.source_function, device)
    
    def solve_analytical(self, x: torch.Tensor) -> torch.Tensor:
        """
        Solve the Poisson equation analytically for the given source.
        
        For -∇²u = f(x) with u(0) = u(1) = 0, the solution is:
        u(x) = ∫₀ˣ ∫₀ᵗ f(s) ds dt - x ∫₀¹ ∫₀ᵗ f(s) ds dt
        
        Args:
            x: Spatial points
            
        Returns:
            Analytical solution
        """
        # This is a simplified analytical solution for the default source
        # For more complex sources, numerical integration would be needed
        u = (torch.sin(2 * np.pi * x) / (4 * np.pi**2) + 
             0.5 * torch.sin(4 * np.pi * x) / (16 * np.pi**2))
        
        # Apply boundary conditions
        u = u - x * u  # This is a simplified boundary condition application
        
        return u
    
    def generate_training_data(
        self,
        num_boundary_points: int = 2,
        num_interior_points: int = 20,
        corruption_level: float = 0.3,
        outlier_std: float = 3.0,
        missing_data_ratio: float = 0.1
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Generate training data with corruption.
        
        Args:
            num_boundary_points: Number of boundary condition points
            num_interior_points: Number of interior observation points
            corruption_level: Percentage of corrupted data
            outlier_std: Standard deviation for outliers
            missing_data_ratio: Percentage of missing data
            
        Returns:
            Tuple of (x_data, y_data) tensors
        """
        # Boundary points
        x_boundary = torch.tensor([[self.x_start], [self.x_end]], dtype=torch.float32, device=self.device)
        y_boundary = torch.zeros(2, 1, dtype=torch.float32, device=self.device)
        
        # Interior points
        x_interior = torch.linspace(
            self.x_start + 0.1, 
            self.x_end - 0.1, 
            num_interior_points, 
            device=self.device
        ).unsqueeze(1)
        
        # True solution at interior points
        y_interior_true = self.solve_analytical(x_interior)
        
        # Add corruption
        n_corrupt = int(num_interior_points * corruption_level)
        corrupt_indices = torch.randperm(num_interior_points)[:n_corrupt]
        
        y_interior = y_interior_true.clone()
        y_interior[corrupt_indices] += torch.randn(n_corrupt, 1, device=self.device) * outlier_std
        
        # Add missing data
        n_missing = int(num_interior_points * missing_data_ratio)
        missing_indices = torch.randperm(num_interior_points)[:n_missing]
        
        # Remove missing data points
        keep_indices = torch.ones(num_interior_points, dtype=torch.bool, device=self.device)
        keep_indices[missing_indices] = False
        
        x_interior = x_interior[keep_indices]
        y_interior = y_interior[keep_indices]
        
        # Combine boundary and interior data
        x_data = torch.cat([x_boundary, x_interior], dim=0)
        y_data = torch.cat([y_boundary, y_interior], dim=0)
        
        return x_data, y_data
    
    def generate_collocation_points(self, num_points: int = 1000) -> torch.Tensor:
        """
        Generate collocation points for physics loss.
        
        Args:
            num_points: Number of collocation points
            
        Returns:
            Collocation points tensor
        """
        x_points = torch.linspace(
            self.x_start, 
            self.x_end, 
            num_points, 
            device=self.device
        ).unsqueeze(1)
        
        return x_points
    
    def compute_physics_loss(self, model: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
        """
        Compute physics loss for the Poisson equation.
        
        Args:
            model: Neural network model
            x: Spatial points
            
        Returns:
            Physics loss tensor
        """
        residual = self.physics_loss.compute_residual(model, x)
        return torch.mean(residual**2)
    
    def get_problem_info(self) -> Dict[str, Any]:
        """Get problem information."""
        return {
            'problem_type': 'inverse_poisson',
            'x_start': self.x_start,
            'x_end': self.x_end,
            'input_dim': 1,  # Spatial coordinate only
            'output_dim': 1,  # Scalar field
        }
