"""
Complete Physics-Informed Convolutional Neural Process model.
Combines all components into the full architecture.
"""

import torch
import torch.nn as nn
from typing import Dict, Tuple, Optional
import math

from models.encoders import SetEncoder, ParameterEncoder
from models.kernels import build_kernel
from models.aggregator import MultiChannelAggregator
from models.conv_backbone import build_conv_backbone
from models.decoder import build_decoder
from models.grid_manager import GridManager


class PIConvNP(nn.Module):
    """
    Physics-Informed Convolutional Neural Process.
    
    Architecture:
    1. Encode context set: (x_c, y_c) -> latent features
    2. Aggregate onto regular grid using kernels
    3. Process with convolutional backbone
    4. Decode to predictive distribution
    
    The model can be conditioned on PDE parameters λ.
    """
    
    def __init__(
        self,
        # Problem specification
        spatial_dim: int,
        observation_dim: int,
        output_dim: int = 1,
        
        # Domain and grid
        grid_resolution: Tuple[int, ...] = (256,),
        domain_bounds: Tuple[Tuple[float, float], ...] = ((-1.0, 1.0),),
        
        # Architecture dimensions
        latent_dim: int = 64,
        observation_encoder_dim: int = 64,
        conv_channels: int = 64,
        
        # Encoder settings
        encoder_hidden_dims: Tuple[int, ...] = (64, 64),
        
        # Kernel settings
        kernel_type: str = 'rbf',
        kernel_lengthscale: float = 0.1,
        kernel_learnable: bool = True,
        
        # Backbone settings
        num_conv_blocks: int = 6,
        conv_kernel_size: int = 3,
        use_unet: bool = False,
        
        # Decoder settings
        min_sigma: float = 1e-4,
        
        # Parameter conditioning
        parameter_dim: Optional[int] = None,
        use_parameter_conditioning: bool = False,
        
        # Other
        activation: str = 'swish',
        device: str = 'cpu'
    ):
        super().__init__()
        
        self.spatial_dim = spatial_dim
        self.observation_dim = observation_dim
        self.output_dim = output_dim
        self.latent_dim = latent_dim
        self.use_parameter_conditioning = use_parameter_conditioning
        self._device = device
        
        # Grid manager
        self.grid_manager = GridManager(
            spatial_dim=spatial_dim,
            grid_resolution=grid_resolution,
            domain_bounds=domain_bounds,
            device=device
        )
        
        # Set encoder
        self.set_encoder = SetEncoder(
            spatial_dim=spatial_dim,
            observation_dim=observation_dim,
            latent_dim=latent_dim,
            observation_encoder_dim=observation_encoder_dim,
            hidden_dims=encoder_hidden_dims,
            activation=activation
        )
        
        # Parameter encoder (optional)
        if use_parameter_conditioning and parameter_dim is not None:
            self.parameter_encoder = ParameterEncoder(
                input_dim=parameter_dim,
                output_dim=latent_dim,
                hidden_dims=encoder_hidden_dims,
                activation=activation
            )
            self.parameter_dim = parameter_dim
        else:
            self.parameter_encoder = None
            self.parameter_dim = None
        
        # Kernel for aggregation
        self.kernel = build_kernel(
            kernel_type=kernel_type,
            initial_lengthscale=kernel_lengthscale,
            learnable=kernel_learnable
        )
        
        # Aggregator
        self.aggregator = MultiChannelAggregator(
            kernel=self.kernel,
            normalize=True,
            include_density=True
        )
        
        # Input channels to backbone
        backbone_input_channels = latent_dim + 1
        if use_parameter_conditioning and parameter_dim is not None:
            backbone_input_channels += latent_dim
        
        # Convolutional backbone
        self.backbone = build_conv_backbone(
            spatial_dim=spatial_dim,
            input_channels=backbone_input_channels,
            hidden_channels=conv_channels,
            num_blocks=num_conv_blocks,
            kernel_size=conv_kernel_size,
            activation=activation,
            use_unet=use_unet
        )
        
        # Decoder
        self.decoder = build_decoder(
            spatial_dim=spatial_dim,
            input_channels=conv_channels,
            output_dim=output_dim,
            parameter_dim=latent_dim if use_parameter_conditioning else None,
            min_sigma=min_sigma,
            use_parameter_conditioning=False
        )
    
    @property
    def device(self):
        """Get the device of the model."""
        return next(self.parameters()).device
    
    def to(self, device):
        """Override to method to also update grid_manager device."""
        super().to(device)
        self.grid_manager.to(device)
        self._device = str(device)
        return self
    
    def encode_context(
        self,
        x_context: torch.Tensor,
        y_context: torch.Tensor
    ) -> torch.Tensor:
        """
        Encode context set into latent features.
        
        Args:
            x_context: Context locations, shape (batch, n_context, spatial_dim)
            y_context: Context observations, shape (batch, n_context, obs_dim)
        
        Returns:
            Latent features, shape (batch, n_context, latent_dim)
        """
        return self.set_encoder(x_context, y_context)
    
    def aggregate_to_grid(
        self,
        x_context: torch.Tensor,
        latent_features: torch.Tensor,
        batch_size: int
    ) -> torch.Tensor:
        """
        Aggregate context features onto regular grid.
        
        Args:
            x_context: Context locations, shape (batch, n_context, spatial_dim)
            latent_features: Encoded features, shape (batch, n_context, latent_dim)
            batch_size: Batch size
        
        Returns:
            Grid features, shape (batch, latent_dim+1, *grid_shape)
        """
        device = x_context.device
        x_grid = self.grid_manager.get_grid(batch_size, device=device)
        
        aggregated = self.aggregator(x_context, latent_features, x_grid)
        aggregated = aggregated.transpose(1, 2)
        grid_features = self.grid_manager.reshape_to_grid(aggregated)
        
        return grid_features
    
    def forward(
        self,
        x_context: torch.Tensor,
        y_context: torch.Tensor,
        x_target: Optional[torch.Tensor] = None,
        lambda_params: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass of BSNP.
        
        Args:
            x_context: Context locations, shape (batch, n_context, spatial_dim)
            y_context: Context observations, shape (batch, n_context, obs_dim)
            x_target: Target locations (if None, uses full grid)
            lambda_params: PDE parameters, shape (batch, param_dim)
        
        Returns:
            mean: Predicted mean, shape (batch, output_dim, n_target)
            sigma: Predicted std dev, shape (batch, output_dim, n_target)
        """
        batch_size = x_context.shape[0]
        device = x_context.device
        
        # Encode context set
        latent_features = self.encode_context(x_context, y_context)
        
        # Aggregate onto grid
        grid_features = self.aggregate_to_grid(
            x_context, latent_features, batch_size
        )
        
        # Add parameter conditioning if enabled
        if self.use_parameter_conditioning and lambda_params is not None:
            lambda_encoded = self.parameter_encoder(lambda_params)
            
            if self.spatial_dim == 1:
                lambda_grid = lambda_encoded.unsqueeze(-1).expand(
                    -1, -1, self.grid_manager.grid_resolution[0]
                )
            elif self.spatial_dim == 2:
                lambda_grid = lambda_encoded.unsqueeze(-1).unsqueeze(-1).expand(
                    -1, -1,
                    self.grid_manager.grid_resolution[0],
                    self.grid_manager.grid_resolution[1]
                )
            else:
                raise NotImplementedError("3D parameter conditioning not implemented")
            
            grid_features = torch.cat([grid_features, lambda_grid], dim=1)
        
        # Process with backbone
        processed = self.backbone(grid_features)
        
        # Decode to distribution parameters
        mean, sigma = self.decoder(processed)
        
        # If target points specified, interpolate to them
        if x_target is not None:
            mean_flat = self.grid_manager.flatten_from_grid(mean)
            sigma_flat = self.grid_manager.flatten_from_grid(sigma)
            
            x_grid = self.grid_manager.get_grid(batch_size, device=device)
            indices = self.grid_manager.nearest_grid_indices(x_target)
            
            if self.spatial_dim == 1:
                idx = indices[..., 0]
                mean_target = torch.gather(
                    mean_flat, 2, idx.unsqueeze(1).expand(-1, self.output_dim, -1)
                )
                sigma_target = torch.gather(
                    sigma_flat, 2, idx.unsqueeze(1).expand(-1, self.output_dim, -1)
                )
            else:
                mean_target = mean_flat
                sigma_target = sigma_flat
            
            return mean_target, sigma_target
        else:
            mean_flat = self.grid_manager.flatten_from_grid(mean)
            sigma_flat = self.grid_manager.flatten_from_grid(sigma)
            return mean_flat, sigma_flat
    
    def get_mean_field_on_grid(
        self,
        x_context: torch.Tensor,
        y_context: torch.Tensor,
        lambda_params: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Get mean field predictions on the internal grid WITHOUT resampling.
        
        This is critical for physics loss computation - we process through
        the full pipeline and return the mean field on the grid with gradients
        intact for PDE residual computation.
        
        Args:
            x_context: Context locations, shape (batch, n_context, spatial_dim)
            y_context: Context observations, shape (batch, n_context, obs_dim)
            lambda_params: PDE parameters, shape (batch, param_dim)
        
        Returns:
            Mean field on grid, shape (batch, output_dim, n_grid)
        """
        batch_size = x_context.shape[0]
        
        # Encode context
        latent_features = self.encode_context(x_context, y_context)
        
        # Aggregate onto grid
        grid_features = self.aggregate_to_grid(
            x_context, latent_features, batch_size
        )
        
        # Add parameter conditioning if enabled
        if self.use_parameter_conditioning and lambda_params is not None:
            lambda_encoded = self.parameter_encoder(lambda_params)
            
            if self.spatial_dim == 1:
                lambda_grid = lambda_encoded.unsqueeze(-1).expand(
                    -1, -1, self.grid_manager.grid_resolution[0]
                )
            elif self.spatial_dim == 2:
                lambda_grid = lambda_encoded.unsqueeze(-1).unsqueeze(-1).expand(
                    -1, -1,
                    self.grid_manager.grid_resolution[0],
                    self.grid_manager.grid_resolution[1]
                )
            
            grid_features = torch.cat([grid_features, lambda_grid], dim=1)
        
        # Process with backbone
        processed = self.backbone(grid_features)
        
        # Decode - get only mean, not sigma
        # The decoder returns (mean, sigma) tuple
        mean, _ = self.decoder(processed)
        
        # Flatten to (batch, output_dim, n_grid)
        mean_flat = self.grid_manager.flatten_from_grid(mean)
        
        return mean_flat
    
    def predict(
        self,
        x_context: torch.Tensor,
        y_context: torch.Tensor,
        x_target: torch.Tensor,
        lambda_params: Optional[torch.Tensor] = None,
        num_samples: int = 1
    ) -> torch.Tensor:
        """Make predictions at target locations."""
        mean, sigma = self.forward(x_context, y_context, x_target, lambda_params)
        
        if num_samples == 0:
            return mean.unsqueeze(0)
        
        samples = []
        for _ in range(num_samples):
            epsilon = torch.randn_like(mean)
            sample = mean + sigma * epsilon
            samples.append(sample)
        
        return torch.stack(samples, dim=0)
    
    def get_grid_predictions(
        self,
        x_context: torch.Tensor,
        y_context: torch.Tensor,
        lambda_params: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Get predictions on the full latent grid."""
        batch_size = x_context.shape[0]
        device = x_context.device
        
        mean, sigma = self.forward(x_context, y_context, None, lambda_params)
        x_grid = self.grid_manager.get_grid(batch_size, device=device)
        
        return x_grid, mean, sigma


class PIConvNPWithPhysics(PIConvNP):
    """
    Extended BSNP that includes physics loss computation.
    """
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def compute_pde_residual(
        self,
        x_collocation: torch.Tensor,
        mean_field: torch.Tensor,
        lambda_params: torch.Tensor,
        pde_fn: callable
    ) -> torch.Tensor:
        """Compute PDE residual at collocation points."""
        x_collocation = x_collocation.requires_grad_(True)
        residual = pde_fn(x_collocation, mean_field, lambda_params)
        return residual