import torch
import torch.nn as nn
import torch.nn.functional as F
from sympy.abc import theta
from torch_geometric.nn import GCN2Conv

from src.models.model_utils import ACTIVATION_MAPPING


# Define the GCN model
class GCNII(nn.Module):

    def __init__(self,
                 output_dim: int,
                 final_activation: torch.functional,
                 hidden_dims: list[int] = None,
                 hidden_dim: int = None,
                 n_message_passings: int = None,
                 alpha: float = 0.5,
                 theta: float = 0.5,
                 pooling: str = None,
                 dropout=0.5):
        super(GCNII, self).__init__()

        if hidden_dims is None and hidden_dim is not None:
            hidden_dims = [hidden_dim] * n_message_passings

        # Embedding layer
        self.embedding = nn.LazyLinear(out_features=hidden_dims[0])

        self.conv_layers = nn.ModuleList()
        for i in range(1, len(hidden_dims)):
            self.conv_layers.append(GCN2Conv(channels=hidden_dims[i - 1],
                         alpha=alpha,
                         theta=theta,
                         layer=i))

        self.readout = nn.Linear(hidden_dims[-1], output_dim)
        self.dropout_rate = dropout
        self.activation = ACTIVATION_MAPPING[final_activation]

    def forward(self, x, edge_index, edge_attr=None, batch=None):
        # Apply embedding layer
        x = self.embedding(x)
        x0 = x

        # Apply message passing layers
        for conv in self.conv_layers:
            x = conv(x=x, x_0=x0, edge_index=edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout_rate, training=self.training)

        # Safe final node representations in evaluation mode
        if not self.training:
            x_L = x

        # Apply readout layer
        x = self.activation(self.readout(x), dim=1)

        output = x if self.training else (x, x_L)

        return output
