import equinox as eqx
import jax.numpy as jnp
import jax.random as random
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool, PyTree
from diffusion_crf import AbstractBatchableObject, auto_vmap
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Dict, overload, Literal
import jax
import jax.tree_util as jtu

"""
This module implements metrics and tools for evaluating probabilistic predictions using empirical distributions.
It provides:

1. EmpiricalDistribution - A class for representing and manipulating empirical distributions
2. Evaluation metrics including CRPS, log-likelihood, quantile-based metrics, and Wasserstein distance
3. Helper functions for computing distribution statistics and distances between samples

The implementation is fully compatible with JAX transformations including JIT compilation,
vectorization, and automatic differentiation, allowing for efficient computation of metrics
across large datasets.
"""

class EmpiricalDistribution(AbstractBatchableObject):
  """
  An empirical distribution for scalars.

  **Attributes**
  - `samples`: A 1D array of scalar samples with shape (N,).
  - `sorted_samples`: The sorted version of `samples`.
  - `sorted_ix`: The indices that sort the `samples`.

  **Usage**
  This module only works with scalar samples. All input functions expect a scalar.
  """
  samples: Float[Array, "N"]
  sorted_samples: Float[Array, "N"]
  sorted_ix: Array  # Integer array of shape (N,)

  def __init__(self, samples: Float[Array, "N"]) -> None:
    """
    Initialize the EmpiricalDistribution with a set of scalar samples.

    **Arguments**
    - `samples`: A 1D array of scalar samples with shape (N,).

    **Raises**
    - ValueError if `samples` is not a 1-dimensional array.
    """
    # Ensure samples is a 1D array.
    assert samples.ndim == 1, "samples must be a 1-dimensional array of scalars"
    self.samples = samples
    self.sorted_samples = jnp.sort(samples)
    self.sorted_ix = jnp.argsort(samples)

  @property
  def batch_size(self):
    if self.samples.ndim == 1:
      return None
    elif self.samples.ndim == 2:
      return self.samples.shape[0]
    else:
      return self.samples.shape[:-1]

  @property
  @auto_vmap
  def mean(self) -> Float[Array, ""]:
    """
    **Returns**
    - A scalar representing the mean of the samples.
    """
    return jnp.mean(self.samples)

  @property
  @auto_vmap
  def stddev(self) -> Float[Array, ""]:
    """
    **Returns**
    - A scalar representing the standard deviation of the samples.
    """
    return jnp.std(self.samples)

  @auto_vmap
  def cdf(self, x: Scalar) -> Float[Array, ""]:
    """
    Compute the cumulative distribution function (CDF) at a scalar x.

    **Arguments**
    - `x`: A scalar at which to evaluate the CDF.

    **Returns**
    - A scalar representing the CDF value at x.
    """
    x_val = jnp.asarray(x)
    # Ensure the input is a scalar.
    assert x_val.ndim == 0, "cdf: input x must be a scalar."
    idx = jnp.searchsorted(self.sorted_samples, x_val, side="right")
    return idx / self.samples.shape[0]

  @auto_vmap
  def sample(self, key: PRNGKeyArray) -> Float[Array, ""]:
    """
    Draw a single random sample from the empirical distribution.

    **Arguments**
    - `key`: A JAX PRNGKey.

    **Returns**
    - A scalar sample drawn from the distribution.
    """
    # We assume key is a valid PRNGKey.
    n = self.samples.shape[0]
    idx = random.randint(key, shape=(), minval=0, maxval=n)
    return self.samples[idx]

  @auto_vmap
  def coverage_loss(self, x: Scalar, level: Scalar) -> Float[Array, ""]:
    """
    Compute the coverage loss for a scalar observation.
    """
    return jnp.mean(x <= self.quantile(level))

  @auto_vmap
  def quantile(self, level: Scalar) -> Float[Array, ""]:
    """
    Compute the quantile corresponding to a given level.

    **Arguments**
    - `level`: A scalar quantile level in the interval [0, 1].

    **Returns**
    - A scalar quantile value corresponding to the input level.
    """
    level_val = jnp.asarray(level)
    # Ensure level is a scalar.
    assert level_val.ndim == 0, "quantile: input level must be a scalar."
    # Ensure that level is between 0 and 1.
    assert 0.0 <= level_val <= 1.0, "quantile: level must be between 0 and 1."
    n = self.samples.shape[0]
    # Compute the index so that level=1 returns the last sample.
    idx = jnp.clip(jnp.round(n * level_val).astype(jnp.int32) - 1, 0, n - 1)
    return self.sorted_samples[idx]

  @auto_vmap
  def quantile_losses(
    self,
    obs: Scalar,
    quantiles: Float[Array, "N"],
    levels: Float[Array, "N"]
  ) -> Float[Array, "N"]:
    """
    Compute quantile losses for a scalar observation.

    **Arguments**
    - `obs`: A scalar observed value.
    - `quantiles`: A 1D array of quantile values.
    - `levels`: A 1D array of quantile levels corresponding to `quantiles`.

    **Returns**
    - A 1D array of quantile losses.
    """
    obs_val = jnp.asarray(obs)
    # Ensure the observation is a scalar.
    assert obs_val.ndim == 0, "quantile_losses: observation must be a scalar."
    # Ensure quantiles and levels are 1D and have the same length.
    assert quantiles.ndim == 1, "quantile_losses: quantiles must be a 1D array."
    assert levels.ndim == 1, "quantile_losses: levels must be a 1D array."
    assert quantiles.shape[0] == levels.shape[0], (
      "quantile_losses: quantiles and levels must have the same length."
    )
    return jnp.where(obs_val >= quantiles,
                     levels * (obs_val - quantiles),
                     (1 - levels) * (quantiles - obs_val))

  @auto_vmap
  def crps_univariate(self, x: Scalar) -> Float[Array, ""]:
    """
    Compute the Continuous Ranked Probability Score (CRPS) for a scalar observation.

    **Arguments**
    - `x`: A scalar observed value.

    **Returns**
    - A scalar representing the CRPS value.
    """
    x_val = jnp.asarray(x)
    # Ensure the observation is a scalar.
    assert x_val.ndim == 0, "crps_univariate: input x must be a scalar."
    n = self.samples.shape[0]
    levels = jnp.linspace(1 / n, 1, n)
    quantiles = self.sorted_samples  # Already sorted; shape (N,)
    qlosses = self.quantile_losses(x_val, quantiles, levels)
    return jnp.sum(qlosses)

  @auto_vmap
  def nrmse(self, x: Scalar) -> Float[Array, ""]:
    """
    Compute the Normalized Root Mean Square Error (NRMSE) for a scalar observation.
    """
    x_val = jnp.asarray(x)
    # Ensure the observation is a scalar.
    assert x_val.ndim == 0, "nrmse: input x must be a scalar."
    return jnp.mean((x_val - self.mean)**2) / self.stddev**2

  @auto_vmap
  def log_likelihood(self, x: Scalar) -> Float[Array, ""]:
    """
    Compute the log-likelihood of a scalar observation under a KDE approximation.

    **Arguments**
    - `x`: A scalar observed value

    **Returns**
    - Log-likelihood using Gaussian KDE with Silverman's bandwidth
    """
    x_val = jnp.asarray(x)
    assert x_val.ndim == 0, "log_likelihood: input x must be a scalar."

    # Silverman's rule for bandwidth selection
    n = self.samples.shape[0]
    h = (4 * self.stddev**5 / (3 * n)) ** 0.2
    h = jnp.maximum(h, 1e-5)  # Prevent division by zero

    # Gaussian KDE calculation
    z = (x_val - self.samples)/h
    log_density = jax.scipy.special.logsumexp(-0.5*z**2) - jnp.log(n*h*jnp.sqrt(2*jnp.pi))
    return log_density

  @auto_vmap
  def wasserstein2(self, other: 'EmpiricalDistribution') -> Float[Array, ""]:
    """
    Compute the Wasserstein-2 distance between this empirical distribution and another.
    For multivariate distributions, the distributions are approximated as Gaussians.

    **Arguments**
    - `other`: Another EmpiricalDistribution instance

    **Returns**
    - A scalar representing the Wasserstein-2 distance
    """
    # For 1D samples, reshape to (N, 1) format for multivariate calculation
    samples1 = self.samples.reshape(-1, 1)
    samples2 = other.samples.reshape(-1, 1)

    return wasserstein2_distance(samples1, samples2)

