import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCN2Conv
import torch


class GCN2(nn.Module):
    def __init__(self, args, input_dim, output_dim, hid_dim):
        super(GCN2, self).__init__()
        self.num_layers = args.K
        self.dropout = args.dropout

        self.activation = args.activation_fn
        if self.activation == 'relu':
            self.activation_fn = F.relu
        elif self.activation == 'leaky_relu':
            self.activation_fn = F.leaky_relu
        elif self.activation == 'tanh':
            self.activation_fn = F.tanh
        elif self.activation == 'sigmoid':
            self.activation_fn = F.sigmoid
        else:
            raise ValueError(f"Unsupported activation function: {self.activation}")

        # Initial transformation
        self.lins = nn.Linear(input_dim, hid_dim)

        # GCN2Conv layers
        self.convs = nn.ModuleList()
        for i in range(self.num_layers):
            self.convs.append(GCN2Conv(
                channels=hid_dim,
                alpha=args.alpha,
                theta=args.theta,
                layer=i + 1  # Layer number is important for GCN2Conv
            ))

        self.output = nn.Linear(hid_dim, output_dim)

        # Fix parameter name from rest_param to reset_param
        if hasattr(args, 'reset_param') and args.reset_param:
            self.reset_parameter()
        elif hasattr(args, 'rest_param') and args.rest_param:
            self.reset_parameter()

    def reset_parameter(self):
        # Reset initial transformation
        nn.init.xavier_uniform_(self.lins.weight.data)
        if self.lins.bias is not None:
            self.lins.bias.data.zero_()

        # Reset convolution layers
        for conv in self.convs:
            nn.init.xavier_uniform_(conv.lin.weight.data)
            if conv.lin.bias is not None:
                conv.lin.bias.data.zero_()

        # Reset output layer
        nn.init.xavier_uniform_(self.output.weight.data)
        if self.output.bias is not None:
            self.output.bias.data.zero_()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        if hasattr(data, 'edge_weight') and data.edge_weight is not None:
            edge_weight = data.edge_weight
        else:
            edge_weight = torch.ones(edge_index.size(1), device=edge_index.device)

        # Initial transformation
        h = self.lins(x)
        h = F.dropout(h, p=self.dropout, training=self.training)

        # Save the initial representation (x₀) for the skip connections
        x0 = h.clone()

        # Apply convolutional layers
        for i, conv in enumerate(self.convs):
            # Important: GCN2Conv needs both current features h and initial features x0
            h = conv(h, x0, edge_index, edge_weight)
            h = self.activation_fn(h)
            h = F.dropout(h, p=self.dropout, training=self.training)

        # Final linear transformation
        logits = self.output(h)

        return logits, h