import numpy as np
import jax
import time
import jax.numpy as jnp
from typing import Any

# Evaluation metric:
    # Histogram error: 
    #   H(S) = {(S_i, c_i) | i = 1,...,N}, where S is the statistic, S_i is the i-th summary stat bin center, c_i is the count
    #   Err(H_true, H_pred) = sum_i ||c_true_i - c_pred_i||_1
    # Energy spectrum error:
    #  U_n = (u(t_i)_n, G(u(t_i))_n, G(G(u(t_i)))_n, ...), where n is the trajectory sample index, G is the emulator operator
    #  ESE = mean_n || FFT(U_n_true) - FFT(U_n_pred)||_1 / ||FFT(U_n_true)||_1

def compute_mmd(x: jax.Array, y: jax.Array, kernel_bandwidth: float = 1.0) -> jax.Array:
    """
    MMD with RBF kernel. Much faster than Wasserstein.
    
    Args:
        x: shape (n_samples, n_dims)
        y: shape (n_samples, n_dims)
    """
    def rbf_kernel(x, y, bandwidth):
        # ||x - y||^2
        diff = x[:, None, :] - y[None, :, :]
        dist_sq = jnp.sum(diff ** 2, axis=-1)
        return jnp.exp(-dist_sq / (2 * bandwidth ** 2))
    
    xx = rbf_kernel(x, x, kernel_bandwidth)
    yy = rbf_kernel(y, y, kernel_bandwidth)
    xy = rbf_kernel(x, y, kernel_bandwidth)
    
    mmd_sq = jnp.mean(xx) + jnp.mean(yy) - 2 * jnp.mean(xy)
    return jnp.sqrt(jnp.maximum(mmd_sq, 0.0))   


def get_IPM(data_true: jax.Array, data_pred: jax.Array) -> float:
    """Compute the Integral Probability Metric (IPM) through MMD.

    Args:
        data_true (jax.Array): True data samples of shape (batch_size, n_samples, n_dims).
        data_pred (jax.Array): Predicted data samples of shape (batch_size, n_samples, n_dims).

    Returns:
        ipm (float): The computed IPM value.
    """
    if data_true.shape != data_pred.shape:
        raise ValueError(f"Shape mismatch: {data_true.shape} vs {data_pred.shape}")

    mmd_fn = jax.vmap(compute_mmd, in_axes=(0, 0))
    mmd_values = mmd_fn(data_true, data_pred)

    return float(jnp.mean(mmd_values))

def get_histogram(data: np.ndarray, num_bins: int, density: bool = True) -> list[tuple[np.ndarray[tuple[int, int], np.dtype[Any]], np.ndarray[tuple[int, int], np.dtype[Any]]]]:
    """Compute histogram of the data.

    Args:
        data (np.ndarray): Input data array of shape:
            - (batch_size, n_samples, n_dims) for batched n-D
        num_bins (int): Number of histogram bins per dimension.
        density (bool, optional): If True, normalize the histogram. Defaults to True.

    Returns:
        histograms (list[tuple]): List of length batch_size, each element is a tuple of: 
            - Bin centers (np.ndarray): Shape (n_features, num_bins) -  Histogram bin centers for each dimension.
            - counts (np.ndarray): Shape (n_features, num_bins) -  Histogram counts for each dimension.
    """
    
    if data.ndim != 3:
        if data.ndim == 2:
            data = data[:, :, None]  # Add singleton dimension for n_dims=1
        else:
            raise ValueError(f"Expected 3D input (batch_size, n_samples, n_dims), got {data.shape}")

    batch_size, _, n_dims = data.shape

    histograms = []
    
    for b in range(batch_size):
        bin_centers = np.zeros((n_dims, num_bins))
        counts = np.zeros((n_dims, num_bins))
        for dim in range(n_dims):
            
            hist_counts, bin_edges = np.histogram(data[b, :, dim], bins=num_bins, density=density)
            counts[dim] = hist_counts
            bin_centers[dim] = 0.5 * (bin_edges[:-1] + bin_edges[1:])
        histograms.append((bin_centers, counts))

    return histograms

