import torch
import torch.nn as nn
from math import trace_df_dz, jacobian_df_dz, grad_trace_df_dz

from vae_lib.models.VAE import VAE 
from vae_lib.utils.distributions import log_normal_diag, log_normal_standard
from acnf import CNF

class CNF_VAE(nn.Module):
    """ CNF transform base variational approximation from encoder (q_0(z) = N(z; mu, std)) to q_T(z)
        to be closer to true posterior p(z|x)
        and decoder remap samples from q_T(z) to estimate joint likelihood p(x, z)
    """

    def __init__(self, input_size, z_dim, hidden_dim, width, layer_type = "hypernet", deeper = False, \
                        activation = "tanh", divegence_method="naive", device = torch.device("cpu"), T = 1):
        super().__init__()
        self.vae = VAE(z_dim, input_size = input_size, input_type = "binary", device = device)
        self.cnf = CNF_FOR_VAE(in_out_dim=z_dim, hidden_dim=hidden_dim, width= width, layer_type=layer_type, \
                                    deeper = deeper, activation = activation, divegence_method=divegence_method)
        self.t0 = 0
        self.T = T

    def forward(self, x, t0 = None, t1 = None):
        if t0 is None:
            t0 = self.t0
        if t1 is None:
            t1 = self.T
        device = x.device
        # x: [batch_size, 1, 28, 28], return encoding_mean: [batch_size, z_dim], encoding_var: [batch_size, z_dim]
        encoding_mean, encoding_var = self.encode(x)
        # Sample z_0
        z0 = self.vae.reparameterize(encoding_mean, encoding_var)

        # refresh odefunc statistics and base_dist
        self.cnf.before_odeint()

        # update mean and variance for data-specific base distribution
        self.cnf.update_base_dist(encoding_mean, encoding_var)

        # gradient of log \mu(z(T))
        grad_log_mu_t1 = self.cnf.gradient_log_base_distribution(z0).to(device)
        log_mu_t1 = torch.sum(self.cnf.base_dist.log_prob(z0).to(device), dim=-1, keepdim=True)
        
        if t1 > t0:
            # generation direction
            z_t_forward, delta_logp, _,  constraint_loss_t = odeint(self.cnf.generate_with_log_density_and_constraint_penalty, \
                                                                    (z0, torch.zeros_like(log_mu_t1), grad_log_mu_t1, torch.zeros_like(log_mu_t1)),
                                                                    torch.tensor([t1, t0]).type(torch.float32).to(device),
                                                                    atol=1e-5,
                                                                    rtol=1e-5,
                                                                    method='dopri5',
                                                                    adjoint_params = list(self.cnf.parameters()))

            # reconstruction on "transformed" samples
            z_t0 = z_t_forward[-1]
            delta_logp_t0 = delta_logp[-1]
        elif t1 == t0:
            # no flow
            z_t0 = z0
            delta_logp_t0 = torch.zeros_like(log_mu_t1)
            constraint_loss_t = torch.zeros_like(log_mu_t1).repeat(2, 1, 1)

        x_reconstruct = self.vae.decode(z_t0)

        # kl between q_0 and standard Gaussian, return KLD [batch_size, 1]
        # ln p(z_k)  (not averaged)
        log_p_zk = log_normal_standard(z_t0, dim=-1)
        # ln q(z_0)  (not averaged) , encoding_var
        log_q_z0 = log_normal_diag(z0, mean=encoding_mean, log_var=encoding_var.log(), dim=-1)
        # estimated KLD
        # -0.5 * torch.sum(1 + torch.log(encoding_var) - encoding_mean.pow(2) - encoding_var, dim=-1, keepdim=True)
        KLD = (log_q_z0-log_p_zk).unsqueeze(-1)

        return x_reconstruct, -delta_logp_t0, KLD, -constraint_loss_t[-1]

    def evaluate_elbo(self, x, ts):
        device = x.device

        encoding_mean, encoding_var = self.encode(x)
        # Sample z_0
        z0 = self.vae.reparameterize(encoding_mean, encoding_var)

        # refresh odefunc statistics and base_dist
        self.cnf.before_odeint()

        # update mean and variance for data-specific base distribution
        self.cnf.update_base_dist(encoding_mean, encoding_var)

        # gradient of log \mu(z(T))
        grad_log_mu_t1 = self.cnf.gradient_log_base_distribution(z0).to(device)
        log_mu_t1 = torch.sum(self.cnf.base_dist.log_prob(z0).to(device), dim=-1, keepdim=True)

        criterion = nn.BCELoss(reduction='sum')

        # generation direction
        with torch.no_grad():
            z_t_forward, delta_logp, _,  _ = odeint(self.cnf.generate_with_log_density_and_constraint_penalty, \
                                                                    (z0, torch.zeros_like(log_mu_t1), grad_log_mu_t1, torch.zeros_like(log_mu_t1)),
                                                                    ts,
                                                                    atol=1e-5,
                                                                    rtol=1e-5,
                                                                    method='dopri5',
                                                                    adjoint_params = list(self.cnf.parameters()))
            # ln q(z_0)  (not averaged) , encoding_var
            log_q_z0 = log_normal_diag(z0, mean=encoding_mean, log_var=encoding_var.log(), dim=-1)
            neg_elbos = []
            x_reconstructs = []
            for z_t0, delta_logp_t0 in zip(z_t_forward, delta_logp):
                x_reconstruct = self.vae.decode(z_t0)

                # kl between q_0 and standard Gaussian, return KLD [batch_size, 1]
                # ln p(z_k)  (not averaged)
                log_p_zk = log_normal_standard(z_t0, dim=-1)
                
                # estimated KLD
                # -0.5 * torch.sum(1 + torch.log(encoding_var) - encoding_mean.pow(2) - encoding_var, dim=-1, keepdim=True)
                old_kld = (log_q_z0-log_p_zk).unsqueeze(-1)

                x_reconstructs.append(x_reconstruct)
                

                bce_loss = criterion(x_reconstruct, x)/x.shape[0]
                # total loss = reconstruction loss + KLD - expectation over int delta_logp + constraint loss
                new_kl = old_kld.mean(0) + delta_logp_t0.mean(0)
                neg_elbo = bce_loss + new_kl

                neg_elbos.append(neg_elbo)

            neg_elbos = torch.stack(neg_elbos, dim = 0)
            x_reconstructs = torch.stack(x_reconstructs, dim = 0)
        
        return neg_elbos, x_reconstructs
            


    def encode(self, x):
        encoding_mean, encoding_var = self.vae.encode(x)
        return (encoding_mean, encoding_var)

    def decode(self, z):
        return self.vae.decode(z)

