# Continous normalizing flow
# FFJORD (https://arxiv.org/abs/1810.01367) [https://github.com/rtqichen/ffjord]
# Neural ODEs (https://arxiv.org/abs/1806.07366) [https://github.com/rtqichen/torchdiffeq]
# Regularization (https://arxiv.org/abs/2002.02798) [https://github.com/cfinlay/ffjord-rnode]

import nf
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchdiffeq import odeint_adjoint, odeint


def rms_norm(tensor):
    return tensor.pow(2).mean().sqrt()

def make_norm(state):
    state_size = state.numel()
    def norm(aug_state):
        y = aug_state[1:1 + state_size]
        adj_y = aug_state[1 + state_size:1 + 2 * state_size]
        return max(rms_norm(y), rms_norm(adj_y))
    return norm


def divergence_bf(y, x, *args):
    diag = torch.zeros_like(x)
    for i in range(x.shape[1]):
        d = torch.autograd.grad(y[:, i].sum(), x, create_graph=True)[0].contiguous()[:, i].contiguous()
        diag[:,i] += d
    return diag, diag.view(x.shape[0], -1).pow(2).mean(dim=1)

def divergence_approx(f, y, e):
    e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0]
    sqnorm = e_dzdx.view(y.shape[0], -1).pow(2).mean(dim=1)
    e_dzdx_e = e_dzdx * e
    return e_dzdx_e, sqnorm


class ODEfunc(nn.Module):
    def __init__(self, diffeq, divergence_fn=None, rademacher=False, returns_divergence=False):
        super().__init__()
        self.diffeq = diffeq
        self.rademacher = rademacher
        self.returns_divergence = returns_divergence

        if divergence_fn == 'brute_force':
            self.divergence_fn = divergence_bf
        elif divergence_fn == 'approximate':
            self.divergence_fn = divergence_approx
        elif not returns_divergence:
            raise ValueError('Divergence function should be "brute_force" or "approximate".')

        self.register_buffer('_num_evals', torch.tensor(0.))

    def before_odeint(self, e=None):
        self._e = e
        self._num_evals.fill_(0)

    def num_evals(self):
        return self._num_evals.item()

    def forward(self, t, states):
        # states = [y, logp, reg, (latent)]
        assert len(states) == 3 or len(states) == 4

        y = states[0]

        self._num_evals += 1
        batchsize = y.shape[0]

        # Sample and fix the noise.
        if self._e is None:
            if self.rademacher:
                self._e = torch.randint(low=0, high=2, size=y.shape).to(y) * 2 - 1
            else:
                self._e = torch.randn_like(y)

        with torch.set_grad_enabled(True):
            y.requires_grad_(True)
            t.requires_grad_(True)
            for s in states[2:]:
                s.requires_grad_(True)

            # Construct input by concatenating time and y (and latent)
            odefunc_input = [torch.ones_like(y[...,:1]) * t, y]
            if len(states) == 4: # if latent
                odefunc_input.append(states[3])
            odefunc_input = torch.cat(odefunc_input, -1).requires_grad_(True)

            # If function has divergence
            if self.returns_divergence:
                dy, divergence, sqjacnorm = self.diffeq(odefunc_input)
                return tuple([dy, -divergence, sqjacnorm]) + states[3:]

            # Else, get divergence and regularization
            dy = self.diffeq(odefunc_input)
            if not self.training:
                divergence, _ = divergence_bf(dy, y)
                return tuple([dy, -divergence]) + states[2:]
            else:
                divergence, sqjacnorm = self.divergence_fn(dy, y, self._e)
                return tuple([dy, -divergence, sqjacnorm]) + states[3:]


