import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, global_mean_pool
from torch_geometric.nn import global_add_pool
from torch import nn

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, dropout=0.5):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.norm1 = torch.nn.LayerNorm(hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.norm2 = torch.nn.LayerNorm(hidden_channels)
        self.dropout = dropout

        self.decoder = torch.nn.Linear(hidden_channels, hidden_channels)
        
        self.predictor = nn.Sequential(
            nn.Linear(hidden_channels*2, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, 1)
        )


    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = self.norm1(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        x = self.norm2(x)
        return x

    def decode(self, z, edge_label_index):
        return self.predictor(torch.cat([z[edge_label_index[0]], z[edge_label_index[1]]], dim=-1)).view(-1)

    def decode_prob(self, z, edge_label_index):
        return torch.sigmoid(self.decode(z, edge_label_index))
    

    def decode_break(self, z, edge_label_index):
        predict = self.decoder(z[edge_label_index[0]])
        scores = predict * z[edge_label_index[1]]
        scores = scores.sum(dim=1)

        return torch.sigmoid(scores)


class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, heads=1, dropout=0.5):
        super().__init__()
        assert hidden_channels % heads == 0, "hidden_channels must be divisible by heads"
        self.heads = heads
        self.conv1 = GATConv(in_channels, hidden_channels // heads, heads=heads)
        self.norm1 = torch.nn.LayerNorm(hidden_channels)
        self.conv2 = GATConv(hidden_channels, hidden_channels, heads=1)
        self.norm2 = torch.nn.LayerNorm(hidden_channels)
        self.dropout = dropout

        self.decoder = torch.nn.Linear(hidden_channels, hidden_channels)
        self.predictor = nn.Sequential(
            nn.Linear(hidden_channels*2, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, 1)
        )

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = self.norm1(x)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        x = self.norm2(x)
        return x

    def decode(self, z, edge_label_index):
        return self.predictor(torch.cat([z[edge_label_index[0]], z[edge_label_index[1]]], dim=-1)).view(-1)

    def decode_prob(self, z, edge_label_index):
        return torch.sigmoid(self.decode(z, edge_label_index))
        

    def decode_break(self, z, edge_label_index):
        predict = self.decoder(z[edge_label_index[0]])
        scores = predict * z[edge_label_index[1]]
        scores = scores.sum(dim=1)

        return torch.sigmoid(scores)


class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, dropout=0.1):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.norm1 = torch.nn.LayerNorm(hidden_channels)
        self.dropout = dropout
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.norm2 = torch.nn.LayerNorm(hidden_channels)

        self.decoder = torch.nn.Linear(hidden_channels, hidden_channels)

        self.predictor = nn.Sequential(
            nn.Linear(hidden_channels*2, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, 1)
        )

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = self.norm1(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        x = self.norm2(x)
        return x

    def decode(self, z, edge_label_index):
        return self.predictor(torch.cat([z[edge_label_index[0]], z[edge_label_index[1]]], dim=-1)).view(-1)

    def decode_prob(self, z, edge_label_index):
        return torch.sigmoid(self.decode(z, edge_label_index))
    
    def decode_break(self, z, edge_label_index):
        predict = self.decoder(z[edge_label_index[0]])
        scores = predict * z[edge_label_index[1]]
        scores = scores.sum(dim=1)

        return torch.sigmoid(scores)



# Separate Regressor Head
class Regressor(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels=1, dropout=0.5):
        super().__init__()
        self.regressor = torch.nn.Sequential(
            torch.nn.Linear(hidden_channels, hidden_channels // 2),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_channels // 2, out_channels)
        )

    def forward(self, x):
        return self.regressor(x)
    

# GCN Encoder for Graph Regression
class GCNGraphRegressor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, dropout=0.5):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.norm1 = torch.nn.LayerNorm(hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.norm2 = torch.nn.LayerNorm(hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.norm3 = torch.nn.LayerNorm(hidden_channels)
        self.dropout = dropout

        self.regressor = Regressor(hidden_channels, out_channels=1, dropout=dropout)

    def forward(self, x, edge_index, batch):
        # Graph convolutions with normalization
        x = self.conv1(x, edge_index)
        x = self.norm1(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        x = self.conv2(x, edge_index)
        x = self.norm2(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        x = self.conv3(x, edge_index)
        x = self.norm3(x)
        x = F.relu(x)
        
        # Graph-level pooling
        x = global_mean_pool(x, batch)
        x = self.regressor(x)
        return x


# GAT Encoder for Graph Regression
class GATGraphRegressor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, heads=4, dropout=0.5):
        super().__init__()
        assert hidden_channels % heads == 0, "hidden_channels must be divisible by heads"
        
        self.conv1 = GATConv(in_channels, hidden_channels // heads, heads=heads)
        self.norm1 = torch.nn.LayerNorm(hidden_channels)
        self.conv2 = GATConv(hidden_channels, hidden_channels // heads, heads=heads)
        self.norm2 = torch.nn.LayerNorm(hidden_channels)
        self.conv3 = GATConv(hidden_channels, hidden_channels, heads=1)
        self.norm3 = torch.nn.LayerNorm(hidden_channels)
        self.dropout = dropout
        self.regressor = Regressor(hidden_channels, out_channels=1, dropout=dropout)

    def forward(self, x, edge_index, batch):
        # Graph attention convolutions with normalization
        x = self.conv1(x, edge_index)
        x = self.norm1(x)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        x = self.conv2(x, edge_index)
        x = self.norm2(x)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        x = self.conv3(x, edge_index)
        x = self.norm3(x)
        x = F.elu(x)
        
        # Graph-level pooling
        x = global_mean_pool(x, batch)
        x = self.regressor(x)
        return x


# SAGE Encoder for Graph Regression
class SAGEGraphRegressor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, dropout=0.5):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.norm1 = torch.nn.LayerNorm(hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.norm2 = torch.nn.LayerNorm(hidden_channels)
        self.conv3 = SAGEConv(hidden_channels, hidden_channels)
        self.norm3 = torch.nn.LayerNorm(hidden_channels)
        self.dropout = dropout
        self.regressor = Regressor(hidden_channels, out_channels=1, dropout=dropout)

    def forward(self, x, edge_index, batch):
        # GraphSAGE convolutions with normalization
        x = self.conv1(x, edge_index)
        x = self.norm1(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        x = self.conv2(x, edge_index)
        x = self.norm2(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        x = self.conv3(x, edge_index)
        x = self.norm3(x)
        x = F.relu(x)
        
        # Graph-level pooling
        x = global_mean_pool(x, batch)
        x = self.regressor(x)
        return x

