"""
==================
Normalizing Flows
==================

"""
import itertools

import numpy as np
import torch
from torchdiffeq import odeint_adjoint as odeint
import torch.nn.functional as F

from _utils import *

class Flow(torch.nn.Module):
    def __init__(self, transformations):
        super(Flow, self).__init__()
        self.transformations = torch.nn.ModuleList(transformations)

    def set_trace(self, trace):
        did_set_trace = False
        for module in self.modules():
            if hasattr(module, 'set_trace') and module is not self:
                module.set_trace(trace)
                did_set_trace = True
        assert did_set_trace

    def forward(self, x, edges, batch, context=None):
        batch_size = batch.max().item() + 1
        ldj = x.new_zeros(batch_size)
        reg_term = x.new_zeros(batch_size)

        for transform in self.transformations:
            x, delta_logp, delta_reg_term = transform(x, edges=edges, batch=batch, context=context)

            ldj = ldj + delta_logp
            reg_term = reg_term + delta_reg_term

        return x, ldj, reg_term

    def reverse(self, z, node_mask=None, edge_mask=None, context=None):
        for transform in reversed(self.transformations):
            z = transform.reverse(z, node_mask, edge_mask, context)

        return z

    def reverse_chain(self, z, node_mask=None, edge_mask=None, context=None):
        for transform in reversed(self.transformations):
            if hasattr(transform, 'reverse_chain'):
                z_chain = transform.reverse_chain(z, node_mask, edge_mask, context)
                n_timesteps = z_chain.size(0)
                node_mask = node_mask.view(1, -1, 1)
                node_mask = node_mask.repeat(n_timesteps, 1, 1)

                node_mask = node_mask.view(-1, 1)
                z = z_chain.view(n_timesteps * z.size(0), *z.size()[1:])

            else:
                z = transform.reverse(z, node_mask, edge_mask, context)

        return z


