import torch
import torch.nn as nn

from torch_geometric.nn import GCN


class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim[0])
        self.fc2 = nn.Linear(hidden_dim[0], hidden_dim[1])
        
        self.mu = nn.Linear(hidden_dim[1], latent_dim)
        self.logvar = nn.Linear(hidden_dim[1], latent_dim)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        
        mu = self.mu(x)
        logvar = self.logvar(x)
        return mu, logvar
    
class Decoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, origin_dim):
        super().__init__()
        self.defc1 = nn.Linear(input_dim, hidden_dim[0])
        self.defc2 = nn.Linear(hidden_dim[0], hidden_dim[1])
        self.defc3 = nn.Linear(hidden_dim[1], origin_dim)
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.defc1(x))
        x = self.relu(self.defc2(x))
        x = self.defc3(x)
        return x
    
class CVAE(nn.Module):
    def __init__(self, gnn, input_size, z_size, latent_dim, a_dim, hidden_dim:list):
        super().__init__()
        self.gnn = gnn
        self.encoder = Encoder(input_size + a_dim + z_size, hidden_dim, latent_dim)
        self.decoder = Decoder(latent_dim + a_dim + z_size, hidden_dim[::-1], input_size)
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return eps*std + mu
    
    def forward(self, data):
        a = self.gnn(2 * data.s - 1, data.edge_index)
        xaz = torch.cat((data.x, a, data.z), dim=1)
        
        mu, logvar = self.encoder(xaz)
        z = self.reparameterize(mu, logvar)
        z = torch.cat((z, a, data.z), dim=1)
        pred = self.decoder(z)
        return pred, mu, logvar
    
class CVAE_linear(CVAE):
    def __init__(self, gnn, input_size, latent_dim, a_dim, hidden_dim:list):
        super().__init__(gnn, input_size, latent_dim, a_dim, hidden_dim)
        
    
    def forward(self, data):
        a = self.gnn(2 * data.s - 1)
        xaz = torch.cat((data.x, a, data.z), dim=1)
        
        mu, logvar = self.encoder(xaz)
        z = self.reparameterize(mu, logvar)
        z = torch.cat((z, a, data.z), dim=1)
        pred = self.decoder(z)
        return pred, mu, logvar
    
class loss_function(nn.Module):
    def __init__(self) -> None:
        super(loss_function, self).__init__()  
        self.mse_loss = nn.MSELoss(reduction='sum')
        
    def forward(self, pred, x, mu, logvar):
        MSE = self.mse_loss(pred, x)
        # see Appendix B from VAE paper:
        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
        # https://arxiv.org/abs/1312.6114
        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return MSE   + KLD, MSE, KLD
    
def cvae(sensitive_size, hidden_channels, num_layers, a_dim, input_size, latent_dim, z_size):
    gcn = GCN(sensitive_size, hidden_channels, num_layers, a_dim)
    # gcn = nn.Linear(sensitive_size, a_dim)
    cvae = CVAE(gcn, input_size, z_size, latent_dim, a_dim, [64, 32])
    return cvae