import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GINConv, GATConv, SAGEConv,ChebConv
from torch_geometric.nn import TransformerConv

# ------------------ Define base models ------------------
class GCN(nn.Module):
    def __init__(self, in_ch, hid_ch, out_ch):
        super().__init__()
        self.conv1 = GCNConv(in_ch, hid_ch)
        self.conv2 = GCNConv(hid_ch, out_ch)
    def forward(self, data, edge_weight=None, return_first_layer=False):
        x, ei = data.x, data.edge_index
        h1 = self.conv1(x, ei, edge_weight=edge_weight)
        h1 = F.relu(h1)
        if return_first_layer:
            return h1
        h1 = F.dropout(h1, training=self.training)
        h2 = self.conv2(h1, ei, edge_weight=edge_weight)
        return F.log_softmax(h2, dim=1)

class ChebNet(nn.Module):
    def __init__(self, in_ch, hid_ch, out_ch, K=3):
        """
        A two-layer Chebyshev GCN.

        Args:
            in_ch (int):  Number of input features per node.
            hid_ch (int): Number of hidden channels.
            out_ch (int): Number of output channels (classes).
            K (int):      Order of the Chebyshev polynomial (default=3).
        """
        super().__init__()
        # ChebConv: K-hop localized spectral filter
        self.conv1 = ChebConv(in_ch, hid_ch, K=K, normalization='sym')
        self.conv2 = ChebConv(hid_ch, out_ch, K=K, normalization='sym')

    # class GCN(torch.nn.Module):
    #     def __init__(self, in_ch, hid_ch, out_ch, alpha1=1.0, alpha2=1.0):
    #         super().__init__()
    #         self.conv1 = GCNConv(in_ch,  hid_ch)
    #         self.conv2 = GCNConv(hid_ch, out_ch)
    #         self.alpha1 = alpha1
    #         self.alpha2 = alpha2
    #
    #
    #         if in_ch != hid_ch:
    #             self.res1 = torch.nn.Linear(in_ch, hid_ch)
    #         else:
    #             self.res1 = None
    #
    #
    #         if hid_ch != out_ch:
    #             self.res2 = torch.nn.Linear(hid_ch, out_ch)
    #         else:
    #             self.res2 = None
    #
    #     def forward(self, data, edge_weight=None, return_first_layer=False):
    #         x, ei = data.x, data.edge_index
    #
    #
    #         h1 = self.conv1(x, ei, edge_weight=edge_weight)
    #         res1 = self.res1(x) if self.res1 is not None else x
    #         h1 = h1 + self.alpha1 * res1
    #         h1 = F.relu(h1)
    #
    #         if return_first_layer:
    #             return h1
    #
    #         h1 = F.dropout(h1, p=0.5, training=self.training)
    #
    #
    #         h2 = self.conv2(h1, ei, edge_weight=edge_weight)
    #         res2 = self.res2(h1) if self.res2 is not None else h1
    #         h2 = h2 + self.alpha2 * res2
    #
    #         return F.log_softmax(h2, dim=1)

    def forward(self, data, edge_weight=None, return_first_layer=False):
        """
        Args:
            data:            PyG Data object, expects `data.x` and `data.edge_index`.
            edge_weight:     Optional edge weights (FloatTensor of shape [E]).
            return_first_layer (bool): If True, only returns the output of the first layer.
        """
        x, ei = data.x, data.edge_index
        x = F.dropout(x, p=0.5, training=self.training)
        # 1) First Chebyshev convolution + ReLU
        h1 = self.conv1(x, ei, edge_weight=edge_weight)
        h1 = F.relu(h1)
        if return_first_layer:
            return h1

        # 2) Dropout + second Chebyshev convolution
        h1 = F.dropout(h1, p=0.5, training=self.training)
        h2 = self.conv2(h1, ei, edge_weight=edge_weight)

        # 3) Log-softmax for classification
        return F.log_softmax(h2, dim=1)



