# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Modifications Copyright 2025, XXX


import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import functools 
from scipy.integrate import solve_ivp
from torchdiffeq import odeint


#EMD assumes s(t)=1 and sigma(t)=t

#Assuming EMD, the diffusion coefficient reduces to sqrt(2sigma)=sqrt(2t), since:
#drift_coef(t) = s(t)sqrt(2sigma'(t)sigma(t))   #Karras et al., NeurIPS 2022, Eq. (34)

#Assuming EMD, the standard deviation of the perturbation kernel reduces to sigma=t, since:
#var_kernel(t) = s^2(t)sigma^2(t)   #Karras et al., NeurIPS 2022, Eq. (11)


def loglik(x, score_model, sigma_min=1e-5, sigma_max=1.0):

  #sigma_min: The smallest time step for numerical stability.
  device='cpu'

  def marginal_prob_std(t):
   std = t if torch.is_tensor(t) else torch.tensor(t, device=device)
   return std

  def diffusion_coeff(t):  
     t = t if torch.is_tensor(t) else torch.tensor(t, device=device)
     #diffusion = torch.sqrt(t) #incorrect
     diffusion = torch.sqrt(2.0*t) #correct
     return diffusion
    
  marginal_prob_std_fn = functools.partial(marginal_prob_std)
  diffusion_coeff_fn = functools.partial(diffusion_coeff)
  
  def prior_likelihood(z):
    shape = z.shape
    N = shape[1]
    return -N / 2. * np.log(2 * np.pi * sigma_max ** 2) - torch.sum(z ** 2, dim=1) / (2 * sigma_max ** 2)

  def ode_likelihood(x, 
                    score_model,
                    marginal_prob_std, 
                    diffusion_coeff,
                    batch_size, 
                    device,
                    sigma_min,
                    sigma_max):
    """Compute the likelihood with probability flow ODE.
    
    Args:
      x: Input data.
      score_model: A PyTorch model representing the score-based model.
      marginal_prob_std: A function that gives the standard deviation of the 
        perturbation kernel.
      diffusion_coeff: A function that gives the diffusion coefficient of the 
        forward SDE.
      batch_size: The batch size. Equals to the leading dimension of `x`.
      device: 'cuda' for evaluation on GPUs, and 'cpu' for evaluation on CPUs.
      sigma_min: A `float` number. The smallest time step for numerical stability.

    Returns:
      log_lik: log p(x), where p is the density induced by diffusion model
    """

    num_hutchinson_samples = 16

    def divergence_eval(sample, time, num_hutchinson_samples):      
      """Compute the divergence of the score-based model with Skilling-Hutchinson."""
      divs = []
      for _ in range(num_hutchinson_samples):
          normal_sample = torch.randn_like(sample)
          with torch.enable_grad():
              sample.requires_grad_(True)
              score_e = torch.sum(score_model(sample, time) * normal_sample)
              grad_score_e = torch.autograd.grad(score_e, sample, create_graph=False)[0]
          div = torch.sum(grad_score_e * normal_sample, dim=1)
          divs.append(div)
      return torch.stack(divs, dim=0).mean(dim=0)
    
    #To get more stability, let's use torchdiffeq, since integration is very unstable, stifness etc.
    class ODEFunc(torch.nn.Module):
      def forward(self, t, states):
          sample, logp = states
          t_tensor = t.expand(sample.shape[0]).to(device)
          with torch.no_grad():
              score = score_model(sample, t_tensor)
          g = diffusion_coeff(t)
          dx = -0.5 * g**2 * score #Assumes zero drift term
          div = divergence_eval(sample, t_tensor, num_hutchinson_samples=num_hutchinson_samples) # Hutchinson's trace estimator for divergence
          dlogp = -0.5 * g**2 * div
          return (dx, dlogp)

    logp0 = torch.zeros(x.shape[0], device=device) 
    ode_func = ODEFunc()
    t_span = torch.tensor([sigma_min, sigma_max], device=device)
    xT, delta_logp = odeint(
        ode_func,
        (x, logp0), #Initial state: (x, zeros_like(batch))
        t_span,
        rtol=1e-6,
        atol=1e-6,
        method='implicit_adams' #can handle midly stiffness better than 'dopri5' which is RK45 equivalent in torchdiffeq
    )
    xT = xT[-1]
    delta_logp = delta_logp[-1] #accumulated log-prob change

    prior_logp = prior_likelihood(xT)
    log_likelihood = prior_logp + delta_logp
    return log_likelihood
    
    #Log lik in bit per dim (not needed)
    #bpd = -(prior_logp + delta_logp) / np.log(2)
    #N = np.prod(shape[1:])
    #bpd = bpd / N + 8.
    #return z, bpd
  
  log_lik = ode_likelihood(x, score_model, marginal_prob_std_fn,diffusion_coeff_fn, x.shape[0], device=device, sigma_min=sigma_min, sigma_max=sigma_max)
  
  return log_lik


#Depreciated: maybe incorrect
def loglik_data(data, ema, temp, s, sigma_min, sigma_max, n_repeats=11):
    all_log_liks = []
    fulldata = data
    winners = fulldata.clone()[0, :, :].transpose(0, 1)  # (batch_size, D)
    losers = fulldata.clone()[1, :, :].transpose(0, 1)   # (batch_size, D)
    for _ in range(n_repeats):
        logp_winners = loglik(
            winners,
            lambda x, t: ema(x, t, joint=0, temp=temp),
            sigma_min=sigma_min,
            sigma_max=sigma_max
        )
        logp_losers = loglik(
            losers,
            lambda x, t: ema(x, t, joint=0, temp=temp),
            sigma_min=sigma_min,
            sigma_max=sigma_max
        )
        log_lik = torch.nn.functional.logsigmoid((logp_winners - logp_losers) / s)
        all_log_liks.append(log_lik)
    all_log_liks = torch.stack(all_log_liks, dim=0)  # shape: (n_repeats, batch_size)
    median_log_lik = all_log_liks.median(dim=0).values  # take median across repeats
    return median_log_lik.mean()