def compute_histogram_error(H_true: list[tuple[np.ndarray[int, np.dtype], np.ndarray[int, np.dtype]]], 
                            H_hat: list[tuple[np.ndarray[int, np.dtype], np.ndarray[int, np.dtype]]], 
                            aggregate: str = 'mean') -> float:
    """Compute histogram error between true and predicted summary statistics.

    Args:
        H_true (list[tuple[np.ndarray, np.ndarray]]): List of true histograms for each sample.
        H_hat (list[tuple[np.ndarray, np.ndarray]]): List of predicted histograms for each sample.
        aggregate (str, optional): Method to aggregate errors ('mean' or 'sum'). Defaults to 'mean'.

    Returns:
        histogram_error (float): Histogram error value.
    """
    if len(H_true) != len(H_hat):
        raise ValueError(f"Batch size mismatch: {len(H_true)} vs {len(H_hat)}")
    
    true_counts = np.stack([counts for _, counts in H_true])  # (batch_size, n_dims, num_bins)
    hat_counts = np.stack([counts for _, counts in H_hat])    # (batch_size, n_dims, num_bins)
    
    # L1 distance for each feature dimension
    hist_error_per_dim = np.sum(np.abs(hat_counts - true_counts), axis=2)

    if aggregate == 'mean':
        return np.mean(hist_error_per_dim)
    else:
        raise ValueError(f"Unknown aggregate method: {aggregate}.")

def compute_energy_spectrum_error(u_true: np.ndarray, u_hat: np.ndarray, eps: float = 1e-8) -> float:
    """Power spectral density error (PSDE) between true and predicted trajectories.
        PSDE = mean_n || PSD(U_n_true) - PSD(U_n_pred) ||_1 / || PSD(U_n_true) ||_1

    Args:
        u_true (np.ndarray): True trajectory samples.
        u_hat (np.ndarray): Predicted trajectory samples.
        eps (float, optional): Small value to avoid division by zero. Defaults to 1e-8.

    Returns:
        float: Power spectral density error (PSDE) value.
    """
    if u_true.shape != u_hat.shape:
        raise ValueError(f"Shape mismatch: {u_true.shape} vs {u_hat.shape}")

    # Demean to avoid DC dominating the PSD comparison.
    u_true_demean = u_true - np.mean(u_true, axis=1, keepdims=True)
    u_hat_demean = u_hat - np.mean(u_hat, axis=1, keepdims=True)

    n = u_true.shape[1]
    U = np.fft.rfft(u_true_demean, axis=1)
    U_hat = np.fft.rfft(u_hat_demean, axis=1)
    P_true = (np.abs(U) ** 2) / max(n, 1)
    P_hat = (np.abs(U_hat) ** 2) / max(n, 1)

    # One-sided PSD correction for real signals.
    if P_true.shape[1] > 1:
        if n % 2 == 0:
            P_true[:, 1:-1, ...] *= 2.0
            P_hat[:, 1:-1, ...] *= 2.0
        else:
            P_true[:, 1:, ...] *= 2.0
            P_hat[:, 1:, ...] *= 2.0

    psd_error_per_sample = np.mean(np.abs(P_true - P_hat), axis=1) / (np.mean(P_true, axis=1) + eps)
    psd_error = np.mean(psd_error_per_sample).item()

    return psd_error

