"""
Graph Neural Network components for GLEAM-AI.

This module contains the Diffusion Convolutional Recurrent Neural Network (DCRNN)
implementation used for spatial modeling in the STNP framework.
"""

import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch_geometric.utils import to_dense_adj, dense_to_sparse
from torch_geometric.nn.conv import MessagePassing

# Try to import torch_scatter, fall back to pure PyTorch if not available
try:
    import torch_scatter
    HAS_TORCH_SCATTER = True
except ImportError:
    HAS_TORCH_SCATTER = False
    # Only show warning if explicitly requested via environment variable
    import os
    if os.environ.get('GLEAM_SHOW_SCATTER_WARNING', '').lower() == 'true':
        import warnings
        warnings.warn(
            "torch_scatter not found. Using PyTorch fallback implementation. "
            "For better performance on CUDA systems, install torch-scatter with: "
            "pip install torch-scatter -f https://data.pyg.org/whl/torch-2.5.0+cu124.html"
        )
    # Note: torch-scatter is not available for Apple Silicon and the fallback works well


class DConv(MessagePassing):
    """
    Diffusion convolution layer for graph neural networks.
    
    This layer implements diffusion convolution as described in the DCRNN paper,
    allowing information to flow through the graph structure.
    
    Args:
        in_channels: Number of input features per node
        out_channels: Number of output features per node
        K: Filter size (number of diffusion steps)
        bias: Whether to include learnable bias terms
    """
    
    def __init__(self, in_channels: int, out_channels: int, K: int, bias: bool = True):
        super(DConv, self).__init__(aggr="add", flow="source_to_target")
        
        if K <= 0:
            raise ValueError("Filter size K must be positive")
            
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.K = K
        
        # Weight parameters for forward and backward diffusion
        self.weight = nn.Parameter(torch.Tensor(2, K, in_channels, out_channels))
        
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter("bias", None)
        
        self._reset_parameters()
    
    def _reset_parameters(self) -> None:
        """Initialize parameters using Xavier uniform initialization."""
        nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            nn.init.zeros_(self.bias)
    
    def message(self, x_j: torch.Tensor, norm: torch.Tensor) -> torch.Tensor:
        """
        Message function for message passing.
        
        Args:
            x_j: Node features of neighboring nodes
            norm: Normalization factors
            
        Returns:
            Normalized messages
        """
        return norm.view(-1, 1) * x_j
    
    def calculate_norms(
        self, 
        edge_index: torch.Tensor, 
        edge_weight: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Calculate normalization factors for in-degree and out-degree.
        
        Args:
            edge_index: Graph edge indices [2, num_edges]
            edge_weight: Edge weights [num_edges]
            
        Returns:
            Tuple of (out_degree_norms, in_degree_norms)
        """
        eps = 1e-8
        row, col = edge_index
        num_nodes = edge_index.max().item() + 1
        
        # Calculate out-degree and in-degree
        if HAS_TORCH_SCATTER:
            # Use torch_scatter if available (faster)
            out_deg = torch_scatter.scatter_add(edge_weight, row, dim_size=num_nodes)
            in_deg = torch_scatter.scatter_add(edge_weight, col, dim_size=num_nodes)
        else:
            # Fallback to pure PyTorch implementation
            out_deg = torch.zeros(num_nodes, dtype=edge_weight.dtype, device=edge_weight.device)
            in_deg = torch.zeros(num_nodes, dtype=edge_weight.dtype, device=edge_weight.device)
            out_deg = out_deg.scatter_add(0, row, edge_weight)
            in_deg = in_deg.scatter_add(0, col, edge_weight)
        
        # Calculate normalization factors
        norm_out = 1.0 / (out_deg + eps)
        norm_in = 1.0 / (in_deg + eps)
        
        return norm_out, norm_in
    
    def forward(
        self,
        X: torch.Tensor,
        edge_index: torch.Tensor,
        edge_weight: torch.Tensor,
    ) -> torch.Tensor:
        """
        Forward pass of diffusion convolution.
        
        Args:
            X: Node features [batch_size, num_nodes, in_channels]
            edge_index: Graph edge indices [2, num_edges]
            edge_weight: Edge weights [num_edges]
            
        Returns:
            Convolved features [batch_size, num_nodes, out_channels]
        """
        # Calculate normalization factors
        norm_out, norm_in = self.calculate_norms(edge_index, edge_weight)
        # Swap source and destination to reverse edges (keep shape as [2, num_edges])
        reverse_edge_index = edge_index[[1, 0], :]
        
        # Initialize with zeroth-order terms
        Tx_0 = X
        Tx_1 = X
        H = (torch.matmul(Tx_0, (self.weight[0])[0]) + 
             torch.matmul(Tx_0, (self.weight[1])[0]))
        
        # First-order diffusion
        if self.weight.size(1) > 1:
            Tx_1_o = self.propagate(edge_index, x=X, norm=norm_out, size=None)
            Tx_1_i = self.propagate(reverse_edge_index, x=X, norm=norm_in, size=None)
            H = (H + 
                 torch.matmul(Tx_1_o, (self.weight[0])[1]) + 
                 torch.matmul(Tx_1_i, (self.weight[1])[1]))
        
        # Higher-order diffusion (Chebyshev polynomials)
        for k in range(2, self.weight.size(1)):
            Tx_2_o = self.propagate(edge_index, x=Tx_1_o, norm=norm_out, size=None)
            Tx_2_o = 2.0 * Tx_2_o - Tx_0
            Tx_2_i = self.propagate(reverse_edge_index, x=Tx_1_i, norm=norm_in, size=None)
            Tx_2_i = 2.0 * Tx_2_i - Tx_0
            
            H = (H + 
                 torch.matmul(Tx_2_o, (self.weight[0])[k]) + 
                 torch.matmul(Tx_2_i, (self.weight[1])[k]))
            
            Tx_0, Tx_1_o, Tx_1_i = Tx_1, Tx_2_o, Tx_2_i
        
        # Add bias if present
        if self.bias is not None:
            H += self.bias
        
        return H


class DCRNN2(nn.Module):
    """
    Diffusion Convolutional Recurrent Neural Network (DCRNN).
    
    This implementation combines diffusion convolution with gated recurrent units
    for spatio-temporal modeling. It's based on the paper:
    "Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting"
    
    Args:
        in_channels: Number of input features per node
        out_channels: Number of output features per node
        K: Filter size (number of diffusion steps)
        bias: Whether to include learnable bias terms
    """
    
    def __init__(self, in_channels: int, out_channels: int, K: int, bias: bool = True):
        super(DCRNN2, self).__init__()
        
        if in_channels <= 0:
            raise ValueError("in_channels must be positive")
        if out_channels <= 0:
            raise ValueError("out_channels must be positive")
        if K <= 0:
            raise ValueError("Filter size K must be positive")
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.K = K
        self.bias = bias
        
        self._create_parameters_and_layers()
    
    def _create_update_gate_parameters_and_layers(self) -> None:
        """Create parameters and layers for the update gate."""
        self.conv_x_z = DConv(
            in_channels=self.in_channels + self.out_channels,
            out_channels=self.out_channels,
            K=self.K,
            bias=self.bias,
        )
    
    def _create_reset_gate_parameters_and_layers(self) -> None:
        """Create parameters and layers for the reset gate."""
        self.conv_x_r = DConv(
            in_channels=self.in_channels + self.out_channels,
            out_channels=self.out_channels,
            K=self.K,
            bias=self.bias,
        )
    
    def _create_candidate_state_parameters_and_layers(self) -> None:
        """Create parameters and layers for the candidate state."""
        self.conv_x_h = DConv(
            in_channels=self.in_channels + self.out_channels,
            out_channels=self.out_channels,
            K=self.K,
            bias=self.bias,
        )
    
    def _create_parameters_and_layers(self) -> None:
        """Create all parameters and layers for the DCRNN."""
        self._create_update_gate_parameters_and_layers()
        self._create_reset_gate_parameters_and_layers()
        self._create_candidate_state_parameters_and_layers()
    
    def _set_hidden_state(self, X: torch.Tensor, H: Optional[torch.Tensor]) -> torch.Tensor:
        """
        Initialize hidden state if not provided.
        
        Args:
            X: Input features [batch_size, num_nodes, in_channels]
            H: Optional hidden state [batch_size, num_nodes, out_channels]
            
        Returns:
            Hidden state [batch_size, num_nodes, out_channels]
        """
        if H is None:
            H = torch.zeros(X.shape[0], X.shape[1], self.out_channels).to(X.device)
        return H
    
    def _calculate_update_gate(
        self, 
        X: torch.Tensor, 
        edge_index: torch.Tensor, 
        edge_weight: torch.Tensor, 
        H: torch.Tensor
    ) -> torch.Tensor:
        """
        Calculate the update gate.
        
        Args:
            X: Input features [batch_size, num_nodes, in_channels]
            edge_index: Graph edge indices [2, num_edges]
            edge_weight: Edge weights [num_edges]
            H: Hidden state [batch_size, num_nodes, out_channels]
            
        Returns:
            Update gate values [batch_size, num_nodes, out_channels]
        """
        Z = torch.cat([X, H], dim=2)
        Z = self.conv_x_z(Z, edge_index, edge_weight)
        Z = torch.sigmoid(Z)
        return Z
    
    def _calculate_reset_gate(
        self, 
        X: torch.Tensor, 
        edge_index: torch.Tensor, 
        edge_weight: torch.Tensor, 
        H: torch.Tensor
    ) -> torch.Tensor:
        """
        Calculate the reset gate.
        
        Args:
            X: Input features [batch_size, num_nodes, in_channels]
            edge_index: Graph edge indices [2, num_edges]
            edge_weight: Edge weights [num_edges]
            H: Hidden state [batch_size, num_nodes, out_channels]
            
        Returns:
            Reset gate values [batch_size, num_nodes, out_channels]
        """
        R = torch.cat([X, H], dim=2)
        R = self.conv_x_r(R, edge_index, edge_weight)
        R = torch.sigmoid(R)
        return R
    
    def _calculate_candidate_state(
        self, 
        X: torch.Tensor, 
        edge_index: torch.Tensor, 
        edge_weight: torch.Tensor, 
        H: torch.Tensor, 
        R: torch.Tensor
    ) -> torch.Tensor:
        """
        Calculate the candidate state.
        
        Args:
            X: Input features [batch_size, num_nodes, in_channels]
            edge_index: Graph edge indices [2, num_edges]
            edge_weight: Edge weights [num_edges]
            H: Hidden state [batch_size, num_nodes, out_channels]
            R: Reset gate values [batch_size, num_nodes, out_channels]
            
        Returns:
            Candidate state values [batch_size, num_nodes, out_channels]
        """
        H_tilde = torch.cat([X, H * R], dim=2)
        H_tilde = self.conv_x_h(H_tilde, edge_index, edge_weight)
        H_tilde = torch.tanh(H_tilde)
        return H_tilde
    
    def _calculate_hidden_state(
        self, 
        Z: torch.Tensor, 
        H: torch.Tensor, 
        H_tilde: torch.Tensor
    ) -> torch.Tensor:
        """
        Calculate the new hidden state.
        
        Args:
            Z: Update gate values [batch_size, num_nodes, out_channels]
            H: Previous hidden state [batch_size, num_nodes, out_channels]
            H_tilde: Candidate state [batch_size, num_nodes, out_channels]
            
        Returns:
            New hidden state [batch_size, num_nodes, out_channels]
        """
        H = Z * H + (1 - Z) * H_tilde
        return H
    
    def forward(
        self,
        X: torch.Tensor,
        edge_index: torch.Tensor,
        edge_weight: Optional[torch.Tensor] = None,
        H: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Forward pass of the DCRNN.
        
        Args:
            X: Input features [batch_size, num_nodes, in_channels]
            edge_index: Graph edge indices [2, num_edges]
            edge_weight: Optional edge weights [num_edges]
            H: Optional hidden state [batch_size, num_nodes, out_channels]
            
        Returns:
            Hidden state [batch_size, num_nodes, out_channels]
        """
        if edge_weight is None:
            edge_weight = torch.ones(edge_index.size(1)).to(X.device)
        
        # Initialize hidden state if not provided
        H = self._set_hidden_state(X, H)
        
        # Calculate gates and candidate state
        Z = self._calculate_update_gate(X, edge_index, edge_weight, H)
        R = self._calculate_reset_gate(X, edge_index, edge_weight, H)
        H_tilde = self._calculate_candidate_state(X, edge_index, edge_weight, H, R)
        
        # Update hidden state
        H = self._calculate_hidden_state(Z, H, H_tilde)
        
        return H