# class ChebNet(nn.Module):
#     def __init__(self, in_ch, hid_ch, out_ch, K=3, alpha1=1.0, alpha2=1.0):
#
#         super().__init__()
#         self.conv1 = ChebConv(in_ch,  hid_ch, K=K, normalization='sym')
#         self.conv2 = ChebConv(hid_ch, out_ch, K=K, normalization='sym')
#
#         self.alpha1 = alpha1
#         self.alpha2 = alpha2
#
#
#         if in_ch != hid_ch:
#             self.res1 = nn.Linear(in_ch, hid_ch)
#         else:
#             self.res1 = None
#
#
#         if hid_ch != out_ch:
#             self.res2 = nn.Linear(hid_ch, out_ch)
#         else:
#             self.res2 = None
#
#     def forward(self, data, edge_weight=None, return_first_layer=False):
#         x, ei = data.x, data.edge_index
#
#
#         x0 = F.dropout(x, p=0.5, training=self.training)
#         h1 = self.conv1(x0, ei, edge_weight=edge_weight)
#         res1 = self.res1(x) if self.res1 is not None else x
#         h1 = h1 + self.alpha1 * res1
#         h1 = F.relu(h1)
#
#         if return_first_layer:
#             return h1
#
#
#         h1 = F.dropout(h1, p=0.5, training=self.training)
#         h2 = self.conv2(h1, ei, edge_weight=edge_weight)
#         res2 = self.res2(h1) if self.res2 is not None else h1
#         h2 = h2 + self.alpha2 * res2
#
#         return F.log_softmax(h2, dim=1)

#



class GIN(nn.Module):
    def __init__(self, in_ch, hid_ch, out_ch):
        super().__init__()
        nn1 = nn.Sequential(
            nn.Linear(in_ch, hid_ch),
            nn.ReLU(),
            nn.Linear(hid_ch, hid_ch)
        )
        self.conv1 = GINConv(nn1)
        nn2 = nn.Sequential(
            nn.Linear(hid_ch, hid_ch),
            nn.ReLU(),
            nn.Linear(hid_ch, out_ch)
        )
        self.conv2 = GINConv(nn2)
    def forward(self, data, edge_weight=None, return_first_layer=False):
        x, ei = data.x, data.edge_index
        h1 = self.conv1(x, ei); h1 = F.relu(h1)
        if return_first_layer:
            return h1
        h1 = F.dropout(h1, training=self.training)
        h2 = self.conv2(h1, ei)
        return F.log_softmax(h2, dim=1)

# class GIN(nn.Module):
#     def __init__(self, in_ch, hid_ch, out_ch, alpha1=1.0, alpha2=1.0):
#         super().__init__()
#
#         nn1 = nn.Sequential(
#             nn.Linear(in_ch, hid_ch),
#             nn.ReLU(),
#             nn.Linear(hid_ch, hid_ch)
#         )
#         self.conv1 = GINConv(nn1)
#
#         self.res1 = nn.Linear(in_ch, hid_ch) if in_ch != hid_ch else None
#         self.alpha1 = alpha1
#
#
#         nn2 = nn.Sequential(
#             nn.Linear(hid_ch, hid_ch),
#             nn.ReLU(),
#             nn.Linear(hid_ch, out_ch)
#         )
#         self.conv2 = GINConv(nn2)
#
#         self.res2 = nn.Linear(hid_ch, out_ch) if hid_ch != out_ch else None
#         self.alpha2 = alpha2
#
#     def forward(self, data, edge_weight=None, return_first_layer=False):
#         x, edge_index = data.x, data.edge_index
#
#
#         h1 = self.conv1(x, edge_index)
#         r1 = self.res1(x) if self.res1 is not None else x
#         h1 = h1 + self.alpha1 * r1
#         h1 = F.relu(h1)
#         if return_first_layer:
#             return h1
#         h1 = F.dropout(h1, p=0.5, training=self.training)
#
#
#         h2 = self.conv2(h1, edge_index)
#         r2 = self.res2(h1) if self.res2 is not None else h1
#         h2 = h2 + self.alpha2 * r2
#
#         return F.log_softmax(h2, dim=1)

