"""
Graph Neural Network models for TensorGalerkin
"""

import torch
import torch.nn as nn
import torch_geometric.nn as gnn

from .base import Activation


class plain_GNN(nn.Module):
    """
    Simple GNN implementation with various convolution types
    """
    
    def __init__(self, 
                 num_features,
                 num_hidden, 
                 num_classes, 
                 num_layers,
                 gnn_type="gcn",
                 activation="relu",
                 dropout_in=0,
                 dropout=0):
        super().__init__()

        self.gnn_type = gnn_type.lower()
        self.dropout_in = dropout_in
        self.dropout = dropout

        self.DropoutIn = nn.Dropout(dropout_in)
        
        if gnn_type == 'gcn':
            # init GCN pipeline
            layers = [gnn.GCNConv(num_features, num_hidden)]
            for i in range(num_layers - 2):
                layers += [gnn.GCNConv(num_hidden, num_hidden)]
            layers.append(gnn.GCNConv(num_hidden, num_classes))

        elif gnn_type == "gat":
            # init GAT pipeline
            num_heads = 4
            layers = [gnn.GATConv(num_features, num_hidden // num_heads, heads=num_heads)]
            for i in range(num_layers - 2):
                layers += [gnn.GATConv(num_hidden, num_hidden // num_heads, heads=num_heads)]
            layers.append(gnn.GATConv(num_hidden, num_classes, heads=num_heads))

        elif gnn_type == "sage":
            # init GraphSAGE pipeline
            layers = [gnn.SAGEConv(num_features, num_hidden)]
            for i in range(num_layers - 2):
                layers += [gnn.SAGEConv(num_hidden, num_hidden)]
            layers.append(gnn.SAGEConv(num_hidden, num_classes))
            
        elif self.gnn_type == "sgc":
            K = 2  # Number of hops
            layers = [gnn.SGConv(num_features, num_hidden, K=K)]
            for i in range(num_layers - 2):
                layers += [gnn.SGConv(num_hidden, num_hidden, K=K)]
            layers.append(gnn.SGConv(num_hidden, num_classes, K=K))

        else:
            raise NotImplementedError(f"Convolution type {self.gnn_type} not implemented")
        
        # Construct the pipeline using the layers
        pipeline = []
        for i, layer in enumerate(layers):
            if i != 0:
                pipeline += [Activation(activation), nn.Dropout(dropout)]
            pipeline += [(layer, 'x, edge_index -> x')]
        self.pipeline = gnn.Sequential('x, edge_index', pipeline)

        self.reset_parameters()

    def reset_parameters(self):
        if isinstance(self.pipeline, nn.ModuleList) or isinstance(self.pipeline, list):
            for module in self.pipeline:
                if hasattr(module, 'reset_parameters'):
                    module.reset_parameters()
        else:  # If it's a single module
            if hasattr(self.pipeline, 'reset_parameters'):
                self.pipeline.reset_parameters()

    def forward(self, x, edge_index):
        x = self.DropoutIn(x)
        x = self.pipeline(x, edge_index)
        
        # If gnn_type is "gat", reshape and average the output over the attention heads
        if self.gnn_type == "gat":
            n_nodes = x.size(0)
            x = x.view(n_nodes, -1, 4).mean(dim=-1)
    
        return x


def init_model(num_features, num_classes, config):
    """
    Initialize a simple GNN model for basic use cases
    
    Parameters:
    -----------
        num_features: int
            Number of input features
        num_classes: int  
            Number of output classes
        config: object
            Configuration object with model parameters
            
    Returns:
    --------
        torch.nn.Module
            Initialized GNN model
    """
    return plain_GNN(
        num_features=num_features,
        num_hidden=config.n_hidden,
        num_classes=num_classes,
        num_layers=config.n_layers,
        gnn_type=config.gnn,
        dropout_in=config.dropout_in,
        dropout=config.dropout
    )