"""
Embedding model components for GLEAM-AI.

This module contains the EmbedModel class that combines spatial and temporal features
using graph neural networks for the STNP framework.
"""

import torch
import torch.nn as nn
from typing import Optional, Tuple
from ..gnn import DCRNN2


class EmbedModel(nn.Module):
    """
    Embedding model that combines spatial and temporal features.
    
    This model uses a Diffusion Convolutional Recurrent Neural Network (DCRNN)
    to process spatial relationships in the graph, then aggregates the results
    into a fixed-size embedding representation.
    
    Args:
        in_channels: Number of input features per node
        embed_out_dim: Output embedding dimension
        out_channels: Number of output channels from the GNN
        max_diffusion_step: Maximum diffusion steps for DCRNN
        num_nodes: Number of nodes in the graph
    """
    
    def __init__(
        self, 
        in_channels: int,
        embed_out_dim: int, 
        out_channels: int, 
        max_diffusion_step: int,
        num_nodes: int,
    ) -> None:
        super().__init__()
        
        if in_channels <= 0:
            raise ValueError("in_channels must be positive")
        if embed_out_dim <= 0:
            raise ValueError("embed_out_dim must be positive")
        if out_channels <= 0:
            raise ValueError("out_channels must be positive")
        if max_diffusion_step <= 0:
            raise ValueError("max_diffusion_step must be positive")
        if num_nodes <= 0:
            raise ValueError("num_nodes must be positive")
        
        self.in_channels = in_channels
        self.embed_out_dim = embed_out_dim
        self.out_channels = out_channels
        self.max_diffusion_step = max_diffusion_step
        self.num_nodes = num_nodes
        
        # DCRNN for spatial feature processing
        self.recurrent = DCRNN2(
            in_channels=in_channels, 
            out_channels=out_channels, 
            K=max_diffusion_step
        )
        
        # Fully connected layer for embedding
        self.fc = nn.Linear(out_channels * num_nodes, embed_out_dim)
    
    def forward(
        self, 
        inputs: torch.Tensor, 
        edge_index: torch.Tensor, 
        edge_weight: torch.Tensor, 
        hidden_state: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass of the embedding model.
        
        Args:
            inputs: Node features [batch_size, num_nodes, in_channels]
            edge_index: Graph edge indices [2, num_edges]
            edge_weight: Graph edge weights [num_edges]
            hidden_state: Optional hidden state from previous timestep
                [batch_size, num_nodes, out_channels]
            
        Returns:
            Tuple of:
            - output: Embedding output [batch_size, embed_out_dim]
            - hidden_state: Updated hidden state [batch_size, num_nodes, out_channels]
        """
        batch_size = inputs.size(0)
        
        # Validate input shapes
        if inputs.size(1) != self.num_nodes:
            raise ValueError(f"Expected {self.num_nodes} nodes, got {inputs.size(1)}")
        if inputs.size(2) != self.in_channels:
            raise ValueError(f"Expected {self.in_channels} input channels, got {inputs.size(2)}")
        
        # Process spatial features through DCRNN
        hidden_states = self.recurrent(
            inputs, 
            edge_index.to(inputs.device), 
            edge_weight.to(inputs.device), 
            hidden_state
        )
        
        # Aggregate all node states into embedding
        output = self.fc(hidden_states.reshape(batch_size, -1))
        
        return output, hidden_states
    
    def get_embedding_dim(self) -> int:
        """Get the output embedding dimension."""
        return self.embed_out_dim
    
    def get_num_nodes(self) -> int:
        """Get the number of nodes in the graph."""
        return self.num_nodes
