"""
Loss functions for BSNP training.
Includes data likelihood and physics-informed losses.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple, Optional

from physics.pde_operators import NonlinearPoissonOperator


class PIConvNPLoss(nn.Module):
    """
    Combined loss for Physics-Informed ConvNP.
    
    L_total = λ_data * L_data + λ_physics(t) * L_physics
    
    where λ_physics(t) uses warmup schedule.
    """
    
    def __init__(
        self,
        lambda_data: float = 1.0,
        lambda_physics: float = 0.0001,  # Start small
        lambda_reg: float = 0.0,
        use_physics_loss: bool = True,
        pde_operator: Optional[nn.Module] = None,
        physics_warmup_steps: int = 5000,
        physics_weight_final: float = 0.01
    ):
        super().__init__()
        
        self.lambda_data = lambda_data
        self.lambda_physics_initial = lambda_physics
        self.lambda_physics_final = physics_weight_final
        self.lambda_reg = lambda_reg
        self.use_physics_loss = use_physics_loss
        self.warmup_steps = physics_warmup_steps
        
        # Current training step (updated by trainer)
        self.current_step = 0
        
        # PDE operator
        if pde_operator is None and use_physics_loss:
            self.pde_operator = NonlinearPoissonOperator()
        else:
            self.pde_operator = pde_operator
    
    def get_physics_weight(self) -> float:
        """Get current physics weight with linear warmup."""
        if self.current_step < self.warmup_steps:
            alpha = self.current_step / self.warmup_steps
            return (1 - alpha) * self.lambda_physics_initial + alpha * self.lambda_physics_final
        return self.lambda_physics_final
    
    def forward(
        self,
        model: nn.Module,
        x_context: torch.Tensor,
        y_context: torch.Tensor,
        x_target: torch.Tensor,
        y_target: torch.Tensor,
        lambda_params: Optional[torch.Tensor] = None,
        x_collocation: Optional[torch.Tensor] = None,
        boundary_values: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, Dict[str, float]]:
        """
        Compute total loss with numerical stability checks.
        
        Args:
            model: BSNP model
            x_context: Context locations
            y_context: Context observations
            x_target: Target locations
            y_target: Target values
            lambda_params: PDE parameters (batch, n_params)
            x_collocation: Collocation points (unused, we use grid)
            boundary_values: Boundary values (unused, hard constraints)
        
        Returns:
            total_loss: Scalar loss
            loss_dict: Dictionary of loss components
        """
        
        # Get predictions
        mean, sigma = model(x_context, y_context, x_target, lambda_params)
        
        # Data likelihood loss
        data_loss = self.compute_nll_loss(mean, sigma, y_target)
        
        # Check for NaN/Inf
        if torch.isnan(data_loss) or torch.isinf(data_loss):
            print("WARNING: NaN/Inf in data loss!")
            data_loss = torch.tensor(0.0, device=x_context.device)
        
        # Physics loss with current weight
        if self.use_physics_loss and lambda_params is not None:
            physics_loss = self.compute_physics_loss(
                model, x_context, y_context, x_collocation, lambda_params
            )
            
            # Check for NaN/Inf
            if torch.isnan(physics_loss) or torch.isinf(physics_loss):
                print(f"WARNING: NaN/Inf in physics loss! Setting to 0.")
                physics_loss = torch.tensor(0.0, device=x_context.device)
            
            # Get current weight
            current_physics_weight = self.get_physics_weight()
        else:
            physics_loss = torch.tensor(0.0, device=x_context.device)
            current_physics_weight = 0.0
        
        # Regularization
        reg_loss = self.compute_regularization_loss(model)
        
        # Total loss
        total_loss = (
            self.lambda_data * data_loss +
            current_physics_weight * physics_loss +
            self.lambda_reg * reg_loss
        )
        
        # Final NaN check
        if torch.isnan(total_loss) or torch.isinf(total_loss):
            print("WARNING: NaN/Inf in total loss! Using data loss only.")
            total_loss = data_loss
        
        # Loss dictionary
        loss_dict = {
            'total': total_loss.item(),
            'data': data_loss.item(),
            'physics': physics_loss.item() if isinstance(physics_loss, torch.Tensor) else 0.0,
            'reg': reg_loss.item() if isinstance(reg_loss, torch.Tensor) else 0.0,
            'physics_weight': current_physics_weight
        }
        
        return total_loss, loss_dict
    
    def compute_nll_loss(
        self,
        mean: torch.Tensor,
        sigma: torch.Tensor,
        target: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute negative log-likelihood loss.
        
        Args:
            mean: Predicted mean, shape (batch, output_dim, n_target)
            sigma: Predicted std dev, shape (batch, output_dim, n_target)
            target: Target values, shape (batch, n_target, output_dim) or (batch, n_target)
        
        Returns:
            Scalar NLL loss
        """
        # Handle target shape
        if target.dim() == 2:
            target = target.unsqueeze(-1)  # (batch, n_target, 1)
        
        # Transpose target to match mean/sigma
        target = target.transpose(1, 2)  # (batch, output_dim, n_target)
        
        # Compute NLL
        nll = 0.5 * torch.log(2 * torch.pi * sigma**2) + 0.5 * ((target - mean) / sigma)**2
        
        return nll.mean()
    
    def compute_physics_loss(
        self,
        model: nn.Module,
        x_context: torch.Tensor,
        y_context: torch.Tensor,
        x_collocation: Optional[torch.Tensor],
        lambda_params: torch.Tensor,
        num_collocation: int = 100
    ) -> torch.Tensor:
        """
        Compute physics-informed loss based on PDE residuals.
        
        CRITICAL: Extract mean field from model's internal grid representation,
        ensuring gradient flow is maintained.
        
        Args:
            model: The BSNP model
            x_context: Context locations
            y_context: Context observations
            x_collocation: Not used (we use the model's internal grid)
            lambda_params: PDE parameters, shape (batch, n_chebyshev + 1)
            num_collocation: Not used
        
        Returns:
            Physics loss (scalar)
        """
        if not self.use_physics_loss:
            return torch.tensor(0.0, device=x_context.device)
        
        batch_size = x_context.shape[0]
        
        # Convert lambda_params to dict format for PDE operator
        if isinstance(lambda_params, torch.Tensor):
            lambda_params_dict = {
                'xi': lambda_params[:, :-1],  # Chebyshev coefficients
                'w': lambda_params[:, -1:]    # Nonlinearity parameter
            }
        elif isinstance(lambda_params, dict):
            lambda_params_dict = lambda_params
        else:
            raise TypeError(f"Unexpected lambda_params type: {type(lambda_params)}")
        
        try:
            # Get grid points from model
            grid_manager = model.grid_manager
            x_grid = grid_manager.get_grid(batch_size)  # (batch, n_grid, 1)
            
            # CRITICAL: Clone and enable gradients for x_grid
            x_grid = x_grid.clone().detach().requires_grad_(True)
            
            # Get mean field on grid using model's specialized method
            # This preserves the computational graph
            with torch.set_grad_enabled(True):
                # Use get_mean_field_on_grid which processes through full pipeline
                mean_field = model.get_mean_field_on_grid(
                    x_context,
                    y_context,
                    lambda_params
                )
                # mean_field shape: (batch, output_dim, n_grid)
                
                # Ensure correct shape for PDE operator
                if mean_field.dim() == 2:
                    mean_field = mean_field.unsqueeze(1)
                
                # Compute PDE residual
                residual = self.pde_operator(
                    mean_field,          # (batch, 1, n_grid)
                    x_grid,              # (batch, n_grid, 1)
                    lambda_params_dict   # {'xi': ..., 'w': ...}
                )
                # residual shape: (batch, n_grid)
                
                # Compute physics loss with numerical stability
                # Use Huber loss (smooth L1) for robustness
                physics_loss = F.smooth_l1_loss(
                    residual,
                    torch.zeros_like(residual),
                    reduction='mean',
                    beta=1.0  # Transition point between L1 and L2
                )
                
        except Exception as e:
            print(f"Warning: Physics loss computation failed: {e}")
            import traceback
            traceback.print_exc()
            physics_loss = torch.tensor(0.0, device=x_context.device)
        
        return physics_loss
    
    def compute_regularization_loss(self, model: nn.Module) -> torch.Tensor:
        """Compute L2 regularization on model parameters."""
        if self.lambda_reg == 0:
            return torch.tensor(0.0, device=next(model.parameters()).device)
        
        reg_loss = 0.0
        for param in model.parameters():
            if param.requires_grad:
                reg_loss += torch.sum(param ** 2)
        
        return reg_loss


