import torch
import torch.nn as nn
import torch.nn.functional as F

from codes.data.models.dgms.vae import VAEBase

class PlanarFlow(nn.Module):
    
    def __init__(self, dim):
        super().__init__()
        self.u = nn.Parameter(torch.randn(dim))
        self.w = nn.Parameter(torch.randn(dim))
        self.b = nn.Parameter(torch.randn(1))

    def forward(self, z):
        inner = (self.w * z).sum(dim=1, keepdim=True) + self.b
        z_out = z + self.u * torch.tanh(inner)
        psi = (1 - torch.tanh(inner)**2) * self.w
        log_det = torch.log(torch.abs(1 + (self.u * psi).sum(dim=1, keepdim=True)))
        return z_out, log_det


class FlowVAE(VAEBase):

    def __init__(
        self, 
        encoder, 
        decoder, 
        flow_type, 
        num_flows,
        enc_out_dim, 
        z_dim
    ):
        super(VAEBase, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

        self.enc_z_mean = nn.Linear(enc_out_dim, z_dim)
        self.enc_z_sdev = nn.Linear(enc_out_dim, z_dim)

        self.flows = self._prepare_flows(z_dim, flow_type, num_flows)
        
    def _prepare_flows(self, z_dim, flow_type, num_flows):
        if flow_type == 'planar':
            Flow = PlanarFlow
        else:
            raise ValueError(f'Unknown flow type: {flow_type}')
        return nn.ModuleList([Flow(z_dim) for _ in range(num_flows)])

    def calc_loss(self, recon_x, x, mu, logvar, z0, zk, log_det_sum):
        recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
        kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        
        # Log probability of zk under the posterior
        log_q_z0 = -0.5 * torch.sum(logvar + (z0 - mu).pow(2) / logvar.exp())
        # Log probability of zk under the prior
        log_p_zk = -0.5 * torch.sum(zk.pow(2))
        
        # Combine all terms
        return recon_loss + kl_div - log_det_sum.sum() - (log_p_zk - log_q_z0)

    def forward(self, x):
        enc_out = self.encoder(x)
        z_mean = self.enc_z_mean(enc_out)
        z_sdev = self.enc_z_sdev(enc_out)
        
        z0 = self.reparameterize(z_mean, z_sdev)

        zk = z0
        log_det_sum = 0

        for flow in self.flows:
            zk, log_det = flow(zk)
            log_det_sum += log_det
        
        recon_x = self.decoder(zk)

        loss = self.calc_loss(recon_x, x, z_mean, z_sdev, z0, zk, log_det_sum)
        return loss
    
    def generate(self, z):
        zk = z
        for flow in self.flows:
            zk, _ = flow(zk)
        
        recon_x = self.decoder(zk)
        return recon_x