import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.interpolate import interp1d
import math
from diffusers import UNet2DModel, DDPMScheduler
from tqdm import tqdm

import sys
# APPEND PATH TO PROJECT CODE TO ENABLE IMPORTS
import utils.cfdm_ddpm_conversion as convert
from utils.cfdm_ddpm_conversion import ScoreFromEps

class DDPMSampleSensitivityMoG(nn.Module):
	def __init__(self, 
				 model: UNet2DModel, 
				 scheduler: DDPMScheduler,
				 num_hutchinson_samples: int = 1, 
				 ode_step_size: float = 1e-1,
				 min_clamp: float = None,
				 max_clamp: float = None	
				 ):
		super().__init__()
		self.model = model
		self.scheduler = scheduler
		self.num_hutchinson_samples = num_hutchinson_samples
		self.ode_step_size = ode_step_size
		self.min_clamp = min_clamp
		self.max_clamp = max_clamp

		betas_np = self.scheduler.betas.cpu().numpy()
		betas_rescaled = betas_np * scheduler.config.num_train_timesteps
		alpha_bars_np = self.scheduler.alphas_cumprod.cpu().numpy()
		sigmas_np = np.sqrt((1 - alpha_bars_np))
		_ts = np.linspace(0, 1, len(betas_np))
		self.beta_fn = interp1d(_ts, betas_rescaled, kind="linear", fill_value="extrapolate")
		self.alpha_bar_fn = interp1d(_ts, alpha_bars_np, kind="linear", fill_value="extrapolate")
		self.sigma_fn = interp1d(_ts, sigmas_np, kind="linear", fill_value="extrapolate")
		self.score_from_eps = ScoreFromEps(self.model, self.scheduler)
		self.mog_X1 = None

	def build_mog(self, mean1, sigma1, scheduler, weights):
		self.mog_X1 = MixtureOfGaussians(mean1, sigma1, scheduler, weights)

	def velocity(self, x, t):
		t_val = t.clamp(0, 1)
		beta_t = torch.tensor(self.beta_fn(t_val.cpu().item())).to(x).view(*[1] * x.dim())
		sigma_t = torch.tensor(self.sigma_fn(t_val.cpu().item())).to(x).view(*[1] * x.dim())
		T = self.scheduler.config.num_train_timesteps
		t_scaled = t_val * T
		idx0 = int(math.floor(t_scaled))
		idx1 = min(idx0 + 1, T - 1)
		# w = 0.5
		w = float(t_scaled - idx0)  # fractional interpolation weight
		t0 = torch.tensor([idx0], dtype=torch.long, device=x.device)
		t1 = torch.tensor([idx1], dtype=torch.long, device=x.device)
		eps0 = self.model(x.float(), t0).sample
		eps1 = self.model(x.float(), t1).sample
		eps = (1 - w) * eps0 + w * eps1
		return 0.5 * beta_t * (eps / sigma_t - x)

	def jvp_velocity(self, z, t, u):
		"""
		Efficiently compute the Jacobian-vector product (JVP) of the velocity field at z in the direction u.
		Args:
			z: (B, ...), input tensor (requires_grad)
			t: scalar tensor, time
			u: (B, ...), vector direction for JVP
		Returns:
			jvp: (B, ...), result of JVP for each batch element
		"""
		z_in = z.clone().detach().requires_grad_(True)
		v_out = self.velocity(z_in, t)
		# Compute the JVP by passing u as grad_outputs to autograd
		jvp = torch.autograd.grad(
			outputs=v_out,
			inputs=z_in,
			grad_outputs=u,
			retain_graph=False,
			create_graph=False
		)[0]
		return jvp

	def compute_d_eta_score(self, z, t, logp0, score0):
		# logp^1_s(z_s), s^1_s(z_s) from MixtureOfGaussians with mean1/sigma1
		logp1 = self.mog_X1.log_prob(z, t)
		score1 = self.mog_X1.forward(z, t)
		weight = torch.exp(logp1 - logp0)
		weight = weight.clamp(min=self.min_clamp, max=self.max_clamp)
		B = score0.shape[0]
		score_shape = score0.shape
		weight = weight.view(B, *[1] * (len(score_shape) - 1))
		d_eta_score = weight * (score1 - score0)
		del logp1, score1, weight
		return d_eta_score

	def hutchinson_trace_estimate(self, x, t):
		"""Vectorized Hutchinson estimator over both batch and samples."""
		B, *dims = x.shape
		x_flat = x.detach().unsqueeze(0).expand(self.num_hutchinson_samples, *x.shape).reshape(-1, *dims).requires_grad_(True)
		z_flat = torch.randn_like(x_flat)
		v_flat = self.velocity(x_flat, t)
		inner = (v_flat * z_flat).view(v_flat.size(0), -1).sum(1)
		grads_flat = torch.autograd.grad(inner.sum(), x_flat, create_graph=True)[0]
		trace_per_sample = (grads_flat * z_flat).view(self.num_hutchinson_samples, B, -1).sum(dim=2)
		trace = trace_per_sample.mean(dim=0)
		del grads_flat, trace_per_sample, z_flat, v_flat, inner
		return trace

	def precompute_sample_path(self, z1, t_init=1.0, t_final=0.0):
		"""
		Precompute the DDPM sample path and logpt sequence for a given z1.
		Returns:
			zt: (B, num_steps+1, C, H, W) sample path
			logpt: (B, num_steps+1) log-prob sequence
		"""
		num_steps = -int((t_final - t_init) / self.ode_step_size)
		dt = (t_final - t_init) / num_steps
		z = z1
		zt = []
		zt.append(z)
		# Precompute logp_1 for DDPM prior
		flat_dim = z1[0].numel()
		logp_1 = -0.5 * torch.sum(z1**2, dim=tuple(range(1, z1.dim()))) - 0.5 * flat_dim * np.log(2 * np.pi)
		logp_1 = logp_1.unsqueeze(1)
		delta_logpt = torch.zeros(z1.shape[0], 1).to(z1)
		logpt = [logp_1 - delta_logpt[:, -1].unsqueeze(1)]
		for i in range(num_steps):
			t = torch.tensor([t_init + i * dt]).to(z.device)
			with torch.no_grad():
				v = self.velocity(z, t)
			z = z + dt * v
			z = z.detach()
			v = v.detach()
			zt.append(z)
			div = self.hutchinson_trace_estimate(z, t).detach()
			delta_logpt = torch.cat([delta_logpt, (delta_logpt[:, -1] + dt * div).unsqueeze(1)], dim=1)
			logpt.append(logp_1 - delta_logpt[:, -1].unsqueeze(1))
			del v, div
		zt = torch.stack(zt, dim=1)
		logpt = torch.cat(logpt, dim=1)  # (B, num_steps+1)
		del z
		torch.cuda.empty_cache()
		return zt, logpt
	
	def precompute_sample_path_no_logpt(self, z1, t_init=1.0, t_final=0.0):
		"""
		Precompute the DDPM sample path for a given z1 without computing logpt.
		Returns:
			zt: (B, num_steps+1, C, H, W) sample path
		"""
		num_steps = -int((t_final - t_init) / self.ode_step_size)
		dt = (t_final - t_init) / num_steps
		z = z1
		zt = []
		zt.append(z)
		for i in range(num_steps):
			t = torch.tensor([t_init + i * dt]).to(z)
			with torch.no_grad():
				v = self.velocity(z, t)
			z = z + dt * v
			z = z.detach()
			v = v.detach()
			zt.append(z)
			del v
		zt = torch.stack(zt, dim=1)
		del z
		torch.cuda.empty_cache()
		return zt

	def sensitivity_given_sample_path(self, zt, logpt, mean1, sigma1, weights, t_init=1.0, t_final=0.0):
		"""
		Given precomputed zt and logpt, compute the sensitivity path et for a given MixtureOfGaussians (mean1, sigma1, weights).
		Returns:
			et: (B, num_steps+1, C, H, W) sensitivity path
		"""
		self.build_mog(mean1, sigma1, self.scheduler, weights)
		num_steps = -int((t_final - t_init) / self.ode_step_size)
		dt = (t_final - t_init) / num_steps
		e = torch.zeros_like(zt[:,0])
		et = []
		et.append(e)
		for i in tqdm(range(num_steps)):
			t = torch.tensor([t_init + i * dt]).to(zt)
			z = zt[:,i]
			score0 = self.score_from_eps(z, t)
			logp0 = logpt[:,i]
			d_eta_score = self.compute_d_eta_score(z, t, logp0, score0).detach()
			beta_t = torch.tensor(self.beta_fn(t.cpu().item())).to(z)
			d_eta_v = -0.5 * beta_t * d_eta_score
			jvp = self.jvp_velocity(z, t, e)
			e = e + dt * (d_eta_v + jvp)
			e = e.detach()
			et.append(e)
			del score0, logp0, d_eta_score, beta_t, d_eta_v, jvp
		et = torch.stack(et, dim=1)
		del e
		del self.mog_X1
		torch.cuda.empty_cache()
		self.mog_X1 = None
		return et

