import torch
import torch.nn as nn
import numpy as np
import math
from scipy.interpolate import interp1d
from tqdm import tqdm

# Assumes ScoreFromEps and CFDM are available from utils/cfdm_ddpm_conversion.py and closed_form_diffusion/cfdm.py
from diffusers import UNet2DModel, DDPMScheduler
from utils.cfdm_ddpm_conversion import ScoreFromEps
from closed_form_diffusion.cfdm import CFDM, CFDM_NN


class DDPMSampleSensitivity(nn.Module):
	def __init__(self, 
				 model: UNet2DModel, 
				 scheduler: DDPMScheduler,
				 num_hutchinson_samples: int = 1, 
				 ode_step_size: float = 1e-1,
				 min_clamp: float = None,
				 max_clamp: float = None	
				 ):
		super().__init__()
		self.model = model
		self.scheduler = scheduler
		self.num_hutchinson_samples = num_hutchinson_samples
		self.ode_step_size = ode_step_size
		self.min_clamp = min_clamp
		self.max_clamp = max_clamp

		betas_np = self.scheduler.betas.cpu().numpy()
		betas_rescaled = betas_np * scheduler.config.num_train_timesteps
		alpha_bars_np = self.scheduler.alphas_cumprod.cpu().numpy()
		sigmas_np = np.sqrt((1 - alpha_bars_np))
		_ts = np.linspace(0, 1, len(betas_np))
		self.beta_fn = interp1d(_ts, betas_rescaled, kind="linear", fill_value="extrapolate")
		self.alpha_bar_fn = interp1d(_ts, alpha_bars_np, kind="linear", fill_value="extrapolate")
		self.sigma_fn = interp1d(_ts, sigmas_np, kind="linear", fill_value="extrapolate")
		# Score function from DDPM model
		self.score_from_eps = ScoreFromEps(self.model, self.scheduler)

	def build_cfdm(self, X_1, nn_estimator=False):
		if nn_estimator:
			self.cfdm_X1 = CFDM_NN(X_1, self.scheduler, K=100, L=100)
		else:
			self.cfdm_X1 = CFDM(X_1, self.scheduler)

	def velocity(self, x, t):
		t_val = t.clamp(0, 1)
		beta_t = torch.tensor(self.beta_fn(t_val.cpu().item())).to(x).view(*[1] * x.dim())
		sigma_t = torch.tensor(self.sigma_fn(t_val.cpu().item())).to(x).view(*[1] * x.dim())
		T = self.scheduler.config.num_train_timesteps
		t_scaled = t_val * T
		idx0 = int(math.floor(t_scaled))
		idx1 = min(idx0 + 1, T - 1)
		# w = 0.5
		w = float(t_scaled - idx0)  # fractional interpolation weight
		t0 = torch.tensor([idx0], dtype=torch.long, device=x.device)
		t1 = torch.tensor([idx1], dtype=torch.long, device=x.device)
		eps0 = self.model(x.float(), t0).sample
		eps1 = self.model(x.float(), t1).sample
		eps = (1 - w) * eps0 + w * eps1
		return 0.5 * beta_t * (eps / sigma_t - x)

	def jvp_velocity(self, z, t, u):
		"""
		Efficiently compute the Jacobian-vector product (JVP) of the velocity field at z in the direction u.
		Args:
			z: (B, ...), input tensor (requires_grad)
			t: scalar tensor, time
			u: (B, ...), vector direction for JVP
		Returns:
			jvp: (B, ...), result of JVP for each batch element
		"""
		z_in = z.clone().detach().requires_grad_(True)
		v_out = self.velocity(z_in, t)
		# Compute the JVP by passing u as grad_outputs to autograd
		jvp = torch.autograd.grad(
			outputs=v_out,
			inputs=z_in,
			grad_outputs=u,
			retain_graph=False,
			create_graph=False
		)[0]
		return jvp

	def compute_d_eta_score(self, z, t, logp0, score0):
		# logp^1_s(z_s), s^1_s(z_s) from CFDM with X_1
		logp1 = self.cfdm_X1.log_prob(z, t)
		score1 = self.cfdm_X1.forward(z, t)
		# d_eta s^eta_s(z_s) = exp(logp1 - logp0) * (score1 - score0)
		weight = torch.exp(logp1 - logp0)
		weight = weight.clamp(min=self.min_clamp, max=self.max_clamp)
		# reshape weight to have shape (B, *), where score0 and score1 have shape (B, *)
		B = score0.shape[0]
		score_shape = score0.shape
		weight = weight.view(B, *[1] * (len(score_shape) - 1))
		d_eta_score = weight * (score1 - score0)
		del logp1, score1, weight
		return d_eta_score

	def hutchinson_trace_estimate(self, x, t):
		"""Vectorized Hutchinson estimator over both batch and samples."""
		B, *dims = x.shape
		# grads_flat needs to be computed with gradients
		x_flat = x.detach().unsqueeze(0).expand(self.num_hutchinson_samples, *x.shape).reshape(-1, *dims).requires_grad_(True)
		z_flat = torch.randn_like(x_flat)
		v_flat = self.velocity(x_flat, t)
		inner = (v_flat * z_flat).view(v_flat.size(0), -1).sum(1)
		grads_flat = torch.autograd.grad(inner.sum(), x_flat, create_graph=True)[0]
		trace_per_sample = (grads_flat * z_flat).view(self.num_hutchinson_samples, B, -1).sum(dim=2)
		trace = trace_per_sample.mean(dim=0)
		del grads_flat, trace_per_sample, z_flat, v_flat, inner
		return trace

	def precompute_sample_path(self, z1, t_init=1.0, t_final=0.0):
		"""
		Precompute the DDPM sample path and logpt sequence for a given z1.
		Returns:
			zt: (B, num_steps+1, C, H, W) sample path
			logpt: (B, num_steps+1) log-prob sequence
		"""
		num_steps = -int((t_final - t_init) / self.ode_step_size)
		dt = (t_final - t_init) / num_steps
		z = z1
		zt = []
		zt.append(z)
		# Precompute logp_1 for DDPM prior
		flat_dim = z1[0].numel()
		logp_1 = -0.5 * torch.sum(z1**2, dim=tuple(range(1, z1.dim()))) - 0.5 * flat_dim * np.log(2 * np.pi)
		logp_1 = logp_1.unsqueeze(1)
		delta_logpt = torch.zeros(z1.shape[0], 1).to(z1)
		logpt = [logp_1 - delta_logpt[:, -1].unsqueeze(1)]
		for i in range(num_steps):
			t = torch.tensor([t_init + i * dt]).to(z.device)
			with torch.no_grad():
				v = self.velocity(z, t)
			z = z + dt * v
			z = z.detach()
			v = v.detach()
			zt.append(z)
			div = self.hutchinson_trace_estimate(z, t).detach()
			delta_logpt = torch.cat([delta_logpt, (delta_logpt[:, -1] + dt * div).unsqueeze(1)], dim=1)
			logpt.append(logp_1 - delta_logpt[:, -1].unsqueeze(1))
			del v, div
		zt = torch.stack(zt, dim=1)
		logpt = torch.cat(logpt, dim=1)  # (B, num_steps+1)
		del z
		torch.cuda.empty_cache()
		return zt, logpt
	
	def precompute_sample_path_no_logpt(self, z1, t_init=1.0, t_final=0.0):
		"""
		Precompute the DDPM sample path for a given z1 without computing logpt.
		Returns:
			zt: (B, num_steps+1, C, H, W) sample path
		"""
		if t_final >= t_init:
			print("Advecting backwards to obtain z1 latents.")
			num_steps = int((t_final - t_init) / self.ode_step_size)
		else:
			num_steps = -int((t_final - t_init) / self.ode_step_size)
		dt = (t_final - t_init) / num_steps
		z = z1
		zt = []
		zt.append(z)
		for i in range(num_steps):
			t = torch.tensor([t_init + i * dt]).to(z.device)
			with torch.no_grad():
				v = self.velocity(z, t)
			z = z + dt * v
			z = z.detach()
			v = v.detach()
			zt.append(z)
			del v
		zt = torch.stack(zt, dim=1)
		del z
		torch.cuda.empty_cache()
		return zt

	def sensitivity_given_sample_path(self, zt, logpt, X_1, t_init=1.0, t_final=0.0, nn_estimator=False):
		"""
		Given precomputed zt and logpt, compute the sensitivity path et for a given X_1.
		Returns:
			et: (B, num_steps+1, C, H, W) sensitivity path
		"""
		self.build_cfdm(X_1, nn_estimator=nn_estimator)
		num_steps = -int((t_final - t_init) / self.ode_step_size)
		dt = (t_final - t_init) / num_steps
		e = torch.zeros_like(zt[:,0])
		et = []
		et.append(e)
		for i in tqdm(range(num_steps)):
			# Current time step
			t = torch.tensor([t_init + i * dt]).to(zt.device)
			z = zt[:,i]
			score0 = self.score_from_eps(z, t)
			logp0 = logpt[:,i]
			d_eta_score = self.compute_d_eta_score(z, t, logp0, score0).detach()
			beta_t = torch.tensor(self.beta_fn(t.cpu().item())).to(z)
			d_eta_v = -0.5 * beta_t * d_eta_score
			jvp = self.jvp_velocity(z, t, e)
			e = e + dt * (d_eta_v + jvp)
			e = e.detach()
			et.append(e)
			del score0, logp0, d_eta_score, beta_t, d_eta_v, jvp
		et = torch.stack(et, dim=1)
		del e
		del self.cfdm_X1
		torch.cuda.empty_cache()
		self.cfdm_X1 = None
		return et
