import torch
from torch import nn

"""
Adapted from 
https://github.com/abdulfatir/planar-flow-pytorch/tree/3761883bbb53a40f903b5840c18e772bfd3e3d99
"""

class PlanarTransform(nn.Module):
    def __init__(self, dim=128):
        super().__init__()
        self.u = nn.Parameter(torch.randn(1, dim) * 0.01)
        self.w = nn.Parameter(torch.randn(1, dim) * 0.01)
        self.b = nn.Parameter(torch.randn(1) * 0.01)
    def m(self, x):
        return -1 + torch.log(1 + torch.exp(x))
    def h(self, x):
        return torch.tanh(x)
    def h_prime(self, x):
        return 1 - torch.tanh(x) ** 2
    def forward(self, z):
        # z.size() = batch x dim
        u_dot_w = (self.u @ self.w.t()).view(())
        w_hat = self.w / torch.norm(self.w, p=2) # Unit vector in the direction of w
        u_hat = (self.m(u_dot_w) - u_dot_w) * (w_hat) + self.u # 1 x dim
        affine = z @ self.w.t() + self.b
        z_next = z + u_hat * self.h(affine) # batch x dim
        
        psi = self.h_prime(affine) * self.w # batch x dim
        LDJ = torch.log(torch.abs(psi @ u_hat.t() + 1) + 1e-8) # batch x 1
        return z_next, LDJ

class AmortizedPlanarTransform(nn.Module):
    def __init__(self, dim=128):
        super().__init__()
        # self.u = nn.Parameter(torch.randn(1, dim) * 0.01)
        # self.w = nn.Parameter(torch.randn(1, dim) * 0.01)
        # self.b = nn.Parameter(torch.randn(1) * 0.01)
        self.dim = dim
        self.softplus = nn.Softplus()

    # def m(self, x):
    #     return -1 + torch.log(1 + torch.exp(x))
    def h(self, x):
        return torch.tanh(x)
    def h_prime(self, x):
        return 1 - torch.tanh(x) ** 2
    def forward(self, z, params):
        """
        from the implementation of Sylvester normalizing flow
        https://github.com/riannevdberg/sylvester-flows/blob/32dde9b7d696fee94f946a338182e542779eecfe/models/flows.py

        params: (batch, [u;w;b])
        """
        bs = z.shape[0]
        u, w = torch.chunk(params[:, :-1], 2, 1)
        u = u.view(bs, self.dim, 1)
        w = w.view(bs, 1, self.dim)
        b = params[:,-1].view(bs, 1, 1)

        zk = z.unsqueeze(2)
        # reparameterize u such that the flow becomes invertible (see appendix paper)
        uw = torch.bmm(w, u)
        m_uw = -1. + self.softplus(uw)
        w_norm_sq = torch.sum(w ** 2, dim=2, keepdim=True)
        u_hat = u + ((m_uw - uw) * w.transpose(2, 1) / w_norm_sq)

        # compute flow with u_hat
        wzb = torch.bmm(w, zk) + b
        z = zk + u_hat * self.h(wzb)
        z = z.squeeze(2)

        # compute logdetJ
        psi = w * self.h_prime(wzb)
        log_det_jacobian = torch.log(torch.abs(1 + torch.bmm(psi, u_hat)))
        log_det_jacobian = log_det_jacobian.squeeze(2).squeeze(1)

        return z, log_det_jacobian

class PlanarFlow(nn.Module):
    def __init__(self, dim=128, K=8):
        super().__init__()
        self.transforms = nn.ModuleList([PlanarTransform(dim) for k in range(K)])
    def forward(self, z):
        zK = z
        LDJ = 0
        for transform in self.transforms:
            zK, log_det_j = transform(zK)
            LDJ += log_det_j
        return zK, LDJ

class AmortizedPlanarFlow(nn.Module):
    def __init__(self, dim=128, K=8):
        super().__init__()
        self.K = K
        self.transforms = nn.ModuleList([AmortizedPlanarTransform(dim) for k in range(K)])
        self.get_params = nn.Linear(dim*2, (dim*2+1)*K)    # transform input into parameters
    def forward(self, z, h):
        params = torch.chunk(self.get_params(h), self.K, 1)
        z_k = z
        LDJ = 0
        for i, transform in enumerate(self.transforms):
            z_k, log_det_j = transform(z_k, params[i])
            LDJ += log_det_j
        return z_k, LDJ