
import torch
import numpy as np
from typing import Dict
from scipy.sparse.linalg import eigsh, LinearOperator

from .base import BaseMetric


def arnoldi_iteration(hvp_func, dim: int, k: int, device: str) -> np.ndarray:
    # Initialize matrices for the Arnoldi process
    q_basis = torch.zeros(dim, k + 1, device=device)
    h_hessenberg = torch.zeros(k + 1, k, device=device)
    
    # Start with a random vector
    b = torch.randn(dim, device=device)
    q_basis[:, 0] = b / torch.linalg.norm(b)
    
    for n in range(k):
        # Compute the Hessian-vector product
        v = hvp_func(q_basis[:, n])
        
        # Modified Gram-Schmidt process to orthogonalize v against previous basis vectors
        for j in range(n + 1):
            h_hessenberg[j, n] = torch.dot(q_basis[:, j], v)
            v = v - h_hessenberg[j, n] * q_basis[:, j]
            
        h_hessenberg[n + 1, n] = torch.linalg.norm(v)
        
        # Avoid division by zero
        if h_hessenberg[n + 1, n] > 1e-8:
            q_basis[:, n + 1] = v / h_hessenberg[n + 1, n]
        else:
            # If the norm is zero, the subspace is exhausted
            break
            
    # The eigenvalues of the smaller Hessenberg matrix approximate the eigenvalues of H
    hessenberg_final = h_hessenberg[:k, :k]
    ritz_values, _ = torch.linalg.eig(hessenberg_final)
    
    return ritz_values.real.detach().cpu().numpy()


class HessianMetricSlow(BaseMetric):
    @property
    def name(self) -> str:
        return "HessianMetric"

    @property
    def metric_type(self) -> str:
        return "per_seed"

    def _get_score_from_noise(self, noise_pred, latent, t, model):
        """Helper to convert noise prediction to score."""
        # This conversion depends on the scheduler's noise schedule
        alpha_prod_t = model.scheduler.alphas_cumprod[t]
        beta_prod_t = 1 - alpha_prod_t
        return -noise_pred / beta_prod_t.sqrt()

    def measure_deprecated(self, intermediates: Dict, model, **kwargs) -> Dict:
        """
        ||H_delta * s_delta||^2
        This is computed at the first generation step (t=T-1).

        Args:
            intermediates (Dict): The dictionary from the sampler.
            model: The diffusion model, needed for its score function.
            **kwargs: Additional arguments like the timestep 't'.

        Returns:
            Dict: A dictionary with the calculated metric.
        """
        # only at initial sampling step (t=T-1) [cite: 138, 171]
        t_index = -1
        
        # print(intermediates['timesteps']) # [1 21 41 ... 941 961 981]
        uncond_noise = intermediates['uncond_noise'][t_index]
        text_noise = intermediates['text_noise'][t_index]
        latents = intermediates['x_inter'][t_index]
        timestep = intermediates['timesteps'][t_index]
        timestep = torch.tensor([timestep] * model.num_frames, device=model.device)
        
        s_delta = text_noise - uncond_noise

        # Jacobian of the score function
        def score_func(x, t, c_crossattn, c_concat=None):
            return model.model.diffusion_model(x, t, c_crossattn)
        
        conditioning_context = kwargs.get("conditioning_context")
        unconditioning_context = kwargs.get("unconditioning_context")
        if conditioning_context is None:
            raise ValueError("HessianMetric requires 'conditioning_context' (c_).")
        tc = conditioning_context['context']
        uc = unconditioning_context['context']

        # JVP for H_c * s_delta
        # or the Hessian-vector product of the log-probability
        jvp_cond = torch.autograd.functional.jvp(
            lambda x: score_func(x, timestep, tc), 
            latents, 
            s_delta
        )[1]

        # JVP for H_u * s_delta
        jvp_uncond = torch.autograd.functional.jvp(
            lambda x: score_func(x, timestep, uc), # Assuming null context for unconditional
            latents, 
            s_delta
        )[1]

        # H_delta * s_delta = (H_c - H_u) * s_delta
        hessian_product = jvp_cond - jvp_uncond
        
        # Done!
        metric_val = hessian_product.norm(p=2).pow(2).item()
        
        return {"hessian_norm": metric_val}
    
    def _get_hessian_vector_product_func(self, model, latents, timestep_tensor, context):
        """Returns a function that computes H*v for a given vector v."""
        
        def score_func(x, t, ctx):
            return model.model.diffusion_model(x, t, ctx)

        def hvp_func(v):
            v_torch = torch.from_numpy(v).to(latents.device, dtype=torch.float32)
            v_torch = v_torch.view_as(latents)
            jvp_result = torch.autograd.functional.jvp(
                lambda x: score_func(x, timestep_tensor, context),
                latents,
                v_torch
            )[1]
            return jvp_result.view(-1).detach().cpu().numpy()
        
        return hvp_func
    
    def _get_hvp_torch_func(self, model, latents, timestep_tensor, context):
        """Returns a function that computes H*v for a given torch vector v."""
        
        def score_func(x, t, ctx):
            # print("_get_hvp_torch_func", "called with x:", x.shape, "t:", t, "ctx:", ctx.shape)
            return model.model.diffusion_model(x, t, ctx)

        def hvp_func(v_torch):
            # flattened for dot products but reshaped for the model
            v_shaped = v_torch.view_as(latents)

            jvp_result = torch.autograd.functional.jvp(
                lambda x: score_func(x, timestep_tensor, context),
                latents,
                v_shaped
            )[1]
            
            return jvp_result.view(-1)
        
        return hvp_func
    
    def measure(self, intermediates: Dict, model, **kwargs) -> Dict:
        """
        Computes the top k eigenvalues of the conditional and unconditional Hessians
        at timesteps t=1 and t=20.
        """
        conditioning_context = kwargs.get("conditioning_context")
        unconditioning_context = kwargs.get("unconditioning_context")
        if conditioning_context is None:
            raise ValueError("HessianEigenvalueMetric requires 'conditioning_context'.")
        
        tc = conditioning_context['context']
        uc = unconditioning_context['context']
        
        # Number of eigenvalues to compute
        k = 20
        
        results = {}
        # Timesteps of interest, corresponding to t=1 and t=20 in the paper
        # DDIM with 50 steps: t=1 -> index -1, t=20 -> index -20 (approx)
        steps_to_analyze = {"t50": 0, "t1": -1, "t20": -20} 
        
        for name, t_index in steps_to_analyze.items():
            latents = intermediates['x_inter'][t_index]
            timestep_int = intermediates['timesteps'][t_index]
            timestep_tensor = torch.tensor([timestep_int], device=model.device)
            
            dim = latents.numel()

            # --- Conditional Eigenvalues ---
            hvp_cond_func = self._get_hvp_torch_func(model, latents, timestep_tensor, tc)
            eigvals_cond = arnoldi_iteration(hvp_cond_func, dim, k, model.device)

            # --- Unconditional Eigenvalues ---
            hvp_uncond_func = self._get_hvp_torch_func(model, latents, timestep_tensor, uc)
            eigvals_uncond = arnoldi_iteration(hvp_uncond_func, dim, k, model.device)
            
            results[name] = {
                "cond_eigvals": sorted(eigvals_cond.tolist()),
                "uncond_eigvals": sorted(eigvals_uncond.tolist())
            }

        # print("HessianMetric autograd JVP:", results)
        return results