class GAT(nn.Module):
    def __init__(self, in_ch, hid_ch, out_ch):
        super().__init__()
        self.conv1 = GATConv(in_ch, hid_ch, heads=8, concat=True)
        self.conv2 = GATConv(hid_ch * 8, out_ch, heads=1, concat=False)
    def forward(self, data, edge_weight=None, return_first_layer=False):
        x, ei = data.x, data.edge_index
        h1 = self.conv1(x, ei); h1 = F.relu(h1)
        if return_first_layer:
            return h1
        h1 = F.dropout(h1, training=self.training)
        h2 = self.conv2(h1, ei)
        return F.log_softmax(h2, dim=1)

# class GAT(torch.nn.Module):
#     def __init__(self, in_ch, hid_ch, out_ch, heads1=8, heads2=1, alpha1=1.0, alpha2=1.0):
#         super().__init__()
#
#         self.conv1 = GATConv(in_ch, hid_ch, heads=heads1, concat=True)
#
#         self.conv2 = GATConv(hid_ch * heads1, out_ch, heads=heads2, concat=False)
#
#         self.alpha1 = alpha1
#         self.alpha2 = alpha2
#
#         self.res1 = torch.nn.Linear(in_ch, hid_ch * heads1) if in_ch != hid_ch * heads1 else None
#         self.res2 = torch.nn.Linear(hid_ch * heads1, out_ch)       if hid_ch * heads1 != out_ch else None
#
#     def forward(self, data, edge_weight=None, return_first_layer=False):
#         x, ei = data.x, data.edge_index
#
#
#         h1 = self.conv1(x, ei)
#         r1 = self.res1(x) if self.res1 is not None else x
#         h1 = h1 + self.alpha1 * r1
#         h1 = F.relu(h1)
#         if return_first_layer:
#             return h1
#         h1 = F.dropout(h1, p=0.5, training=self.training)
#
#
#         h2 = self.conv2(h1, ei)
#         r2 = self.res2(h1) if self.res2 is not None else h1
#         h2 = h2 + self.alpha2 * r2
#
#         return F.log_softmax(h2, dim=1)


class GraphSAGE(nn.Module):
    def __init__(self, in_ch, hid_ch, out_ch):
        super().__init__()
        self.conv1 = SAGEConv(in_ch, hid_ch)
        self.conv2 = SAGEConv(hid_ch, out_ch)
    def forward(self, data, edge_weight=None, return_first_layer=False):
        x, ei = data.x, data.edge_index
        h1 = self.conv1(x, ei); h1 = F.relu(h1)
        if return_first_layer:
            return h1
        h1 = F.dropout(h1, training=self.training)
        h2 = self.conv2(h1, ei)
        return F.log_softmax(h2, dim=1)

# class GraphSAGE(nn.Module):
#     def __init__(self, in_ch, hid_ch, out_ch, alpha1=1.0, alpha2=1.0):
#         super().__init__()
#
#         self.conv1 = SAGEConv(in_ch,  hid_ch)
#         self.conv2 = SAGEConv(hid_ch, out_ch)
#
#         self.alpha1 = alpha1
#         self.alpha2 = alpha2
#
#
#         if in_ch != hid_ch:
#             self.res1 = nn.Linear(in_ch, hid_ch)
#         else:
#             self.res1 = None
#
#
#         if hid_ch != out_ch:
#             self.res2 = nn.Linear(hid_ch, out_ch)
#         else:
#             self.res2 = None
#
#     def forward(self, data, edge_weight=None, return_first_layer=False):
#         x, ei = data.x, data.edge_index
#
#
#         h1 = self.conv1(x, ei)
#         res1 = self.res1(x) if self.res1 is not None else x
#         h1 = h1 + self.alpha1 * res1
#         h1 = F.relu(h1)
#
#         if return_first_layer:
#             return h1
#
#         # dropout
#         h1 = F.dropout(h1, p=0.5, training=self.training)
#
#
#         h2 = self.conv2(h1, ei)
#         res2 = self.res2(h1) if self.res2 is not None else h1
#         h2 = h2 + self.alpha2 * res2
#
#
#         return F.log_softmax(h2, dim=1)

