import torch
import torch.nn as nn
import torch.nn.functional as F

class encoder(nn.Module):
    '''
    The encoder that is used in generators for capturing latent representation
    '''

    def __init__(self, in_size, hidden_size, output_size):
        '''
        Args:
            in_size: input dimension
            hidden_size: hidden layer dimension
            output_size: encoder output dimension
        Output:
            (return value in forward) a tensor of shape (batch_size, output_size)
        '''
        super(encoder, self).__init__()
        self.linear_1 = nn.Linear(in_size, hidden_size)
        self.linear_2 = nn.Linear(hidden_size, output_size)
        self.h2mu = nn.Linear(output_size, output_size)
        self.h2logvar = nn.Linear(output_size, output_size)
        torch.nn.init.xavier_uniform_(self.linear_1.weight)
        torch.nn.init.xavier_uniform_(self.linear_2.weight)
        torch.nn.init.xavier_uniform_(self.h2mu.weight)
        torch.nn.init.xavier_uniform_(self.h2logvar.weight)
        self.linear_1.bias.data.fill_(0)
        self.linear_2.bias.data.fill_(0)
        self.h2mu.bias.data.fill_(0)
        self.h2logvar.bias.data.fill_(0)

    
    def reparameterize(self, mu, logvar):
        '''
        Args:
            mu: mean, parateters of simple tractable normal distribution
            logvar: log of variance, parateters of simple tractable normal distribution
        '''
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + std * eps
        return z

    def forward(self, input_f):
        '''
        Args:
            input_f: tensor of shape (batch_size, in_size)
        '''
        
        y_1 = F.relu(self.linear_1(input_f))
        y_2 = F.relu(self.linear_2(y_1))
        mu = self.h2mu(y_2)
        logvar = self.h2logvar(y_2)
        z = self.reparameterize(mu, logvar)

        return mu, logvar, z
    
class decoder(nn.Module):
    '''
    The decoder that is used in generators for recovering input
    '''

    def __init__(self, in_size, hidden_size, output_size):
        '''
        Args:
            in_size: input dimension
            hidden_size: hidden layer dimension
            output_size: decoder output dimension
        Output:
            (return value in forward) a tensor of shape (batch_size, output_size)
        '''
        super(decoder, self).__init__()
        
        self.linear_1 = nn.Linear(in_size, hidden_size)
        self.linear_2 = nn.Linear(hidden_size, output_size)
        self.clf1 = nn.Linear(in_size, in_size//2)
        self.clf2 = nn.Linear(in_size//2, 1)
        torch.nn.init.xavier_uniform_(self.linear_1.weight)
        torch.nn.init.xavier_uniform_(self.linear_2.weight)
        torch.nn.init.xavier_uniform_(self.clf1.weight)
        torch.nn.init.xavier_uniform_(self.clf2.weight)
        self.linear_1.bias.data.fill_(0)
        self.linear_2.bias.data.fill_(0)
        self.clf1.bias.data.fill_(0)
        self.clf2.bias.data.fill_(0)
    
    def forward(self, input_f):
        '''
        Args:
            input_f: tensor of shape (batch_size, in_size)
        '''
        y_1 = F.relu(self.linear_1(input_f))
        y_2 = F.relu(self.linear_2(y_1))
        tmp = F.relu(self.clf1(input_f))
        label = F.sigmoid(self.clf2(tmp))

        return y_2, label
    
class VAE(nn.Module):
    '''
    VAE generator
    '''

    def __init__(self, encoder, decoder):
        '''
        Args:
            encoder: VAE encoder
            decoder: VAE decoder
        Output:
            (return value in forward) a tensor of shape (batch_size, output_size)
        '''
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, input_f):
        '''
        Args:
            input_f: tensor of shape (batch_size, in_size)
        '''
            
        mu, logvar, z = self.encoder(input_f)
        recon_x,  outputs = self.decoder(mu)
        return recon_x, mu, logvar, z, outputs

class NoiseGenerator(nn.Module):
    '''
    Learnable noise generator
    '''
    def __init__(self, in_size, hidden_size, output_size):
        '''
        Args:
            noise
        Output:
            (return value in forward) a tensor of shape (batch_size, output_size)
        '''
        super(NoiseGenerator, self).__init__()
        self.fc1 = nn.Linear(in_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, noise):
        x = F.relu(self.fc1(noise))
        generated_noise = torch.tanh(self.fc2(x))
        return generated_noise 
    
