import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
import torch

class GAT(nn.Module):
    def __init__(self, args, input_dim, output_dim, hid_dim):
        super(GAT, self).__init__()
        self.num_layers = args.K
        self.args = args
        self.dropout = args.dropout
        self.threshold = args.threshold
        self.activation = args.activation_fn
        if self.activation == 'relu':
            self.activation_fn = F.relu
        elif self.activation == 'leaky_relu':
            self.activation_fn = F.leaky_relu
        elif self.activation == 'tanh':
            self.activation_fn = F.tanh
        elif self.activation == 'sigmoid':
            self.activation_fn = F.sigmoid
        else:
            raise ValueError(f"Unsupported activation function: {self.activation}")

        self.layers = nn.ModuleList()
        if self.num_layers > 1:
            self.layers.append(GATConv(input_dim, hid_dim, heads=args.heads))
            # For middle layers
            for _ in range(self.num_layers - 2):
                # Use heads * hid_dim as input because GATConv concatenates the outputs from each head
                self.layers.append(GATConv(args.heads * hid_dim, hid_dim, heads=args.heads))
            # For the last layer
            self.layers.append(GATConv(args.heads * hid_dim, hid_dim, heads=args.heads))
        else:
            # For a single layer network
            self.layers.append(GATConv(input_dim, hid_dim, heads=1))
        self.output = nn.Linear(hid_dim*args.heads, output_dim)


        if args.rest_param:
            self.reset_parameter()


    def reset_parameter(self):
        for layer in self.layers:
            nn.init.xavier_uniform_(layer.lin.weight.data)
            if layer.lin.bias is not None:
                layer.lin.bias.data.zero_()
        nn.init.xavier_uniform_(self.output.weight.data)
        if self.output.bias is not None:
            self.output.bias.data.zero_()


    def forward(self, data):
        if not self.args.original:
            data = self.filter_edges_by_threshold(data, self.threshold)
        x, edge_index = data.x, data.edge_index

        if data.edge_weight is not None:
            edge_weight = data.edge_weight
        else:
            edge_weight = torch.ones(edge_index.size(1), device=edge_index.device)


        for i, layer in enumerate(self.layers):
            x = layer(x, edge_index, edge_weight)
            x = self.activation_fn(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        logits = self.output(x)

        return logits, x
    def filter_edges_by_threshold(self,data, threshold):
        """
        Filter edges with weights below the given threshold.

        Args:
            data: The graph data object containing edge_index and edge_weight
            threshold: The weight threshold below which edges will be removed

        Returns:
            Updated data object with filtered edges
        """
        # Check if edge weights exist
        if not hasattr(data, 'edge_weight') or data.edge_weight is None:
            # If no edge weights, return the original data
            return data

        # Create mask for edges to keep (where weight >= threshold)
        mask = data.edge_weight >= threshold

        # Apply mask to filter both edge_index and edge_weight
        data.edge_index = data.edge_index[:, mask]
        data.edge_weight = data.edge_weight[mask]

        return data