class MixtureOfGaussians(nn.Module):
	def __init__(self, means, sigma_0, scheduler, weights):
		super().__init__()
		self.means = means  # (K, D)
		self.K = means.shape[0]
		self.D = means.shape[1]
		self.sigma_0 = sigma_0
		self.scheduler = scheduler
		self.num_timesteps = scheduler.config.num_train_timesteps
		self.weights = weights / weights.sum()  # (K,) non-negative, sum to 1
		alpha_bars_np = scheduler.alphas_cumprod.cpu().numpy()
		sigmas_np = np.sqrt(1 - alpha_bars_np)
		_ts = np.linspace(0, 1, len(alpha_bars_np))
		self.alpha_bar_fn = interp1d(_ts, alpha_bars_np, kind="linear", fill_value="extrapolate")
		self.sigma_fn = interp1d(_ts, sigmas_np, kind="linear", fill_value="extrapolate")

	def forward(self, z, t):
		# Score of mixture of Gaussians
		if isinstance(t, torch.Tensor):
			t = t.cpu().item()
		alpha_bar = float(self.alpha_bar_fn(t))
		sigma_t = float(self.sigma_fn(t))
		means_t = self.means * math.sqrt(alpha_bar)  # (K, D)
		cov = (alpha_bar * (self.sigma_0 ** 2)) + (sigma_t ** 2) + 1e-6
		z = z.view(z.shape[0], -1)  # (B, D)
		means_t = means_t.to(z)
		# Compute log weights
		dist_sq = torch.cdist(z, means_t, p=2) ** 2  # (B, K)
		log_weights = -dist_sq / (2 * cov)
		log_weights = log_weights + torch.log(self.weights.to(z))[None, :]
		log_weights = log_weights - torch.max(log_weights, dim=1, keepdim=True)[0]  # stability
		weights = F.softmax(log_weights, dim=1)
		avg_mean = torch.sum(weights[:, :, None] * means_t[None, :, :], dim=1)  # (B, D)
		score = (avg_mean - z) / cov
		return score
	
	def log_prob(self, z, t):
		if isinstance(t, torch.Tensor):
			t = t.cpu().item()
		alpha_bar = float(self.alpha_bar_fn(t))
		sigma_t = float(self.sigma_fn(t))
		means_t = self.means * math.sqrt(alpha_bar)  # (K, D)
		cov = (alpha_bar * self.sigma_0 ** 2) + (sigma_t ** 2) + 1e-6
		z = z.view(z.shape[0], -1)  # (B, D)
		means_t = means_t.to(z.device)
		D = z.shape[1]
		dist_sq = torch.cdist(z, means_t, p=2) ** 2  # (B, K)
		log_det_cov = D * math.log(cov)
		log_normalizer = -0.5 * (D * math.log(2 * math.pi) + log_det_cov)
		log_probs_per_component = log_normalizer - 0.5 * dist_sq / cov  # (B, K)
		log_probs_per_component = log_probs_per_component + torch.log(self.weights.to(z.device))[None, :]
		log_probs = torch.logsumexp(log_probs_per_component, dim=1)  # (B,)
		return log_probs
	
