"""
Base Physics-Informed Neural Network implementation.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from abc import ABC, abstractmethod
from typing import Dict, Any, Tuple, Optional
import numpy as np


class BasePINN(nn.Module, ABC):
    """
    Abstract base class for Physics-Informed Neural Networks.
    
    This class provides the common interface and functionality for all PINN implementations,
    including standard PINNs, R-PIT PINNs, and Bayesian PINNs.
    """
    
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        hidden_layers: list,
        activation: str = "tanh",
        device: str = "cpu"
    ):
        """
        Initialize the base PINN.
        
        Args:
            input_dim: Number of input dimensions (e.g., 2 for 1D+time, 3 for 2D+time)
            output_dim: Number of output dimensions (e.g., 1 for scalar field, 3 for vector field)
            hidden_layers: List of hidden layer sizes
            activation: Activation function name
            device: Device to run on ('cpu' or 'cuda')
        """
        super().__init__()
        
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_layers = hidden_layers
        self.activation = activation
        self.device = device
        
        # Build the neural network
        self.network = self._build_network()
        self.to(device)
        
        # Initialize weights
        self._initialize_weights()
    
    def _build_network(self) -> nn.Module:
        """Build the neural network architecture."""
        layers = []
        
        # Input layer
        layers.append(nn.Linear(self.input_dim, self.hidden_layers[0]))
        
        # Hidden layers
        for i in range(len(self.hidden_layers) - 1):
            layers.append(self._get_activation())
            layers.append(nn.Linear(self.hidden_layers[i], self.hidden_layers[i + 1]))
        
        # Output layer
        layers.append(self._get_activation())
        layers.append(nn.Linear(self.hidden_layers[-1], self.output_dim))
        
        return nn.Sequential(*layers)
    
    def _get_activation(self) -> nn.Module:
        """Get activation function."""
        activations = {
            'tanh': nn.Tanh(),
            'relu': nn.ReLU(),
            'gelu': nn.GELU(),
            'swish': nn.SiLU(),
            'sin': lambda x: torch.sin(x),
        }
        return activations.get(self.activation, nn.Tanh())
    
    def _initialize_weights(self):
        """Initialize network weights using Xavier initialization."""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_normal_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the network.
        
        Args:
            x: Input tensor of shape (batch_size, input_dim)
            
        Returns:
            Output tensor of shape (batch_size, output_dim)
        """
        return self.network(x)
    
    @abstractmethod
    def compute_physics_loss(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute the physics loss (PDE residual).
        
        Args:
            x: Collocation points tensor
            
        Returns:
            Physics loss tensor
        """
        pass
    
    @abstractmethod
    def compute_data_loss(self, x_data: torch.Tensor, y_data: torch.Tensor) -> torch.Tensor:
        """
        Compute the data loss.
        
        Args:
            x_data: Data input points
            y_data: Data target values
            
        Returns:
            Data loss tensor
        """
        pass
    
    def compute_total_loss(
        self,
        x_collocation: torch.Tensor,
        x_data: torch.Tensor,
        y_data: torch.Tensor,
        **kwargs
    ) -> Dict[str, torch.Tensor]:
        """
        Compute the total loss combining physics and data losses.
        
        Args:
            x_collocation: Collocation points for physics loss
            x_data: Data input points
            y_data: Data target values
            **kwargs: Additional arguments for specific implementations
            
        Returns:
            Dictionary containing individual and total losses
        """
        physics_loss = self.compute_physics_loss(x_collocation)
        data_loss = self.compute_data_loss(x_data, y_data)
        
        total_loss = physics_loss + data_loss
        
        return {
            'total_loss': total_loss,
            'physics_loss': physics_loss,
            'data_loss': data_loss,
        }
    
    def predict(self, x: torch.Tensor) -> torch.Tensor:
        """
        Make predictions on new data.
        
        Args:
            x: Input tensor
            
        Returns:
            Predictions tensor
        """
        self.eval()
        with torch.no_grad():
            return self.forward(x)
    
    def get_model_info(self) -> Dict[str, Any]:
        """Get model information."""
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        
        return {
            'input_dim': self.input_dim,
            'output_dim': self.output_dim,
            'hidden_layers': self.hidden_layers,
            'activation': self.activation,
            'total_parameters': total_params,
            'trainable_parameters': trainable_params,
            'device': self.device,
        }
    
    def save_checkpoint(self, filepath: str, **kwargs):
        """Save model checkpoint."""
        checkpoint = {
            'model_state_dict': self.state_dict(),
            'model_info': self.get_model_info(),
            **kwargs
        }
        torch.save(checkpoint, filepath)
    
    def load_checkpoint(self, filepath: str):
        """Load model checkpoint."""
        checkpoint = torch.load(filepath, map_location=self.device)
        self.load_state_dict(checkpoint['model_state_dict'])
        return checkpoint


class StandardPINN(BasePINN):
    """
    Standard Physics-Informed Neural Network implementation.
    
    This is the baseline implementation that will be compared against R-PIT.
    """
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_type = "mse"
    
    def compute_physics_loss(self, x: torch.Tensor) -> torch.Tensor:
        """Compute physics loss using MSE."""
        # This is a placeholder - specific implementations will override
        # For now, return zero to avoid errors during testing
        return torch.tensor(0.0, device=self.device, requires_grad=True)
    
    def compute_data_loss(self, x_data: torch.Tensor, y_data: torch.Tensor) -> torch.Tensor:
        """Compute data loss using MSE."""
        predictions = self.forward(x_data)
        return F.mse_loss(predictions, y_data)
