import torch
import torch.nn as nn
from model import HyperNetwork, ConcateNetwork
from math import trace_df_dz, jacobian_df_dz, grad_trace_df_dz

class CNF(nn.Module):
    """Adapted from the NumPy implementation at:
    https://gist.github.com/rtqichen/91924063aa4cc95e7ef30b3a5491cc52
    """
    def __init__(self, in_out_dim, hidden_dim, width, base_dist, layer_type = "hypernet", deeper = False, activation = "tanh", divegence_method="naive"):
        super().__init__()
        self.in_out_dim = in_out_dim  # dim(z)
        self.hidden_dim = hidden_dim
        self.width = width
        self.layer_type = layer_type
        if self.layer_type == "hypernet":
            # models parameters by hypernet
            self.hyper_net = HyperNetwork(in_out_dim, hidden_dim, width, deeper = deeper, activation=activation)
        elif self.layer_type == "concatnet":
            self.hyper_net = ConcateNetwork(in_out_dim, hidden_dim, deeper=deeper, activation=activation)
        
        self.base_dist = base_dist
        self.method = divegence_method

        self.register_buffer("_num_evals", torch.tensor(0.))
        self.register_buffer("_e", None)
        
    def before_odeint(self, z = None):
        # clear statistics
        if self.method == "hutchinson":
            self._e = torch.rand_like(z, requires_grad=False).to(z.device)
        elif self.method == "naive":
            self._e = None
        self._num_evals.fill_(0)
    
    def num_evals(self):
        return self._num_evals.item()

    def forward(self, t, states):
        """
            drift function of Augmented dynamics [z(t), logp_est_x(t), grad_logp_z(t), grad_logmu_z(t)]
            states: tuple of Tensors (z, logp_z, grad_logp_z, grad_logmu_z)
        """
        # increment num evals (temporary fix) [TODO]
        self._num_evals += 1

        z = states[0]
        logp_est_x = states[1]
        grad_logp_z = states[2]
        grad_logmu_z = states[3]

        assert torch.isfinite(grad_logp_z).all(), 'non-finite values in state `grad_logp_z`: {}'.format(grad_logp_z)

        batchsize = z.shape[0]

        # enable gradient locally for computing dTr/dz
        with torch.set_grad_enabled(True):
            z.requires_grad_(True)

            if self.layer_type == "hypernet":
                """
                Hyper_net
                W: [width, z_dim, 1]
                B: [width, 1, 1]
                U: [width, 1, z_dim]
                """

                # -------------- Hypernetwork for f ----------------- #
                W, B, U = self.hyper_net(t)

                # [batch_size, z_dim] -> [1, batch_size, z_dim] -> [width, batch_size, z_dim]
                Z = torch.unsqueeze(z, 0).repeat(self.width, 1, 1)
                # torch.matmul implmented for batched version (broadcast batch dimension)
                # Z*W: [width, batch_size, z_dim] * [width, z_dim, 1] -> [width, batch_size, z_dim]
                # h = tanh(Z*W + B): [width, batch_size, z_dim]
                h = torch.tanh(torch.matmul(Z, W) + B) # restrict state always be in [-1,1]
                # dz_dt: f drift function [batch_size, z_dim]
                # average over width
                dz_dt = torch.matmul(h, U).mean(0)
                # --------------------------------------------------- #
            elif self.layer_type == "concatnet":
                
                dz_dt = self.hyper_net(t, z)

            # naive implementation to compute trace of Jacobian
            grad_logmu_z = self.gradient_log_base_distribution(z).squeeze(-1)
            dlogp_z_dt = trace_df_dz(dz_dt, z, self.method, e=self._e).view(batchsize, 1) + torch.sum(grad_logmu_z*dz_dt, dim=-1).view(batchsize, 1)

            # dgrad_logp_z_dt
            J_dz_dt = jacobian_df_dz(dz_dt, z)
            assert torch.isfinite(J_dz_dt).all(), 'non-finite values in state `J_dz_dt`: {}'.format(J_dz_dt)

            grad_Trace_J = grad_trace_df_dz(dz_dt, z, self.method, e = self._e)
            assert torch.isfinite(grad_Trace_J).all(), 'non-finite values in state `grad_Trace_J`: {}'.format(grad_Trace_J)

            dgrad_logp_z_dt = -torch.matmul(J_dz_dt.transpose(-1,-2),grad_logp_z.unsqueeze(-1)).squeeze(-1)\
                                - grad_Trace_J

            # dgrad_logmu_z_dt
            ggrad_logmu_z = self.second_gradient_log_base_distribution(z)
            dgrad_logmu_z_dt = torch.matmul(ggrad_logmu_z.transpose(-1,-2), dz_dt.unsqueeze(-1)).squeeze(-1)

            assert torch.isfinite(ggrad_logmu_z).all(), 'non-finite values in state `ggrad_logmu_z`: {}'.format(ggrad_logmu_z)

            # dgrad_diff_dt = dgrad_logp_z_dt - dgrad_logmu_z_dt

        return (dz_dt, dlogp_z_dt, dgrad_logp_z_dt, dgrad_logmu_z_dt)

    def forward_simulate(self, t, states):
        """ return [z(t), logp_est_t(z(t))]
        """

        # increment num evals (temporary fix) [TODO]
        self._num_evals += 1

        z = states[0]
        logp_z = states[1]
        batchsize = z.shape[0]
        with torch.set_grad_enabled(True):
            z.requires_grad_(True)
            if self.layer_type == "hypernet":
                # -------------- Hypernetwork for g(z,t) ----------------- #
                W, B, U = self.hyper_net(t)

                # [batch_size, z_dim] -> [1, batch_size, z_dim] -> [width, batch_size, z_dim]
                Z = torch.unsqueeze(z, 0).repeat(self.width, 1, 1)
                # torch.matmul implmented for batched version (broadcast batch dimension)
                # Z*W: [width, batch_size, z_dim] * [width, z_dim, 1] -> [width, batch_size, z_dim]
                # h = tanh(Z*W + B): [width, batch_size, z_dim]
                h = torch.tanh(torch.matmul(Z, W) + B) # restrict state always be in [-1,1]
                # dz_dt: g drift function [batch_size, z_dim]
                # average over width
                dz_dt = torch.matmul(h, U).mean(0)
                # --------------------------------------------------- #
            elif self.layer_type == "concatnet":
                
                dz_dt = self.hyper_net(t, z)
            # naive implementation to compute trace of Jacobian
            grad_logmu_z = self.gradient_log_base_distribution(z).squeeze(-1)
            #import pdb
            #pdb.set_trace()
            dlogp_z_dt = trace_df_dz(dz_dt, z, method=self.method, e = self._e).view(batchsize, 1) + torch.sum(grad_logmu_z*dz_dt, dim=-1).view(batchsize, 1)
        
        return (dz_dt, dlogp_z_dt)

    def forward_simulation_with_kinetic_energy(self, t, states):
        """ for training rnode r evaluate kinetic energy
        """
        # increment num evals (temporary fix) [TODO]
        self._num_evals += 1

        z = states[0]
        R = states[2]
        batchsize = z.shape[0]
        
        (dz_dt, dlogp_z_dt) = self.forward_simulate(t, states)
        
        # kinetic energy sum velocity
        dR_dt = torch.sum(dz_dt**2, dim=-1).unsqueeze(-1)

        return (dz_dt, dlogp_z_dt, dR_dt)

    def predict(self, t, states):
        """ return only f(z, t)
        """
        # increment num evals (temporary fix) [TODO]
        self._num_evals += 1

        z = states
        batchsize = z.shape[0]

        # enable gradient locally for computing dTr/dz
        with torch.set_grad_enabled(True):
            z.requires_grad_(True)
            if self.layer_type == "hypernet":
                # -------------- Hypernetwork for f ----------------- #
                W, B, U = self.hyper_net(t)

                # [batch_size, z_dim] -> [1, batch_size, z_dim] -> [width, batch_size, z_dim]
                Z = torch.unsqueeze(z, 0).repeat(self.width, 1, 1)
                # torch.matmul implmented for batched version (broadcast batch dimension)
                # Z*W: [width, batch_size, z_dim] * [width, z_dim, 1] -> [width, batch_size, z_dim]
                # h = tanh(Z*W + B): [width, batch_size, z_dim]
                h = torch.tanh(torch.matmul(Z, W) + B) # restrict state always be in [-1,1]
                # dz_dt: f drift function [batch_size, z_dim]
                # average over width
                dz_dt = torch.matmul(h, U).mean(0)

            elif self.layer_type == "concatnet":
                dz_dt = self.hyper_net(t, z)

        return dz_dt

    def integrate_constraint(self, t, states):
        """ integrate constraint penalty for whole trajectory by backward simulation
            [z(t), grad_logp_z(t), grad_logmu_z(t), R(t)]
        """
        # increment num evals (temporary fix) [TODO]
        self._num_evals += 1

        z = states[0]
        grad_logp_z = states[1]
        grad_logmu_z = states[2]
        R = states[3]
        

        assert torch.isfinite(grad_logp_z).all(), 'non-finite values in state `grad_logp_z`: {}'.format(grad_logp_z)

        batchsize = z.shape[0]

        # enable gradient locally for computing dTr/dz
        with torch.set_grad_enabled(True):
            z.requires_grad_(True)
            if self.layer_type == "hypernet":
                """
                    Hyper_net
                    W: [width, z_dim, 1]
                    B: [width, 1, 1]
                    U: [width, 1, z_dim]
                """

                # -------------- Hypernetwork for f ----------------- #
                W, B, U = self.hyper_net(t)

                # [batch_size, z_dim] -> [1, batch_size, z_dim] -> [width, batch_size, z_dim]
                Z = torch.unsqueeze(z, 0).repeat(self.width, 1, 1)
                # torch.matmul implmented for batched version (broadcast batch dimension)
                # Z*W: [width, batch_size, z_dim] * [width, z_dim, 1] -> [width, batch_size, z_dim]
                # h = tanh(Z*W + B): [width, batch_size, z_dim]
                h = torch.tanh(torch.matmul(Z, W) + B) # restrict state always be in [-1,1]
                # dz_dt: f drift function [batch_size, z_dim]
                # average over width
                dz_dt = torch.matmul(h, U).mean(0)
                # --------------------------------------------------- #
            elif self.layer_type == "concatnet":
                dz_dt = self.hyper_net(t, z)

            # dgrad_logp_z_dt
            J_dz_dt = jacobian_df_dz(dz_dt, z)
            assert torch.isfinite(J_dz_dt).all(), 'non-finite values in state `J_dz_dt`: {}'.format(J_dz_dt)

            grad_Trace_J = grad_trace_df_dz(dz_dt, z, self.method, e=self._e)
            assert torch.isfinite(grad_Trace_J).all(), 'non-finite values in state `grad_Trace_J`: {}'.format(grad_Trace_J)

            dgrad_logp_z_dt = -torch.matmul(J_dz_dt.transpose(-1,-2),grad_logp_z.unsqueeze(-1)).squeeze(-1)\
                                - grad_Trace_J

            # dgrad_logmu_z_dt
            ggrad_logmu_z = self.second_gradient_log_base_distribution(z)
            dgrad_logmu_z_dt = torch.matmul(ggrad_logmu_z.transpose(-1,-2), dz_dt.unsqueeze(-1)).squeeze(-1)
            assert torch.isfinite(ggrad_logmu_z).all(), 'non-finite values in state `ggrad_logmu_z`: {}'.format(ggrad_logmu_z)
            
            # new derivation without density
            dR_dt = torch.sum(((grad_logp_z-grad_logmu_z)+dz_dt)**2, dim=-1).unsqueeze(-1)

        return (dz_dt, dgrad_logp_z_dt, dgrad_logmu_z_dt, dR_dt)

    def base_distribution_logprob(self, z):
        log_prob = self.base_dist.log_prob(z)
        return log_prob
    
    def gradient_base_distribution(self, z):
        with torch.set_grad_enabled(True):
            z.requires_grad_(True)
            log_prob = self.self.base_dist.log_prob(z)
            prob = torch.exp(log_prob)
            grad_log_prob = torch.autograd.grad(log_prob.sum(), z, create_graph = True)[0]
            assert grad_log_prob.shape == z.shape
            grad_prob = grad_log_prob*prob.unsqueeze(-1)
            assert grad_prob.shape == z.shape
        return grad_prob

    def gradient_log_base_distribution(self, z):
        """ return gradient of log base distribution
        """
        with torch.set_grad_enabled(True):
            grad_log_prob = -torch.matmul(self.base_dist.covariance_matrix.inverse().repeat(z.shape[0], 1, 1), \
                                            (z-self.base_dist.loc).unsqueeze(-1)).squeeze(-1)
        assert grad_log_prob.shape == z.shape

        return grad_log_prob

    def second_gradient_log_base_distribution(self, z):

        with torch.no_grad():
            second_grad_log_density = (-self.base_dist.covariance_matrix.inverse()).repeat(z.shape[0], 1, 1)


        return second_grad_log_density