def compute_trajectory_errors(u_true, u_hat) -> dict[str, Any]:
    """
    Compute various error metrics for trajectory predictions.
    
    Args:
        u_true: True trajectory tensor, shape [batch, T, d]
        u_hat: Predicted trajectory tensor, shape [batch, T, d]
    
    Returns:
        Dictionary of error metrics
    """
    # Computing MSE, MAE, RMSE
    mse = np.mean((u_true - u_hat) ** 2).item()
    # mse_per_time = np.mean((u_true - u_hat) ** 2, axis=(0, 2))  # [T]
    # mse_per_component = np.mean((u_true - u_hat) ** 2, axis=(0, 1))  # [d]

    # mae = np.mean(np.abs(u_true - u_hat)).item()
    # mae_per_time = np.mean(np.abs(u_true - u_hat), axis=(0, 2))  # [T]
    # mae_per_component = np.mean(np.abs(u_true - u_hat), axis=(0, 1))  # [d]

    # rmse = np.sqrt(np.mean((u_true - u_hat) ** 2)).item()
    # rmse_per_time = np.sqrt(np.mean((u_true - u_hat) ** 2, axis=(0, 2)))  # [T]
    # rmse_per_component = np.sqrt(np.mean((u_true - u_hat) ** 2, axis=(0, 1)))  # [d]

    # Computing lp (l1, l2 and l-infinity) errors
    l1_error = np.mean(np.abs(u_true - u_hat)).item()
    l2_error = np.mean(np.linalg.norm(u_true - u_hat, axis=-1) / (np.linalg.norm(u_true, axis=-1) + 1e-8)).item()
    l_infinity_error = np.max(np.abs(u_true - u_hat)).item()
    
    # Computing energy spectrum error
    energy_spectrum_error = compute_energy_spectrum_error(u_true, u_hat)

    # Computing histogram error
    k = int(np.ceil(np.sqrt(np.max([u_true.shape[1], u_hat.shape[1]])))) # Sqrt rule for defining the number of bins
    H_true = get_histogram(u_true, num_bins=k, density=True)
    H_hat = get_histogram(u_hat, num_bins=k, density=True)
    hist_error = compute_histogram_error(H_true, H_hat)
    ipm_mmd = get_IPM(jnp.array(u_true), jnp.array(u_hat))

    return {
        'mse': mse,
        # 'mse_per_time': mse_per_time,
        # 'mse_per_component': mse_per_component,
        # 'mae': mae,
        # 'mae_per_time': mae_per_time,
        # 'mae_per_component': mae_per_component,
        # 'rmse': rmse,
        # 'rmse_per_time': rmse_per_time,
        # 'rmse_per_component': rmse_per_component,
        'l1_error': l1_error,
        'l2_error': l2_error,
        'l_infinity_error': l_infinity_error,
        'energy_spectrum': energy_spectrum_error,
        'histogram_error': hist_error,
        'ipm': ipm_mmd
    }

def compute_summary_errors(s_true, s_hat) -> dict[str, Any]:
    """
    Compute error metrics for summary statistics.
    
    Args:
        s_true: True summary statistics tensor
        s_hat: Predicted summary statistics tensor
    
    Returns:
        Dictionary of error metrics
    """
    # Computing MSE, MAE, RMSE
    mse = np.mean((s_true - s_hat) ** 2).item()
    # mse_per_time = np.mean((s_true - s_hat) ** 2, axis=(0, 2))  # [T]
    # mse_per_component = np.mean((s_true - s_hat) ** 2, axis=(0, 1))  # [n]

    # mae = np.mean(np.abs(s_true - s_hat)).item()
    # mae_per_time = np.mean(np.abs(s_true - s_hat), axis=(0, 2))  # [T]
    # mae_per_component = np.mean(np.abs(s_true - s_hat), axis=(0, 1))  # [n]

    # rmse = np.sqrt(np.mean((s_true - s_hat) ** 2)).item()
    # rmse_per_time = np.sqrt(np.mean((s_true - s_hat) ** 2, axis=(0, 2)))  # [T]
    # rmse_per_component = np.sqrt(np.mean((s_true - s_hat) ** 2, axis=(0, 1)))  # [n]

    # Computing lp (l1, l2 and l-infinity) errors
    l1_error = np.mean(np.abs(s_true - s_hat)).item()
    l2_error = np.mean(np.linalg.norm(s_true - s_hat, axis=-1) / (np.linalg.norm(s_true, axis=-1) + 1e-8)).item()
    l_infinity_error = np.max(np.abs(s_true - s_hat)).item()

    # Computing energy spectrum error
    energy_spectrum_error = compute_energy_spectrum_error(s_true, s_hat)

    # Computing histogram error
    k = int(np.ceil(np.sqrt(np.max([s_true.shape[1], s_hat.shape[1]])))) # Sqrt rule for defining the number of bins
    H_true = get_histogram(s_true, num_bins=k, density=True)
    H_hat = get_histogram(s_hat, num_bins=k, density=True)
    hist_error = compute_histogram_error(H_true, H_hat)

    return {
        'mse': mse,
        # 'mse_per_time': mse_per_time,
        # 'mse_per_component': mse_per_component,
        # 'mae': mae,
        # 'mae_per_time': mae_per_time,
        # 'mae_per_component': mae_per_component,
        # 'rmse': rmse,
        # 'rmse_per_time': rmse_per_time,
        # 'rmse_per_component': rmse_per_component,
        'l1_error': l1_error,
        'l2_error': l2_error,
        'l_infinity_error': l_infinity_error,
        'energy_spectrum': energy_spectrum_error,
        'histogram_error': hist_error,
    }
