"""
Graph Neural Network models for irregular meshes

This module provides various GNN architectures optimized for 
irregular mesh-based PDE solving.
"""

import torch
import torch.nn as nn
import torch_geometric.nn as gnn
from torch_geometric.nn import MessagePassing
from typing import Optional
from abc import ABC, abstractmethod

from .base import MLP, Identity, Activation
from .components import Encoder, Decoder, Processor, create_encoder, create_decoder


class GraphProcessor(Processor):
    """Base class for graph-based processors"""
    
    def __init__(self, num_features: int, num_classes: int):
        super().__init__()
        self.num_features = num_features
        self.num_classes = num_classes
    
    @abstractmethod
    def forward(self, x, edge_index):
        pass


class GraphConvolutionalNetwork(GraphProcessor):
    """Graph Convolutional Network (GCN)"""
    
    def __init__(self, num_features: int, num_classes: int, 
                 num_hidden: int = 64, num_layers: int = 3, 
                 activation: str = "relu", dropout: float = 0.):
        super().__init__(num_features, num_classes)
        
        self.layers = nn.ModuleList([gnn.GCNConv(num_features, num_hidden)])
        for _ in range(num_layers - 2):
            self.layers.append(gnn.GCNConv(num_hidden, num_hidden))
        self.layers.append(gnn.GCNConv(num_hidden, num_classes))
        
        self.activation = Activation(activation)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, edge_index):
        for layer in self.layers[:-1]:
            x = self.activation(layer(x, edge_index))
            x = self.dropout(x)
        x = self.layers[-1](x, edge_index)
        return x


class GraphAttentionNetwork(GraphProcessor):
    """Graph Attention Network (GAT)"""
    
    def __init__(self, num_features: int, num_classes: int,
                 num_hidden: int = 64, num_layers: int = 3, 
                 num_heads: int = 4, activation: str = "relu", 
                 dropout: float = 0.):
        super().__init__(num_features, num_classes)
        
        self.layers = nn.ModuleList([
            gnn.GATConv(num_features, num_hidden // num_heads, heads=num_heads)
        ])
        for _ in range(num_layers - 2):
            self.layers.append(
                gnn.GATConv(num_hidden, num_hidden // num_heads, heads=num_heads)
            )
        self.layers.append(gnn.GATConv(num_hidden, num_classes, heads=1))
        
        self.activation = Activation(activation)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, edge_index):
        for layer in self.layers[:-1]:
            x = self.activation(layer(x, edge_index))
            x = self.dropout(x)
        x = self.layers[-1](x, edge_index)
        return x


class GraphSAGE(GraphProcessor):
    """GraphSAGE Network"""
    
    def __init__(self, num_features: int, num_classes: int,
                 num_hidden: int = 64, num_layers: int = 3, 
                 activation: str = "relu", dropout: float = 0.,
                 aggr: str = "mean"):
        super().__init__(num_features, num_classes)
        
        self.layers = nn.ModuleList([
            gnn.SAGEConv(num_features, num_hidden, aggr=aggr)
        ])
        for _ in range(num_layers - 2):
            self.layers.append(gnn.SAGEConv(num_hidden, num_hidden, aggr=aggr))
        self.layers.append(gnn.SAGEConv(num_hidden, num_classes, aggr=aggr))
        
        self.activation = Activation(activation)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, edge_index):
        for layer in self.layers[:-1]:
            x = self.activation(layer(x, edge_index))
            x = self.dropout(x)
        x = self.layers[-1](x, edge_index)
        return x


class SIGN(GraphProcessor):
    """Scalable Inception Graph Neural Network"""
    
    def __init__(self, num_features: int, num_classes: int,
                 num_hidden: int = 64, num_layers: int = 3, 
                 num_hops: int = 3, activation: str = "relu"):
        super().__init__(num_features, num_classes)
        
        self.props = nn.ModuleList([
            gnn.SGConv(num_features, num_hidden, K=k) 
            for k in range(1, num_hops + 1)
        ])
        
        self.branches = nn.ModuleList([
            MLP(num_features, num_hidden, num_hidden, num_layers, activation, res=True)
        ] + [
            MLP(num_hidden, num_hidden, num_hidden, num_layers - 1, activation, res=True) 
            for _ in range(num_hops)
        ])
        
        self.merger = MLP(
            num_hidden * (num_hops + 1), num_classes, 
            num_hidden, num_layers, activation
        )
    
    def forward(self, x, edge_index):
        props = [x] + [prop(x, edge_index) for prop in self.props]
        branches = [branch(prop) for prop, branch in zip(props, self.branches)]
        x = torch.cat(branches, dim=-1)
        x = self.merger(x)
        return x