def matrix_sqrt(mat: Float[Array, "D D"]) -> Float[Array, "D D"]:
  """
  Compute the matrix square root using eigendecomposition.

  **Arguments**
  - `mat`: A symmetric positive definite matrix of shape (D, D)

  **Returns**
  - The matrix square root of the input matrix
  """
  eigvals, eigvecs = jnp.linalg.eigh(mat)
  # Ensure numerical stability with a small epsilon
  eigvals = jnp.maximum(eigvals, 1e-10)
  sqrt_eigvals = jnp.sqrt(eigvals)
  return eigvecs @ jnp.diag(sqrt_eigvals) @ eigvecs.T

def wasserstein2_distance(
  samples1: Float[Array, "N D"],
  samples2: Float[Array, "M D"]
) -> Float[Array, ""]:
  """
  Compute the Wasserstein-2 distance between two sets of multivariate samples.

  **Arguments**
  - `samples1`: First set of samples with shape (N, D)
  - `samples2`: Second set of samples with shape (M, D)

  **Returns**
  - A scalar representing the Wasserstein-2 distance
  """
  # Calculate means
  mu1 = jnp.mean(samples1, axis=0)
  mu2 = jnp.mean(samples2, axis=0)

  # Calculate covariances
  centered1 = samples1 - mu1
  centered2 = samples2 - mu2
  cov1 = (centered1.T @ centered1) / samples1.shape[0]
  cov2 = (centered2.T @ centered2) / samples2.shape[0]

  # Add small regularization to ensure positive definiteness
  eps = 1e-10
  cov1 = cov1 + eps * jnp.eye(cov1.shape[0])
  cov2 = cov2 + eps * jnp.eye(cov2.shape[0])

  # Compute square root of cov1
  cov1_sqrt = matrix_sqrt(cov1)

  # Compute term inside the trace
  term = cov1_sqrt @ cov2 @ cov1_sqrt
  term = term + eps * jnp.eye(term.shape[0])  # Ensure numerical stability
  term_sqrt = matrix_sqrt(term)

  # Compute Wasserstein-2 distance
  mean_term = jnp.sum((mu1 - mu2) ** 2)
  cov_term = jnp.trace(cov1) + jnp.trace(cov2) - 2 * jnp.trace(term_sqrt)

  return jnp.sqrt(mean_term + cov_term)

def compute_univariate_metrics(samples: Float[Array, "N"], obs: Scalar) -> Tuple[Float[Array, ""],
                                                                                 Float[Array, ""],
                                                                                 Float[Array, ""]]:
  dist = EmpiricalDistribution(samples)
  crps = dist.crps_univariate(obs)
  ll = dist.log_likelihood(obs)
  cdf = dist.cdf(obs)
  return crps, ll, cdf