class CNF_FOR_VAE(CNF):
    """ inherit class on base CNF for CNF-VAE model
    """
    def __init__(self, in_out_dim, hidden_dim, width, layer_type = "hypernet", deeper = False, activation = "tanh", divegence_method="naive"):
        super().__init__(in_out_dim, hidden_dim, width, base_dist = None, layer_type = layer_type, \
                                            deeper = deeper, activation = activation, divegence_method=divegence_method)
    
    def before_odeint(self, z = None):
        # clear statistics
        if self.method == "hutchinson":
            self._e = torch.rand_like(z, requires_grad=False).to(z.device)
        elif self.method == "naive":
            self._e = None
        self._num_evals.fill_(0)
        # clear base distribution always before integration to prevent error
        self.base_dist = None

    def generate_with_log_density_and_constraint_penalty(self, t, states):
        """ return [z(t), logp_t(z(t)), grad log p_t(z(t)), ]
        """

        # increment num evals (temporary fix) [TODO]
        self._num_evals += 1

        z = states[0]
        logp_z = states[1]
        grad_logp_z = states[2]
        R = states[3]

        batchsize = z.shape[0]
        with torch.set_grad_enabled(True):
            z.requires_grad_(True)
            if self.layer_type == "hypernet":
                # -------------- Hypernetwork for g(z,t) ----------------- #
                W, B, U = self.hyper_net(t)

                # [batch_size, z_dim] -> [1, batch_size, z_dim] -> [width, batch_size, z_dim]
                Z = torch.unsqueeze(z, 0).repeat(self.width, 1, 1)
                # torch.matmul implmented for batched version (broadcast batch dimension)
                # Z*W: [width, batch_size, z_dim] * [width, z_dim, 1] -> [width, batch_size, z_dim]
                # h = tanh(Z*W + B): [width, batch_size, z_dim]
                h = torch.tanh(torch.matmul(Z, W) + B) # restrict state always be in [-1,1]
                # dz_dt: g drift function [batch_size, z_dim]
                # average over width
                dz_dt = torch.matmul(h, U).mean(0)
                # --------------------------------------------------- #
            elif self.layer_type == "concatnet":
                
                dz_dt = self.hyper_net(t, z)
            # naive implementation to compute trace of Jacobian
            grad_logmu_z = self.gradient_log_base_distribution(z).squeeze(-1)
            dlogp_z_dt = -trace_df_dz(dz_dt, z, method=self.method, e = self._e).view(batchsize, 1)

            # dgrad_logp_z_dt
            J_dz_dt = jacobian_df_dz(dz_dt, z)
            assert torch.isfinite(J_dz_dt).all(), 'non-finite values in state `J_dz_dt`: {}'.format(J_dz_dt)

            grad_Trace_J = grad_trace_df_dz(dz_dt, z, self.method, e = self._e)
            assert torch.isfinite(grad_Trace_J).all(), 'non-finite values in state `grad_Trace_J`: {}'.format(grad_Trace_J)

            dgrad_logp_z_dt = -torch.matmul(J_dz_dt.transpose(-1,-2),grad_logp_z.unsqueeze(-1)).squeeze(-1)\
                                - grad_Trace_J

            dR_dt = torch.sum(((grad_logp_z-grad_logmu_z)+dz_dt)**2, dim=-1).unsqueeze(-1)

        return (dz_dt, dlogp_z_dt, dgrad_logp_z_dt, dR_dt)

    def update_base_dist(self, mean, variance):
        """ overwrite base distribution for data specific variational approximation (from encoder in VAE)
        """
        self.base_dist = torch.distributions.Normal(mean, scale = torch.sqrt(variance))

    def gradient_log_base_distribution(self, z):
        """ overwrite base class function for gradient of log base distribution
            args:
                z: [batch_size, z_dim]
            returns:
                grad_log_prob: [batch_size, z_dim]
        """
        with torch.set_grad_enabled(True):
            # base_dist scale [batch_size, z_dim], mean [batch_size, z_dim]
            grad_log_prob = -torch.matmul(torch.stack([torch.diag(_scale) for _scale in 1/self.base_dist.scale]),\
                                                (z-self.base_dist.loc).unsqueeze(-1)).squeeze(-1)    
            """
            grad_log_prob = -torch.matmul(self.base_dist.covariance_matrix.inverse().repeat(z.shape[0], 1, 1), \
                                            (z-self.base_dist.loc).unsqueeze(-1)).squeeze(-1)
            """
        assert grad_log_prob.shape == z.shape

        return grad_log_prob

    def second_gradient_log_base_distribution(self, z):
        """ for Normal or Multivariate dist only, return negative inverse of covariance matrix
        """
        with torch.set_grad_enabled(True):
            second_grad_log_density = -torch.stack([torch.diag(_scale) for _scale in 1/self.base_dist.scale])

        assert second_grad_log_density.shape[:-1] == z.shape and second_grad_log_density.shape[-1] == z.shape[-1]

        return second_grad_log_density