class InteractionNetwork(MessagePassing):
    """Interaction Graph Network component"""
    
    def __init__(self, num_features: int, num_classes: int, 
                 num_hidden: int = 64, activation: str = "relu"):
        super().__init__(aggr='mean')
        
        self.message_net = nn.Sequential(
            nn.Linear(2 * num_features, num_hidden),
            nn.Linear(num_hidden, num_hidden)
        )
        self.update_net = nn.Sequential(
            nn.Linear(num_features + num_hidden, num_hidden),
            nn.Linear(num_hidden, num_classes)
        )
        self.activation = Activation(activation)
        self.norm = nn.InstanceNorm1d(num_features)
        self.num_features = num_features
        self.num_classes = num_classes
    
    def forward(self, x, edge_index):
        x = self.propagate(edge_index, x=x)
        x = self.norm(x.permute(0, 2, 1)).permute(0, 2, 1)
        return x
    
    def message(self, x_i: torch.Tensor, x_j: torch.Tensor) -> torch.Tensor:
        message = self.message_net[0](torch.cat([x_i, x_j], dim=-1))
        message = self.activation(message)
        message = self.message_net[1](message)
        message = self.activation(message)
        return message
    
    def update(self, message: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        update = self.update_net[0](torch.cat([x, message], dim=-1))
        update = self.activation(update)
        update = self.update_net[1](update)
        update = self.activation(update)
        
        if self.num_features == self.num_classes:
            update = update + x
        return update


class MPNP(GraphProcessor):
    """Message Passing Neural Physics Network"""
    
    def __init__(self, num_features: int, num_classes: int,
                 num_hidden: int = 64, num_layers: int = 3, 
                 activation: str = "relu"):
        super().__init__(num_features, num_classes)
        
        self.layers = nn.ModuleList([
            InteractionNetwork(num_features, num_hidden, num_hidden, activation)
        ])
        for _ in range(num_layers - 2):
            self.layers.append(
                InteractionNetwork(num_hidden, num_hidden, num_hidden, activation)
            )
        self.layers.append(
            InteractionNetwork(num_hidden, num_classes, num_hidden, activation)
        )
    
    def forward(self, x, edge_index):
        for layer in self.layers:
            x = layer(x, edge_index)
        return x


class GNNPipeline(nn.Module):
    """Complete GNN pipeline with encoder-processor-decoder architecture"""
    
    def __init__(self, encoder: Encoder, processor: GraphProcessor, 
                 decoder: Decoder, use_input_norm: bool = True):
        super().__init__()
        
        # Determine input features for normalization
        if hasattr(encoder, "num_features"):
            num_features = encoder.num_features
        elif hasattr(processor, "num_features"):
            num_features = processor.num_features
        else:
            num_features = None
        
        self.input_norm = (nn.BatchNorm1d(num_features) 
                          if use_input_norm and num_features else None)
        self.encoder = encoder
        self.processor = processor
        self.decoder = decoder
    
    def forward(self, x, edge_index):
        """
        Parameters:
        -----------
            x: torch.FloatTensor [..., n_node, n_feature]
            edge_index: torch.LongTensor [2, num_edges]
        Returns:
        --------
            y: torch.FloatTensor [..., n_node, n_class]
        """
        if self.input_norm is not None:
            shape = x.shape
            x = self.input_norm(x.reshape(-1, x.shape[-1]))
            x = x.view(*shape)
        
        x = self.encoder(x)
        x = self.processor(x, edge_index)
        x = self.decoder(x)
        return x


def create_gnn_processor(gnn_type: str, num_features: int, num_classes: int, 
                        **kwargs) -> GraphProcessor:
    """Factory function to create GNN processors"""
    
    gnn_type = gnn_type.lower()
    
    if gnn_type == "gcn":
        return GraphConvolutionalNetwork(num_features, num_classes, **kwargs)
    elif gnn_type == "gat":
        return GraphAttentionNetwork(num_features, num_classes, **kwargs)
    elif gnn_type == "sage":
        return GraphSAGE(num_features, num_classes, **kwargs)
    elif gnn_type == "sign":
        return SIGN(num_features, num_classes, **kwargs)
    elif gnn_type == "mpnp":
        return MPNP(num_features, num_classes, **kwargs)
    else:
        raise ValueError(f"Unknown GNN type: {gnn_type}")


def create_gnn_model(num_features: int, num_classes: int, config) -> GNNPipeline:
    """Create a complete GNN model from configuration"""
    
    # Create encoder
    encoder_params = {
        'num_hidden': getattr(config, 'n_hidden', 64),
        'num_layers': getattr(config, 'encoder_n_layers', 3),
        'activation': getattr(config, 'activation', 'relu'),
        'L': getattr(config, 'encoder_frequency', 1),
        'window_size': getattr(config, 'window_size', 4),
        'bn': getattr(config, 'encoder_use_bn', False),
        'res': getattr(config, 'encoder_use_res', False)
    }
    
    encoder_out_features = (num_features if config.encoder == "identity" 
                           else encoder_params['num_hidden'])
    
    encoder = create_encoder(
        config.encoder, num_features, encoder_out_features, **encoder_params
    )
    
    # Create processor
    processor_params = {
        'num_hidden': getattr(config, 'n_hidden', 64),
        'num_layers': getattr(config, 'n_layers', 3),
        'activation': getattr(config, 'activation', 'relu'),
        'num_heads': getattr(config, 'num_heads', 4),
        'num_hops': getattr(config, 'num_hops', 3),
        'dropout': getattr(config, 'dropout', 0.)
    }
    
    processor_out_features = (encoder_out_features if config.decoder == "identity" 
                             else processor_params['num_hidden'])
    
    processor = create_gnn_processor(
        config.gnn, encoder_out_features, processor_out_features, **processor_params
    )
    
    # Create decoder
    decoder_params = {
        'num_hidden': getattr(config, 'n_hidden', 64),
        'num_layers': getattr(config, 'decoder_n_layers', 3),
        'activation': getattr(config, 'activation', 'relu'),
        'L': getattr(config, 'decoder_frequency', 1),
        'window_size': getattr(config, 'window_size', 4),
        'bn': getattr(config, 'decoder_use_bn', False),
        'res': getattr(config, 'decoder_use_res', False)
    }
    
    decoder = create_decoder(
        config.decoder, processor_out_features, num_classes, **decoder_params
    )
    
    # Create pipeline
    return GNNPipeline(
        encoder=encoder,
        processor=processor,
        decoder=decoder,
        use_input_norm=getattr(config, 'use_input_norm', True)
    )