import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from diffusers import UNet2DModel, DDPMScheduler
from torchdiffeq import odeint
from scipy.interpolate import interp1d
from scipy.integrate import quad
import math

class DDPMLogProb(nn.Module):
    def __init__(self,
                 model: UNet2DModel, 
                 scheduler: DDPMScheduler,
                 num_hutchinson_samples: int = 1,
                 ode_step_size: float = 1e-1,
                 ):
        super().__init__()
        self.model = model
        self.scheduler = scheduler
        self.num_hutchinson_samples = num_hutchinson_samples
        self.ode_step_size = ode_step_size

        betas_np = self.scheduler.betas.cpu().numpy()
        betas_rescaled = betas_np * scheduler.config.num_train_timesteps
        alpha_bars_np = self.scheduler.alphas_cumprod.cpu().numpy()
        # sigmas_np     = np.sqrt((1 - alpha_bars_np) / alpha_bars_np) 
        sigmas_np = np.sqrt((1 - alpha_bars_np)) # I think this is correct, but we should double check
        _ts = np.linspace(0, 1, len(betas_np))
        self.beta_fn = interp1d(_ts, betas_rescaled, kind="linear", fill_value="extrapolate")
        self.alpha_bar_fn = interp1d(_ts, alpha_bars_np, kind="linear", fill_value="extrapolate")
        self.sigma_fn = interp1d(_ts, sigmas_np, kind="linear", fill_value="extrapolate")

    def velocity(self, x, t):
        """
        v(x,t) = -½ beta(t) [ x + eps_cont(x,t) / sigma(t) ]
        where eps_cont is linearly interpolated between the two nearest timesteps.
        """
        # clamp & scalarize
        t_val = t.clamp(0, 1)

        # continuous scheduler values
        beta_t = torch.tensor(self.beta_fn(t_val.cpu().item())).to(x).view(*[1] * x.dim()) # Expand to match x's shape
        sigma_t = torch.tensor(self.sigma_fn(t_val.cpu().item())).to(x).view(*[1] * x.dim()) # Expand to match x's shape

        # rescale t to lie in [0, T-1] and get the two nearest integer timesteps
        T = self.scheduler.config.num_train_timesteps
        t_scaled = t_val * T
        idx0 = int(math.floor(t_scaled))
        idx1 = min(idx0 + 1, T - 1)
        w = 0.5

        t0 = torch.tensor([idx0], dtype=torch.long, device=x.device) # These need to be integers (torch.long) -- see https://github.com/huggingface/diffusers/blob/v0.33.1/src/diffusers/models/unets/unet_2d.py#L278
        t1 = torch.tensor([idx1], dtype=torch.long, device=x.device)

        # predict eps at both timesteps
        # with torch.no_grad():
        #     eps0 = self.model(x.float(), t0).sample
            # eps1 = self.model(x.float(), t1).sample
        eps0 = self.model(x.float(), t0).sample
        eps1 = self.model(x.float(), t1).sample

        # interpolate eps
        eps = (1 - w) * eps0 + w * eps1

        # compute drift function
        # return 0.5 * beta_t * (x + eps / sigma_t) # We want to point *away* from noise if we're taking steps going from T to 0, so we need to flip the sign here.
        return 0.5 * beta_t * (eps / sigma_t - x)
    
    def hutchinson_trace_estimate(self, x, t):
        """Vectorized Hutchinson estimator over both batch and samples."""
        B, *dims = x.shape
        # 1) detach and build a big “expanded” input that requires grad:
        x_expand = x.detach().unsqueeze(0)                       # [1, B, ...]
        x_expand = x_expand.expand(self.num_hutchinson_samples, *x.shape)
        x_flat = x_expand.reshape(-1, *dims)                     # [S*B, ...]
        x_flat = x_flat.requires_grad_(True)

        # 2) make all the probe vectors:
        z = torch.randn((self.num_hutchinson_samples, *x.shape), device=x.device)
        z_flat = z.reshape(-1, *dims)                             # [S*B, ...]

        # 3) forward pass in one big batch:
        v_flat = self.velocity(x_flat, t)                                 # [S*B, ...]
        inner = (v_flat * z_flat).view(v_flat.size(0), -1).sum(1)  # [S*B]

        # 4) backprop gradient of ∑inner w.r.t. x_flat:
        grads_flat = torch.autograd.grad(inner.sum(), x_flat, create_graph=True)[0]
        # grads_flat is [S*B, ...] containing ∇x·(v·z) for each (s,b)

        # 5) assemble per-sample trace estimates and average:
        trace_per_sample = (grads_flat * z_flat)            \
                        .view(self.num_hutchinson_samples, B, -1) \
                        .sum(dim=2)                       # [S, B]
        trace = trace_per_sample.mean(dim=0)                # [B]
        return trace
    
    def exact_divergence(self, x, t):
        """
        Compute the trace of the Jacobian of velocity(x, t) w.r.t. x using autograd.functional.jacobian.
        
        Args:
            x: Tensor of shape (B, D)
            t: Tensor of shape (B,) or scalar or broadcastable to (B,)
            velocity: Callable(x, t) -> Tensor of shape (B, D)
        
        Returns:
            trace: Tensor of shape (B,)
        """
        B, D = x.shape
        traces = []
        for i in range(B):
            xi = x[i].detach().requires_grad_(True)  # shape (D,)

            def single_velocity(xi_):
                return self.velocity(xi_.unsqueeze(0), t)[0]

            J = torch.autograd.functional.jacobian(single_velocity, xi).squeeze()  # shape (D, D)
            print(J)
            traces.append(torch.trace(J))  # scalar
        return torch.stack(traces)  # shape (B,)
    
    def ode_func(self, t, state):
        """For torchdiffeq.odeint"""
        x, logp = state
        v   = self.velocity(x, t)
        div = self.hutchinson_trace_estimate(x, t)
        return v, div
    
    def manual_forward_euler(self, x0, t_init=0.0, t_final=1.0):
        """Manually implement the forward Euler method for ODE integration."""
        num_steps = int((t_final - t_init) / self.ode_step_size)
        dt = (t_final - t_init) / num_steps
        x = x0
        logp = torch.zeros(x.shape[0]).to(x)
        x_seq = []
        delta_logp_seq = []
        x_seq.append(x)
        delta_logp_seq.append(logp)
        for i in range(num_steps):
            t = torch.tensor([t_init + i * dt]).to(x)
            v = self.velocity(x, t)
            div = self.hutchinson_trace_estimate(x, t)
            x = x + dt * v
            logp = logp + dt * div
            x = x.detach()
            logp = logp.detach()
            x_seq.append(x)
            delta_logp_seq.append(logp)
            del v, div
        x_seq = torch.stack(x_seq)
        delta_logp_seq = torch.stack(delta_logp_seq)
        return x_seq, delta_logp_seq


    def forward(self, x0, t):
        B  = x0.size(0)
        logp0 = torch.zeros(B).to(x0)
        # t_span = torch.tensor([t, 1.0]).to(x0)
        # x_seq, delta_logp_seq = odeint(self.ode_func, (x0, logp0), t_span, method="euler", options={"step_size": 1e-2}) # step_size=1e-1 worked well
        x_seq, delta_logp_seq = self.manual_forward_euler(x0, t_init=t, t_final=1.0)
        xT, delta_logp = x_seq[-1], delta_logp_seq[-1]

        flat_dim = xT[0].numel()
        # Sum over all dims but the first
        logp_T  = -0.5 * torch.sum(xT**2, dim=tuple(range(1, xT.dim()))) \
                - 0.5 * flat_dim * np.log(2 * np.pi)

        # return logp_T + delta_logp, x_seq, delta_logp_seq, logp_T
        return logp_T + delta_logp
    