class FFJORD(torch.nn.Module):
    """
    Continuous-time flow FFJORD [1].

    Args:
        dynamics (nn.Module): The ODE dynamics function f(t,x).
        trace_method (str): The trace estimation method. One of {'exact', 'hutch'}.

    References:
        [1] FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models,
            Grathwohl et al., 2019, https://arxiv.org/abs/1810.01367
    """
    def __init__(self, dynamics, trace_method='hutch', ode_regularization=0, hutch_noise='gaussian', ode_solver=0, ode_mesh=1):
        super(FFJORD, self).__init__()

        if ode_solver == 0:
            self.ode_solver = 'euler'
        elif ode_solver == 1:
            self.ode_solver = 'implicit_adams'
        elif ode_solver == 2:
            self.ode_solver = 'dopri5'
        self.ode_mesh = ode_mesh

        self.odefunc = ODEfunc(
            dynamics, method=trace_method, ode_regularization=ode_regularization, hutch_noise=hutch_noise)

        self.set_integration_time(times = list(np.linspace(0.0, 1.0, self.ode_mesh+1)))
        self.set_odeint(method=self.ode_solver)
        self.dynamics_reg = torch.zeros(1)
        
    # def set_integration_time(self, times=[0.0, 1.0]):
    def set_integration_time(self, times=None):
        device = next(iter(self.odefunc.parameters())).device
        self.register_buffer('int_time', torch.tensor(times, dtype=torch.float, device=device))
        self.register_buffer('inv_int_time', torch.tensor(list(reversed(times)), dtype=torch.float, device=device))

    def set_odeint(self, method=None, rtol=1e-4, atol=1e-4):
    # def set_odeint(self, method='implicit_adams', rtol=1e-4, atol=1e-4):
        self.method = method
        self._atol = atol
        self._rtol = rtol
        self._atol_test = 1e-7
        self._rtol_test = 1e-7

    def set_trace(self, trace):
        assert trace == 'exact' or trace == 'hutch'
        self.odefunc.method = trace

    @property
    def atol(self):
        return self._atol if self.training else self._atol_test

    @property
    def rtol(self):
        return self._rtol if self.training else self._rtol_test

    def forward(self, x, edges=None, batch=None, node_mask=None, edge_mask=None, context=None):
        if batch is None:
            ldj = x.new_zeros(x.shape[0])
            reg_term = x.new_zeros(x.shape[0])
        else:
            batch_size = batch.max().item()+1
            ldj = torch.zeros(batch_size, device=x.device)
            reg_term = torch.zeros(batch_size, device=x.device)

        state = (x, ldj, reg_term)

        self.odefunc.before_odeint(x, edges, batch)
        self.odefunc.max_d_ldj = torch.tensor(0.0)
        self.odefunc.max_d_x = torch.tensor(0.0)
        self.odefunc.max_max_d_ldj = torch.tensor(0.0)
        self.odefunc.max_max_d_x = torch.tensor(0.0)

        # Wrap forward, do not unwrap until backward call!!!
        if node_mask is not None or edge_mask is not None or context is not None:
            self.odefunc.dynamics.forward = self.odefunc.dynamics.wrap_forward(
                node_mask, edge_mask, context)

        statet = odeint(self.odefunc, state, self.int_time,
                        method=self.method,
                        rtol=self.rtol,
                        atol=self.atol)

        zt, ldjt, reg_termt = statet
        z, ldj, reg_term = zt[-1], ldjt[-1], reg_termt[-1]

        self.max_d_ldj = abs(self.odefunc.max_d_ldj).mean()
        self.max_d_x = torch.norm(self.odefunc.max_d_x, p=2, dim=-1).mean()
        self.max_max_d_ldj = torch.max(abs(self.odefunc.max_max_d_ldj))
        self.max_max_d_x = torch.max(torch.norm(self.odefunc.max_max_d_x, p=2, dim=-1))

        self.dynamics_reg = self.odefunc.dynamics.reg_terms
        self.lipchitz_reg = self.max_d_x * (self.max_d_x.item()>1) + self.max_d_ldj * (self.max_d_ldj.item()>16)
        return z, ldj, reg_term

    def reverse_fn(self, z, edges, batch, node_mask=None, edge_mask=None, context=None):
        self.odefunc.before_odeint(z, edges, batch)

        with torch.no_grad():
            xt = odeint(self.odefunc.dynamics, z,self.inv_int_time,
                        method=self.method,
                        rtol=self.rtol,
                        atol=self.atol)

        return xt

    def reverse(self, z, edges, batch, node_mask=None, edge_mask=None, context=None):
        xt = self.reverse_fn(z, edges, batch, node_mask, edge_mask, context)
        x = xt[-1]
        return x

    def reverse_chain(self, z, node_mask, edge_mask, context=None):
        self.set_integration_time(times=list(np.linspace(0, 1, 50)))
        xt = self.reverse_fn(z, node_mask, edge_mask, context)
        self.set_integration_time(times=[0.0, 1.0])
        return xt