def run_sample_sensitivity_analysis_exact_logprob(means,
                                                  sigma_0,
                                                  n_model_samples, 
                                                  ode_step_size, 
                                                  min_clamp,
                                                  max_clamp, 
                                                  eta_list
                                                  ):
	device = means.device
	D = means.shape[1]
	scheduler = DDPMScheduler(
		num_train_timesteps=int(1/ode_step_size)+1,
		beta_start=1e-4,
		beta_end=0.02,
		beta_schedule="linear"
		)
	z1 = torch.randn(n_model_samples, D).to(means)
    # Use eta=0 for baseline
	weights_0 = torch.tensor([0.5, 0.5]).to(means)
	mog_eta0 = MixtureOfGaussians(means, sigma_0, scheduler, weights_0)
	eps_eta0 = convert.EpsFromScore(mog_eta0).to(device)
	sensitivity_eta0 = DDPMSampleSensitivityMoG(eps_eta0, scheduler, ode_step_size=ode_step_size, min_clamp=min_clamp, max_clamp=max_clamp)
	zt = sensitivity_eta0.precompute_sample_path_no_logpt(z1)
	# Compute logpt for each t
	T = zt.shape[1]
	logpt = []
	for i in range(T):
		t = 1.0 - i * ode_step_size
		z = zt[:, i, :]
		with torch.no_grad():
			logpt.append(mog_eta0.log_prob(z, t))
	logpt = torch.stack(logpt, axis=1)  # (B, T)
	# Compute sample sensitivity wrt mean2 (second mean), sigma_0, weights for eta=0, using MoG
	mean1 = means
	sigma1 = sigma_0
	weights_1 = torch.tensor([0.0, 1.0]).to(means)
	sample_sensitivity = sensitivity_eta0.sensitivity_given_sample_path(zt, logpt, mean1, sigma1, weights_1)
	# For each eta, compute zt_perturbed and little-o error for all t
	median_little_o_errors = []
	pctile_10_errors = []
	pctile_90_errors = []
	for eta_candidate in tqdm(eta_list):
		# weights_eta = torch.tensor([1.0 - eta_candidate, eta_candidate]).to(means)
		weights_eta = (1 - eta_candidate) * weights_0 + eta_candidate * weights_1
		mog_eta = MixtureOfGaussians(means, sigma_0, scheduler, weights_eta)
		eps_eta = convert.EpsFromScore(mog_eta).to(device)
		sensitivity_eta = DDPMSampleSensitivityMoG(eps_eta, scheduler, ode_step_size=ode_step_size, min_clamp=min_clamp, max_clamp=max_clamp)
		zt_gt = sensitivity_eta.precompute_sample_path_no_logpt(z1)
		zt_perturbed = zt.clone() + eta_candidate * sample_sensitivity
		error = torch.norm(zt_perturbed - zt_gt, dim=2) / (eta_candidate + 1e-12)  # (B, T)
		# error_mean = error.mean(dim=0)  # (T,)
		error_median = error.median(dim=0).values  # (T,)
		median_little_o_errors.append(error_median.detach().cpu())
		# also compute 10th and 90th percentiles
		pctile_10 = torch.quantile(error, 0.1, dim=0).detach().cpu()
		pctile_90 = torch.quantile(error, 0.9, dim=0).detach().cpu()
		pctile_10_errors.append(pctile_10)
		pctile_90_errors.append(pctile_90)
	median_little_o_errors = torch.stack(median_little_o_errors, dim=0)  # (num_eta, T)
	pctile_10_errors = torch.stack(pctile_10_errors, dim=0)  # (num_eta, T)
	pctile_90_errors = torch.stack(pctile_90_errors, dim=0)  # (num_eta, T)
	return median_little_o_errors, pctile_10_errors, pctile_90_errors, zt

