# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: skip-file
# pytype: skip-file
"""Various sampling methods."""

import torch
import numpy as np
from scipy import integrate
from scipy.special import gamma

def to_flattened_numpy(x):
  """Flatten a torch tensor `x` and convert it to numpy."""
  return x.detach().cpu().numpy().reshape((-1,))

def from_flattened_numpy(x, shape):
  """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
  return torch.from_numpy(x.reshape(shape))

def get_div_fn(fn):
    """Create the divergence function of `fn` using the Hutchinson-Skilling trace estimator."""

    def div_fn(x, t, eps):
        with torch.enable_grad():
            x.requires_grad_(True)
            fn_eps = torch.sum(fn(x, t) * eps)
            grad_fn_eps = torch.autograd.grad(fn_eps, x)[0]
        x.requires_grad_(False)
        return torch.sum(grad_fn_eps * eps, dim=tuple(range(1, len(x.shape))))

    return div_fn


def get_likelihood_fn(levy, sde, device, hutchinson_type='Rademacher',
                      rtol=1e-5, atol=1e-5, method='RK45', eps=1e-5):
    """Create a function to compute the unbiased log-likelihood estimate of a given data point.

    Args:
        sde: A `sde_lib.SDE` object that represents the forward SDE.
        inverse_scaler: The inverse data normalizer.
        hutchinson_type: "Rademacher" or "Gaussian". The type of noise for Hutchinson-Skilling trace estimator.
        rtol: A `float` number. The relative tolerance level of the black-box ODE solver.
        atol: A `float` number. The absolute tolerance level of the black-box ODE solver.
        method: A `str`. The algorithm for the black-box ODE solver.
        See documentation for `scipy.integrate.solve_ivp`.
        eps: A `float` number. The probability flow ODE is integrated to `eps` for numerical stability.

    Returns:
        A function that a batch of data points and returns the log-likelihoods in bits/dim,
        the latent code, and the number of function evaluations cost by computation.
    """

    def drift_fn(model, x, t):
        """The drift function of the reverse-time SDE."""
        # score_fn = mutils.get_score_fn(sde, model, train=False, continuous=True)
        # # Probability flow ODE is a special case of Reverse SDE
        # rsde = sde.reverse(score_fn, probability_flow=True)
        # return rsde.sde(x, t)[0]

        score = model(x, t) * torch.pow(sde.marginal_std(t), -(sde.alpha-1)).view(-1,1,1,1)
        
        x_coeff = - sde.beta(t) / sde.alpha        
        sigma = torch.pow(sde.beta(t), 1/sde.alpha)
        
        score_coeff = - sde.beta(t)

        drift = x_coeff[:, None, None, None] * x + score_coeff[:, None, None, None] * score
        
        # if sde.alpha != 2.0:
        #     drift = torch.clamp(drift, -2.3, 2.3)

        return drift

    def div_fn(model, x, t, noise):
        return get_div_fn(lambda xx, tt: drift_fn(model, xx, tt))(x, t, noise)

    def likelihood_fn(model, data, inverse_scalar=0, t_0=1e-5, T=1):
        """Compute an unbiased estimate to the log-likelihood in bits/dim.

        Args:
        model: A score model.
        data: A PyTorch tensor.

        Returns:
        bpd: A PyTorch tensor of shape [batch size]. The log-likelihoods on `data` in bits/dim.
        z: A PyTorch tensor of the same shape as `data`. The latent representation of `data` under the
            probability flow ODE.
        nfe: An integer. The number of function evaluations used for running the black-box ODE solver.
        """
        with torch.no_grad():
            shape = data.shape
        if hutchinson_type == 'Gaussian':
            epsilon = torch.randn_like(data)
        elif hutchinson_type == 'Rademacher':
            epsilon = torch.randint_like(data, low=0, high=2).float() * 2 - 1.
        else:
            raise NotImplementedError(f"Hutchinson type {hutchinson_type} unknown.")

        def ode_func(t, x):
            sample = from_flattened_numpy(x[:-shape[0]], shape).to(data.device).type(torch.float32) # (B, 1, 28, 28)
            vec_t = torch.ones(sample.shape[0], device=sample.device) * t   # (B,)
            drift = to_flattened_numpy(drift_fn(model, sample, vec_t))  # (B x 1 x 28 x 28,)
            logp_grad = to_flattened_numpy(div_fn(model, sample, vec_t, epsilon))   # (B,)
            return np.concatenate([drift, logp_grad], axis=0)   # (B x 1 x 28 x 28 + B,)
        
        init = np.concatenate([to_flattened_numpy(data), np.zeros((shape[0],))], axis=0)    # (B x 1 x 28 x 28 + B,)
        solution = integrate.solve_ivp(ode_func, (t_0, T), init, rtol=rtol, atol=atol, method=method)
        nfe = solution.nfev
        zp = solution.y[:, -1]
        z = from_flattened_numpy(zp[:-shape[0]], shape).to(data.device).type(torch.float32)
        delta_logp = from_flattened_numpy(zp[-shape[0]:], (shape[0],)).to(data.device).type(torch.float32)
        prior_logp = torch.sum(torch.log(levy.pdf(z, sde.alpha)), dim=(1,2,3))
        bpd = -(prior_logp + delta_logp) / np.log(2)
        N = np.prod(shape[1:])
        bpd = bpd / N

        # A hack to convert log-likelihoods to bits/dim
        offset = 7. - inverse_scalar
        bpd = bpd + offset
        return bpd, z, nfe

    return likelihood_fn