import torch
from torch_geometric.nn import GATConv
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import VGAE as PygVAE

class VGAE(PygVAE):

    def get_embeddings(self, data):
        return self.encode(data.x, data.edge_index)

class VGAEEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        in_channels = config.encoder.vae.model.in_channels
        hidden_channels = config.encoder.vae.model.hidden_channels
        latent_dim = config.encoder.vae.model.latent_dim

        self.conv1 = GATConv(in_channels, hidden_channels)
        self.conv_mu = GATConv(hidden_channels, latent_dim)
        self.conv_logstd = GATConv(hidden_channels, latent_dim)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

class VGAEDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()

        latent_dim = config.encoder.vae.model.latent_dim
        self.linear1 = nn.Linear(latent_dim, 32)
        self.linear2 = nn.Linear(32, latent_dim)


    def forward(self, z, edge_index=None, sigmoid=True):
        z =  F.relu(self.linear1(z))
        z = self.linear2(z)
        value = (z @ z.t())
        if edge_index is not None:
            value = value[edge_index[0], edge_index[1]]
        return torch.sigmoid(value) if sigmoid else value