def run_sample_sensitivity_analysis_hutch_logprob(means,
												  sigma_0,
												  n_model_samples,
												  ode_step_size,
												  min_clamp,
												  max_clamp,
												  eta_list,
												  num_hutch_list):
	device = means.device
	D = means.shape[1]
	scheduler = DDPMScheduler(
		num_train_timesteps=int(1/ode_step_size)+1,
		beta_start=1e-4,
		beta_end=0.02,
		beta_schedule="linear"
	)
	z1 = torch.randn(n_model_samples, D).to(means)
	mean1 = means
	sigma1 = sigma_0
	weights_0 = torch.tensor([0.5, 0.5]).to(means)
	results = {}
	for num_hutch in num_hutch_list:
		mog_eta0 = MixtureOfGaussians(means, sigma_0, scheduler, weights_0)
		eps_eta0 = convert.EpsFromScore(mog_eta0).to(device)
		sensitivity_eta0 = DDPMSampleSensitivityMoG(eps_eta0, scheduler, ode_step_size=ode_step_size, min_clamp=min_clamp, max_clamp=max_clamp)
		sensitivity_eta0.num_hutchinson_samples = num_hutch
		zt, logpt = sensitivity_eta0.precompute_sample_path(z1)
		# also compute logpt_exact for comparison
		# Compute logpt for each t
		T = zt.shape[1]
		logpt_exact = []
		for i in range(T):
			t = 1.0 - i * ode_step_size
			z = zt[:, i, :]
			with torch.no_grad():
				logpt_exact.append(mog_eta0.log_prob(z, t))
		logpt_exact = torch.stack(logpt_exact, axis=1)  # (B, T)
		weights_1 = torch.tensor([0.0, 1.0]).to(means)
		sample_sensitivity = sensitivity_eta0.sensitivity_given_sample_path(zt, logpt, mean1, sigma1, weights_1)
		median_little_o_errors = []
		pctile_10_errors = []
		pctile_90_errors = []
		# zt_gt_list = []
		for eta_candidate in tqdm(eta_list):
			weights_eta = (1 - eta_candidate) * weights_0 + eta_candidate * weights_1
			mog_eta = MixtureOfGaussians(means, sigma_0, scheduler, weights_eta)
			eps_eta = convert.EpsFromScore(mog_eta).to(device)
			sensitivity_eta = DDPMSampleSensitivityMoG(eps_eta, scheduler, ode_step_size=ode_step_size, min_clamp=min_clamp, max_clamp=max_clamp)
			sensitivity_eta.num_hutchinson_samples = num_hutch
			zt_gt = sensitivity_eta.precompute_sample_path_no_logpt(z1)
			zt_perturbed = zt.clone() + eta_candidate * sample_sensitivity
			error = torch.norm(zt_perturbed - zt_gt, dim=2) / (eta_candidate + 1e-12)  # (B, T)
			# error_mean = error.mean(dim=0)  # (T,)
			error_median = error.median(dim=0).values  # (T,)
			median_little_o_errors.append(error_median.detach().cpu())
			# also compute 10th and 90th percentiles
			pctile_10 = torch.quantile(error, 0.1, dim=0).detach().cpu()
			pctile_90 = torch.quantile(error, 0.9, dim=0).detach().cpu()
			pctile_10_errors.append(pctile_10)
			pctile_90_errors.append(pctile_90)
		median_little_o_errors = torch.stack(median_little_o_errors, dim=0)  # (num_eta, T)
		pctile_10_errors = torch.stack(pctile_10_errors, dim=0)  # (num_eta, T)
		pctile_90_errors = torch.stack(pctile_90_errors, dim=0) # (num_eta, T)
		results[num_hutch] = (median_little_o_errors, pctile_10_errors, pctile_90_errors, logpt, logpt_exact)
	return results