def build_loss_function(
    loss_type: str = 'pi_convnp',
    lambda_data: float = 1.0,
    lambda_physics: float = 0.0001,
    lambda_reg: float = 1e-6,
    physics_warmup_steps: int = 5000,
    physics_weight_final: float = 0.01
) -> nn.Module:
    """
    Build loss function.
    
    Args:
        loss_type: Type of loss ('pi_convnp', 'mse', 'nll')
        lambda_data: Weight for data loss
        lambda_physics: Initial weight for physics loss
        lambda_reg: Weight for regularization
        physics_warmup_steps: Number of warmup steps
        physics_weight_final: Final physics weight after warmup
    
    Returns:
        Loss function module
    """
    if loss_type == 'pi_convnp':
        return PIConvNPLoss(
            lambda_data=lambda_data,
            lambda_physics=lambda_physics,
            lambda_reg=lambda_reg,
            use_physics_loss=True,
            pde_operator=NonlinearPoissonOperator(),
            physics_warmup_steps=physics_warmup_steps,
            physics_weight_final=physics_weight_final
        )
    elif loss_type == 'nll':
        return PIConvNPLoss(
            lambda_data=lambda_data,
            lambda_physics=0.0,
            lambda_reg=lambda_reg,
            use_physics_loss=False
        )
    elif loss_type == 'mse':
        # Simple MSE loss wrapper
        class MSELoss(nn.Module):
            def forward(self, model, x_context, y_context, x_target, y_target, **kwargs):
                mean, _ = model(x_context, y_context, x_target, kwargs.get('lambda_params'))
                if y_target.dim() == 2:
                    y_target = y_target.unsqueeze(-1)
                y_target = y_target.transpose(1, 2)
                loss = F.mse_loss(mean, y_target)
                return loss, {'total': loss.item(), 'data': loss.item(), 'physics': 0.0}
        return MSELoss()
    else:
        raise ValueError(f"Unknown loss type: {loss_type}")