import torch
import torch.nn as nn
import torch.nn.functional as F



import time
from numbers import Number

class vae_mini(nn.Module):

    def __init__(self, feature_volume, z_size):
        super(vae_mini, self).__init__()
        self.feature_volume = feature_volume

        # q
        self.q_mean = self._linear(self.feature_volume, z_size, relu=False)
        self.q_logvar = self._linear(self.feature_volume, z_size, relu=False)

        # projection
        self.project = self._linear(z_size, self.feature_volume, relu=False)


    def forward(self, x):
        raw_shape = x.shape
        mean, logvar = self.q(x)
        z = self.reparameterize(mean, logvar)
        z_projected = self.project(z).view(raw_shape)

        return (mean, logvar), z_projected

    def reparameterize(self, mu, logvar):
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        #eps = torch.zeros_like(std)
        return eps * std + mu
    
    def _linear(self, in_size, out_size, relu=True):
        return nn.Sequential(
            nn.Linear(in_size, out_size),
            nn.ReLU(),
        ) if relu else nn.Linear(in_size, out_size)
    
    def q(self, encoded):
        unrolled = encoded.view(-1, self.feature_volume)
        return self.q_mean(unrolled), self.q_logvar(unrolled)
    

class vie(nn.Module):

    def __init__(self, feature_volume, z_size):
        super(vie, self).__init__()
        self.feature_volume = feature_volume
        self.z_size = z_size

        # q
        self.q_stat = self._linear(self.feature_volume, 2*z_size, relu=False)
        self.q_std = self._linear(self.feature_volume, z_size, relu=False)



    def forward(self, x):
        raw_shape = x.shape
        mean, std = self.q_res(x)
        #mean, std = self.q(x)
        #std = F.softplus(std-30, beta=1) #head3layer
        #std = F.softplus(std-5, beta=1) #head5layer
        std = F.softplus(std-30, beta=1)
        z = self.reparameterize(mean, std)

        z = z.view(raw_shape)

        return (mean, std), z

    def reparameterize(self, mu, std):
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        #std = torch.exp(0.5 * logvar)
        eps = torch.rand_like(std)
        #eps = torch.zeros_like(std)
        return eps * std + mu
    
    
    def _linear(self, in_size, out_size, relu=True):
        return nn.Sequential(
            nn.Linear(in_size, out_size),
            nn.ReLU(),
        ) if relu else nn.Linear(in_size, out_size)
    
    def q(self, encoded):
        unrolled = encoded.view(-1, self.feature_volume)
        stat = self.q_stat(unrolled)
        return stat[:,:self.z_size], stat[:,self.z_size:]
    
    def q_res(self, encoded):
        unrolled = encoded.view(-1, self.feature_volume)
        std = self.q_std(unrolled)
        return unrolled, std
    
class vie_old(nn.Module):

    def __init__(self, feature_volume, z_size):
        super(vie_old, self).__init__()
        self.feature_volume = feature_volume
        self.z_size = z_size

        # q
        self.q_stat = self._linear(self.feature_volume, 2*z_size, relu=False)



    def forward(self, x):
        raw_shape = x.shape
        mean, std = self.q(x)
        #std = F.softplus(std-30, beta=1) #head3layer
        std = F.softplus(std-5, beta=1) #head5layer
        z = self.reparameterize(mean, std)

        z = z.view(raw_shape)

        return (mean, std), z

    def reparameterize(self, mu, std):
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        #std = torch.exp(0.5 * logvar)
        eps = torch.rand_like(std)
        #eps = torch.zeros_like(std)
        return eps * std + mu
    
    
    def _linear(self, in_size, out_size, relu=True):
        return nn.Sequential(
            nn.Linear(in_size, out_size),
            nn.ReLU(),
        ) if relu else nn.Linear(in_size, out_size)
    
    def q(self, encoded):
        unrolled = encoded.view(-1, self.feature_volume)
        stat = self.q_stat(unrolled)
        return stat[:,:self.z_size], stat[:,self.z_size:]

    


