import torch
import numpy as np
from scipy import integrate
from src.constants import EPS_SDE
from src.utils.diffusion_utils import to_flattened_numpy, from_flattened_numpy, get_score_fn


# def get_div_fn(fn, exact=False):
#   """Create the divergence function of `fn` using the Hutchinson-Skilling trace estimator."""
#   if exact:
#     def div_fn(x,t):
#       with torch.enable_grad():
#         x.requires_grad_(True)
#         grads=[]
#         for i, v in enumerate(torch.eye(x.shape[1], device=x.device)): # iterate over rows of identity matrix

#           # the autograd below does a couple of things at once. Firstly it computes the jacobian of fn(x,t) w.r.t. x. 
#           # dim fn(x,t) = [batch_size, data_dim]. dim x = [batch_size, data_dim]. dim jacobian = [batch_size, data_dim, data_dim]
#           # Secondly, it multiplies the transposed jacobian by the vector v, i.e row of the identity matriv with 1 at the i position.
#           # This multiplicaton extracts the i-th column of the jacobian.T. dim jacobian.T*v = [batch_size, data_dim]
#           # And lastly, it takes ith element of the extracted column. The result is a vector of length batch_size, each entry if partial
#           # derivative fn_i wrt x_i.
#           gradients = torch.autograd.grad(outputs=fn(x,t), inputs=x,
#                               grad_outputs=v.repeat(x.shape[0],1),
#                               create_graph=True, retain_graph=True, only_inputs=True)[0][:,i] 
#           grads.append(gradients)
#       grads = torch.stack(grads,dim=1)
#       return torch.sum(grads, dim=1)
#   else: # code by Song, we do not use it for now
#     def div_fn(x, t, eps):
#       with torch.enable_grad():
#         x.requires_grad_(True)
#         fn_eps = torch.sum(fn(x, t) * eps)
#         grad_fn_eps = torch.autograd.grad(fn_eps, x)[0]
#       x.requires_grad_(False)
#       return torch.sum(grad_fn_eps * eps, dim=tuple(range(1, len(x.shape))))

#   return div_fn


# def get_likelihood_fn(sde, inverse_scaler=None, exact=False, hutchinson_type='Rademacher', 
#                       rtol=1e-5, atol=1e-5, method='RK45', eps=1e-5):
#   """Create a function to compute the unbiased log-likelihood estimate of a given data point.

#   Args:
#     sde: A `sde_lib.SDE` object that represents the forward SDE.
#     inverse_scaler: The inverse data normalizer.
#     hutchinson_type: "Rademacher" or "Gaussian". The type of noise for Hutchinson-Skilling trace estimator.
#     rtol: A `float` number. The relative tolerance level of the black-box ODE solver.
#     atol: A `float` number. The absolute tolerance level of the black-box ODE solver.
#     method: A `str`. The algorithm for the black-box ODE solver.
#       See documentation for `scipy.integrate.solve_ivp`.
#     eps: A `float` number. The probability flow ODE is integrated to `eps` for numerical stability.

#   Returns:
#     A function that a batch of data points and returns the log-likelihoods in bits/dim,
#       the latent code, and the number of function evaluations cost by computation.
#   """

#   def ode_drift_fn_song(model, x, t):
#     """The drift function of the reverse-time ODE."""
#     # score_fn = mutils.get_score_fn(sde, model, train=False, continuous=True)
#     # Probability flow ODE is a special case of Reverse SDE
#     # drift, diffusion = sde.get_reverse_sde_coefficients[0]
#     return sde.get_reverse_sde_coefficients[0]

#   def div_fn(model, x, t, noise):
#     return get_div_fn(lambda xx, tt: ode_drift_fn(model, xx, tt), exact)(x, t)
  
def exact_div_fn(sde, x, t):
    with torch.enable_grad():
      def fn(x,t):
        return sde.get_reverse_sde_coefficients(x, t, probability_flow=True)[0] # return the drift of reverse prob flow ODE
      
      x.requires_grad_(True)
      grads=[]
      for i, v in enumerate(torch.eye(x.shape[1], device=x.device)): # iterate over rows of identity matrix

        # the autograd below does a couple of things at once. Firstly it computes the jacobian of fn(x,t) w.r.t. x. 
        # dim fn(x,t) = [batch_size, data_dim]. dim x = [batch_size, data_dim]. dim jacobian = [batch_size, data_dim, data_dim]
        # Secondly, it multiplies the transposed jacobian by the vector v, i.e row of the identity matriv with 1 at the i position.
        # This multiplicaton extracts the i-th column of the jacobian.T. dim jacobian.T*v = [batch_size, data_dim]
        # And lastly, it takes ith element of the extracted column. The result is a vector of length batch_size, each entry if partial
        # derivative fn_i wrt x_i.
        gradients = torch.autograd.grad(outputs=fn(x,t), inputs=x,
                            grad_outputs=v.repeat(x.shape[0],1),
                            create_graph=True, retain_graph=True, only_inputs=True)[0][:,i] 
        grads.append(gradients)
    grads = torch.stack(grads,dim=1)
    return torch.sum(grads, dim=1)

def calculate_likelihood(denoiser, sde, data, rtol=1e-5, atol=1e-5, method='RK45'):
  """Compute an unbiased estimate to the log-likelihood

  Args:
    model: A score model.
    data: A PyTorch tensor.
  """
  score_fn = get_score_fn(denoiser, if_training=False)

  sde.init_score_fn(score_fn)
  
  with torch.no_grad():
    shape = data.shape

    def ode_func(t, x):
      sample = from_flattened_numpy(x[:-shape[0]], shape).to(data.device).type(torch.float32) # x[:-shape[0]] selects the data dimensions from x
      vec_t = torch.ones(sample.shape[0], device=sample.device) * t

      drift = to_flattened_numpy(sde.get_reverse_sde_coefficients(sample, vec_t, probability_flow=True)[0]) # extract ODE drift from reverse SDE
      logp_grad = to_flattened_numpy(exact_div_fn(sde, sample, vec_t)) # exact caculation of divergence, returns a vector of length batch_size
      return np.concatenate([drift, logp_grad], axis=0) # return concatenation of how sample changes and how logp changes - this makes it consistent with the flattened representation of x

    init = np.concatenate([to_flattened_numpy(data), np.zeros((shape[0],))], axis=0) # the x that we are solving is a concatenation of flattened data with prob loglik for each datapoint
    solution = integrate.solve_ivp(ode_func, (EPS_SDE, sde.T), init, rtol=rtol, atol=atol, method=method)

    zp = solution.y[:, -1]
    z = from_flattened_numpy(zp[:-shape[0]], shape).to(data.device).type(torch.float32)
    delta_logp = from_flattened_numpy(zp[-shape[0]:], (shape[0],)).to(data.device).type(torch.float32)
    prior_logp = sde.prior_logp(z)
    sample_loglik = (prior_logp + delta_logp)

    return sample_loglik