class ODEfunc(torch.nn.Module):
    def __init__(self, dynamics, method='hutch', ode_regularization=0, hutch_noise='gaussian'):
        assert method in {'exact', 'hutch'}
        super(ODEfunc, self).__init__()
        self.dynamics = dynamics
        self.hutch_noise = hutch_noise
        self.method = method
        self.ode_regularization = ode_regularization

        self.max_d_ldj = torch.tensor(0.0)
        self.max_d_x = torch.tensor(0.0)
        self.max_max_d_ldj = torch.tensor(0.0)
        self.max_max_d_x = torch.tensor(0.0)

    def set_trace_exact(self):
        self.method = 'exact'

    def set_trace_hutch(self):
        self.method = 'hutch'

    @staticmethod
    def hutch_trace(f, y, e=None, batch=None):
        """Hutchinson's estimator for the Jacobian trace"""
        e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0]
        e_dzdx_e = e_dzdx * e
        approx_tr_dzdx = sum_except_batch(e_dzdx_e) if batch is None else sum_to_batch(e_dzdx_e, batch.to(torch.int64))
        return approx_tr_dzdx

    @staticmethod
    def only_frobenius(f, y, e=None, batch=None):
        """Hutchinson's estimator for the Jacobian trace"""
        e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0]
        frobenius = sum_except_batch(e_dzdx.pow(2)) if batch is None else sum_to_batch(edzdx.pow(2), batch.to(torch.int64))
        return frobenius

    @staticmethod
    def hutch_trace_and_frobenius(f, y, e=None, batch=None):
        """Hutchinson's estimator for the Jacobian trace"""
        e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0]
        frobenius = sum_except_batch(e_dzdx.pow(2)) if batch is None else sum_to_batch(e_dzdx.pow(2), batch.to(torch.int64))
        e_dzdx_e = e_dzdx * e
        approx_tr_dzdx = sum_except_batch(e_dzdx_e) if batch is None else sum_to_batch(e_dzdx_e, batch.to(torch.int64))
        return approx_tr_dzdx, frobenius

    @staticmethod
    def exact_trace(f, y, batch=None):
        """Exact Jacobian trace"""
        dims = y.size()[1:]
        tr_dzdx = 0.0
        dim_ranges = [range(d) for d in dims]
        for idcs in itertools.product(*dim_ranges):
            batch_idcs = (slice(None),) + idcs
            tr_dzdx += torch.autograd.grad(f[batch_idcs].sum(), y, create_graph=True)[0][batch_idcs]
        return tr_dzdx

    @staticmethod
    def exact_jacobian(f, y):
        """Exact Jacobian"""
        jacobian = []
        for i in range(f.shape[1]):
            grad_f = torch.autograd.grad(f[:, i].sum(), y, create_graph=True)[0]
            jacobian.append(grad_f)
        jacobian = torch.stack(jacobian, dim=-1)
        return jacobian

    def before_odeint(self, tensor, edges, batch):
        self.edges = edges
        self.batch = batch
        self.num_evals = 0
        self.dynamics.egnn.listener = {}
        if self.method == 'hutch':

            if self.hutch_noise == 'gaussian':
                # With _eps ~ Normal(0, 1).
                self._eps = torch.randn_like(tensor)
            elif self.hutch_noise == 'bernoulli':
                # With _eps ~ Rademacher (== Bernoulli on -1 +1 with 50/50 chance).
                self._eps = torch.randint(low=0, high=2, size=tensor.size()).to(tensor) * 2 - 1
            else:
                raise Exception("Wrong hutchinson noise type")
        #try:
        #    self.dynamics.forward = self.dynamics.unwrap_forward()
        #except:
        #    warnings.warn("Warning: dynamics.unwrap_forward() was called but there is nothing to unwrap")

    def forward(self, t, state):
        x, ldj, reg_term = state
        batch = self.batch

        self.num_evals += 1
        with torch.set_grad_enabled(True):
            x.requires_grad_(True)
            t.requires_grad_(True)

            # We always need the dynamics :).
            dx = self.dynamics(t, x, self.edges)

            if self.ode_regularization > 0:
                # L2-squared norm of (dx)
                dx2 = sum_except_batch(dx.pow(2)) if batch is None else sum_to_batch(dx.pow(2), batch.to(torch.int64))

                # If trace is computed exact, frobenius norm is still estimated.
                if self.method == 'exact':
                    ldj = self.exact_trace(dx, x)
                    frobenius = self.only_frobenius(dx, x, e=self._eps, batch=batch)

                # Combined computation for trace and frobenius estimators.
                elif self.method == 'hutch':
                    ldj, frobenius = self.hutch_trace_and_frobenius(dx, x, e=self._eps, batch=batch)

                reg_term = frobenius + dx2

            else:
                if self.method == 'exact':
                    ldj = self.exact_trace(dx, x, batch=batch)

                elif self.method == 'hutch':
                    ldj = self.hutch_trace(dx, x, e=self._eps, batch=batch)

                # No regularization terms, set to zero.
                reg_term = torch.zeros_like(ldj)

            if self.training:
                self.jacobian = self.exact_jacobian(dx, x)
        
        if abs(ldj).mean() > abs(self.max_d_ldj).mean():
            self.max_d_ldj = ldj
        if torch.norm(dx, p=2, dim=-1).mean() > torch.norm(self.max_d_x, p=2, dim=-1).mean():
            self.max_d_x = dx
        if torch.max(abs(ldj)) > torch.max(abs(self.max_max_d_ldj)):
            self.max_max_d_ldj = ldj
        if torch.max(torch.norm(dx, p=2, dim=-1)) > torch.max(torch.norm(self.max_max_d_x, p=2, dim=-1)):
            self.max_max_d_x = dx

        return dx, ldj, reg_term