class ContinuousFlow(nn.Module):
    def __init__(self, dim, net=None, T=1.0, divergence_fn='approximate', use_adjoint=True,
                 solver='dopri5', solver_options={}, test_solver=None, test_solver_options=None,
                 atol=1e-5, rtol=1e-3, returns_divergence=False,
                 faster_adjoint=False, rademacher=False, **kwargs):
        """
        Continuous normalizing flow.
        Also look at Nerual ODE documentation: https://github.com/rtqichen/torchdiffeq

        Args:
            dim: Input data dimension
            net: Neural net that defines a differential equation, instance of `net`
            T: Integrate from 0 until T (Default: 1.0)
            divergence_fn: How to calculate divergence, 'approximate' or 'brute_force'
            use_adjoint: Whether to use adjoint method for backpropagation
            solver: ODE black-box solver, adaptive: dopri5, dopri8, bosh3, adaptive_heun;
                fixed-step: euler, midpoint, rk4, explicit_adams, implicit_adams
            solver_options: Additional options, e.g. {'step_size': 10}
            test_solver: Same as solver, used during evaluation
            test_solver_options: Same as solver_options, used during evaluation
            atol: Tolerance (Default: 1e-5)
            rtol: Tolerance (Default: 1e-5)
            returns_divergence: If 'net' calculates divergence directly (Default: False)
            fastet_adjoint: Whether to enable training trick from
                https://arxiv.org/abs/2009.09457 (Default: False)
            rademacher: Whether to use rademacher sampling (Default: False)
        """
        super().__init__()

        self.T = T
        self.dim = dim

        self.odefunc = ODEfunc(net, divergence_fn, rademacher, returns_divergence)

        self.integrate = odeint_adjoint if use_adjoint else odeint

        self.solver = solver
        self.solver_options = solver_options
        self.test_solver = test_solver or solver
        self.test_solver_options = solver_options if test_solver_options is None else test_solver_options

        self.atol = atol
        self.rtol = rtol

        self.faster_adjoint = faster_adjoint

        self.regularization = None

    def forward(self, x, latent=None, reverse=False, **kwargs):
        # Set inputs
        *shape, dim = x.shape
        x = x.view(-1, dim)
        logp = torch.zeros_like(x)

        # Set integration times
        integration_times = torch.tensor([0.0, self.T]).to(x)
        if reverse:
            integration_times = _flip(integration_times, 0)

        # Refresh the odefunc statistics
        self.odefunc.before_odeint()

        # Set initial state with regularization (and latent)
        initial = [x, logp, torch.zeros(x.shape[0]).to(x)]
        if latent is not None:
            initial.append(latent.view(-1, latent.shape[-1]))
        initial = tuple(initial)

        # Solve ODE
        if self.training:
            state_t = self.integrate(
                self.odefunc,
                initial,
                integration_times,
                atol=self.atol, #[self.atol, self.atol] + [1e20] * len(initial[2:]) if self.solver in ['dopri5', 'bosh3'] else self.atol,
                rtol=self.rtol, #[self.rtol, self.rtol] + [1e20] * len(initial[2:]) if self.solver in ['dopri5', 'bosh3'] else self.rtol,
                method=self.solver,
                options=self.solver_options,
                adjoint_options=dict(norm=make_norm(torch.cat([y.flatten() for y in initial]))) if self.faster_adjoint else {}
            )
        else:
            state_t = self.integrate(
                self.odefunc,
                initial,
                integration_times,
                atol=self.atol,
                rtol=self.rtol,
                method=self.test_solver,
                options=self.test_solver_options,
            )

        if len(integration_times) == 2:
            state_t = tuple(s[1] for s in state_t)

        # Collect outputs with correct shape
        x, logp = [s.view(*shape, dim) for s in state_t[:2]]
        self.regularization = state_t[2].abs()

        return x, -logp

    def inverse(self, x, logp=None, latent=None, **kwargs):
        return self.forward(x, logp=logp, latent=latent, reverse=True)

    def num_evals(self):
        return self.odefunc._num_evals.item()

def _flip(x, dim):
    indices = [slice(None)] * x.dim()
    indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device)
    return x[tuple(indices)]


class ContinuousSetFlow(nn.Module):
    def __init__(self, dim, net=None, input_time=False, **kwargs):
        """
        Continuous normalizing flow that works on sets.
        Inherits all arguments.
        """
        super().__init__()

        self.net = nf.net.CNFSetWrapper(net, input_time=input_time)
        self.flow = ContinuousFlow(dim, net=self.net, **kwargs)

    @property
    def regularization(self):
        return self.flow.regularization

    def forward(self, x, latent=None, reverse=False, mask=None, **kwargs):
        self.net.shape = x.shape
        self.net.mask = mask
        x = x.view(-1, x.shape[1] * x.shape[2])

        y, logp = self.flow.forward(x, latent=latent, reverse=reverse)
        return y.view(*self.net.shape), logp.view(*self.net.shape)

    def inverse(self, x, latent=None, mask=None, **kwargs):
        return self.forward(x, latent=latent, reverse=True, mask=mask, **kwargs)