# This is a dummy model that returns zero noise for all inputs.
class ZeroModel(nn.Module):
    def forward(self, x, t):
        return type('obj', (object,), {'sample': torch.zeros_like(x)})
    
# Test case for DDPMLogProb
# DDPMLogProb(model=ZeroModel(), scheduler=...) should agree with ZeroModelLogProb
class ZeroModelLogProb(nn.Module):
    def __init__(self, scheduler: DDPMScheduler):
        super().__init__()
        self.model = ZeroModel()
        self.scheduler = scheduler
        # Precompute beta and alpha_bar values
        self.betas = scheduler.betas.cpu().numpy()
        self.betas_rescaled = self.betas * scheduler.config.num_train_timesteps
        _ts = np.linspace(0, 1, len(self.betas_rescaled))
        self.beta_t_fn = interp1d(_ts, self.betas_rescaled, kind='linear', fill_value="extrapolate")

    def beta_integral(self, t_init=0.0):
        return quad(lambda t: self.beta_t_fn(t), t_init, 1.0)[0]

    def forward(self, x0, t):
        flat = x0.view(x0.size(0), -1)
        D = flat.size(1)
        integral_beta = self.beta_integral(t)
        exp_term = np.exp(-0.5 * integral_beta) # Because the linear velocity field is -0.5 * beta(t) * x when eps is zeroed out
        x1 = x0 * exp_term
        logp_1  = -0.5 * torch.sum(x1**2, dim=tuple(range(1, x1.dim()))) \
                - 0.5 * D * np.log(2 * np.pi)
        logp_t = logp_1 - 0.5 * D * integral_beta # - because we want to *add* the integral of -0.5*D*beta(t)
        return logp_t