import torch
from typing import Dict, Optional, List
from .base import BaseMetric

class HessianMetric(BaseMetric):
    def __init__(self, timesteps_to_measure: Optional[List[int]] = None):
        """
        timesteps_to_measure: Optional list of indices into intermediates['x_inter'] to compute Hessian at.
                              If None, defaults to legacy {t50, t1, t20}.
        """
        self.timesteps_to_measure = timesteps_to_measure
        super().__init__()

    @property
    def name(self) -> str:
        return "Hessian_SAIL_Metric"

    @property
    def metric_type(self) -> str:
        return "per_seed"

    @torch.no_grad()
    def measure(self, intermediates: Dict, model, **kwargs) -> Dict:
        c_ = kwargs.get("conditioning_context")
        uc_ = kwargs.get("unconditioning_context")
        if c_ is None or uc_ is None:
            raise ValueError("HessianSAILMetric requires conditioning and unconditioning contexts.")

        def get_cond_score(x_t, t):
            return model.apply_model(x_t, t, c_)
        def get_uncond_score(x_t, t):
            return model.apply_model(x_t, t, uc_)

        results = {}

        if self.timesteps_to_measure is None:
            # Legacy behaviour
            steps_to_analyze = {"t50": 0, "t1": -1, "t20": -20}
        else:
            # User-specified indices (labelled by index)
            steps_to_analyze = {f"t{idx}": idx for idx in self.timesteps_to_measure}

        for name, t_index in steps_to_analyze.items():
            if abs(t_index) >= len(intermediates['x_inter']):
                continue

            latents = intermediates['x_inter'][t_index]
            timestep = torch.tensor([intermediates['timesteps'][t_index]] * model.num_frames, device=model.device)

            s_delta = get_cond_score(latents, timestep) - get_uncond_score(latents, timestep)
            delta = 1e-3
            s_delta_norm = torch.linalg.norm(s_delta)
            if s_delta_norm < 1e-6:
                continue

            perturbation = delta * s_delta / s_delta_norm
            latents_perturbed = latents + perturbation

            h_s_cond = get_cond_score(latents_perturbed, timestep) - get_cond_score(latents, timestep)
            h_s_uncond = get_uncond_score(latents_perturbed, timestep) - get_uncond_score(latents, timestep)

            cond_magnitudes = torch.linalg.norm(h_s_cond.squeeze(0), dim=0).flatten()
            uncond_magnitudes = torch.linalg.norm(h_s_uncond.squeeze(0), dim=0).flatten()

            results[name] = {
                "cond_magnitudes": torch.sort(cond_magnitudes).values.cpu().tolist(),
                "uncond_magnitudes": torch.sort(uncond_magnitudes).values.cpu().tolist()
            }

        final_metric_val = 0.0
        if "t1" in results:
            h_s_cond_t1 = torch.tensor(results['t1']['cond_magnitudes'])
            h_s_uncond_t1 = torch.tensor(results['t1']['uncond_magnitudes'])
            final_metric_val = torch.sum((h_s_cond_t1 - h_s_uncond_t1) ** 2).item()

        return {
            "hessian_sail_norm": final_metric_val,
            "visualizations": results
        }