class Graphormer(nn.Module):
    def __init__(self, in_ch, hid_ch, out_ch, heads=8):
        super(Graphormer, self).__init__()
        # Transformer-based encoder and decoder layers
        self.conv1 = TransformerConv(in_ch, hid_ch, heads=heads, dropout=0.1)
        self.conv2 = TransformerConv(hid_ch * heads, out_ch, heads=1, concat=False, dropout=0.1)

    def forward(self, data, edge_weight=None, return_first_layer=False):
        # data.x: node features, data.edge_index: edge list
        x, ei = data.x, data.edge_index
        # First layer
        h1 = self.conv1(x, ei)
        h1 = F.relu(h1)
        if return_first_layer:
            return h1
        # Dropout
        h1 = F.dropout(h1, training=self.training)
        # Second layer
        h2 = self.conv2(h1, ei)
        return F.log_softmax(h2, dim=1)




# class Graphormer(torch.nn.Module):
#     def __init__(self, in_ch, hid_ch, out_ch, heads=8, alpha1=0.8, alpha2=0.8):
#         super().__init__()
#         # Transformer-based encoder layers
#         self.conv1 = TransformerConv(in_ch, hid_ch, heads=heads, dropout=0.1)
#         self.conv2 = TransformerConv(hid_ch * heads, out_ch, heads=1,
#                                      concat=False, dropout=0.1)
#         # Residual scaling
#         self.alpha1 = alpha1
#         self.alpha2 = alpha2
#         # Linear projections for residual when dims differ
#         self.res1 = torch.nn.Linear(in_ch, hid_ch * heads) \
#             if in_ch != hid_ch * heads else None
#         self.res2 = torch.nn.Linear(hid_ch * heads, out_ch) \
#             if hid_ch * heads != out_ch else None
#
#     def forward(self, data, edge_weight=None, return_first_layer=False):
#         x, ei = data.x, data.edge_index
#
#         # —— Layer 1: TransformerConv + residual
#         h1 = self.conv1(x, ei)
#         r1 = self.res1(x) if self.res1 is not None else x
#         h1 = F.relu(h1 + self.alpha1 * r1)
#         if return_first_layer:
#             return h1
#         h1 = F.dropout(h1, p=0.5, training=self.training)
#
#         # —— Layer 2: TransformerConv + residual
#         h2 = self.conv2(h1, ei)
#         r2 = self.res2(h1) if self.res2 is not None else h1
#         h2 = h2 + self.alpha2 * r2
#
#         return F.log_softmax(h2, dim=1)

class GraphGAN(nn.Module):
    def __init__(self, in_ch, hid_ch, out_ch):
        super(GraphGAN, self).__init__()
        # Generator: GCNConv to produce node embeddings
        self.gen_conv = GCNConv(in_ch, hid_ch)
        # Discriminator: Linear classifier on edge embeddings (unused in classification)
        self.disc_fc = nn.Linear(hid_ch * 2, 1)
        # Classification head: map embeddings to class logits
        self.cls_fc = nn.Linear(hid_ch, out_ch)

    def discriminate(self, h, edge_index):
        """
        Given node embeddings h and edge_index [2, E],
        compute probability of each edge being real.
        """
        src, dst = edge_index
        hi, hj = h[src], h[dst]
        z = torch.cat([hi, hj], dim=1)
        return torch.sigmoid(self.disc_fc(z))

    def forward(self, data, edge_weight=None, return_embeddings=False):
        """
        data: PyG data object with x and edge_index
        edge_weight: optional edge weights tensor
        return_embeddings: if True, return node embeddings from generator
        Otherwise returns node classification log-probabilities
        """
        x, ei = data.x, data.edge_index
        # Generator step: produce node embeddings
        h = self.gen_conv(x, ei, edge_weight=edge_weight)
        h = F.relu(h)
        if return_embeddings:
            return h
        # Classification
        logits = self.cls_fc(h)
        return F.log_softmax(logits, dim=1)