import jax
import jax.numpy as jnp
from jax import jit, vmap
from functools import partial  # For jit with static_argnames

import numpy as np  # For data generation and plotting
from scipy.special import softmax as softmax_np
import matplotlib.pyplot as plt
import time
from tqdm import tqdm


def generate_synthetic_data_numpy(
    n_samples, n_classes, dirichlet_concentration_true, noise_scale_logits
):
    """
    Generates synthetic logits and labels using NumPy.
    Math:
        - True underlying probabilities p_true ~ Dirichlet(alpha)
        - True labels Y_i ~ Categorical(p_true_i) for each sample i
        - Base logits_base = log(p_true)
        - Model logits_model = logits_base + GaussianNoise
    """
    if np.isscalar(dirichlet_concentration_true):
        alpha_params = np.ones(n_classes) * dirichlet_concentration_true + 1e-7
    else:
        alpha_params = np.array(dirichlet_concentration_true) + 1e-7
        if len(alpha_params) != n_classes:
            raise ValueError(
                "dirichlet_concentration_true must be a scalar or match n_classes"
            )
    p_true_underlying = np.random.dirichlet(alpha_params, n_samples)
    true_labels_indices = np.array(
        [np.random.choice(n_classes, p=p_i) for p_i in p_true_underlying]
    )
    base_logits_for_model = np.log(p_true_underlying + 1e-9)
    noise = np.random.normal(
        loc=0, scale=noise_scale_logits, size=base_logits_for_model.shape
    )
    model_logits = base_logits_for_model + noise
    return model_logits, true_labels_indices


@partial(jit, static_argnames=["n_classes"])
def _labels_to_one_hot_jax(labels_indices, n_classes: int):
    return jax.nn.one_hot(labels_indices, n_classes)


@partial(jit, static_argnames=["n_classes"])
def _get_linear_util_arrays_jit(
    probabilities, true_labels_one_hot, a_vector, n_classes: int
):
    n_samples = probabilities.shape[0]
    true_class_indices = jnp.argmax(true_labels_one_hot, axis=1)
    realized_utility_array = a_vector[true_class_indices]
    vectorized_u_array = jnp.tile(a_vector, (n_samples, 1))
    return realized_utility_array, vectorized_u_array


@partial(jit, static_argnames=["n_classes"])
def _get_rank_based_util_arrays_jit(
    probabilities, true_labels_one_hot, theta_vector, n_classes: int
):
    n_samples = probabilities.shape[0]
    true_class_indices = jnp.argmax(true_labels_one_hot, axis=1)
    ranks_for_samples_0idx = jnp.argsort(jnp.argsort(-probabilities, axis=1), axis=1)
    realized_utility_array = theta_vector[
        ranks_for_samples_0idx[jnp.arange(n_samples), true_class_indices]
    ]
    vectorized_u_array = theta_vector[ranks_for_samples_0idx]
    return realized_utility_array, vectorized_u_array


@partial(jit, static_argnames=["n_classes"])
def _get_top_class_util_arrays_jit(probabilities, true_labels_one_hot, n_classes: int):
    n_samples = probabilities.shape[0]
    pred_top_class_indices = jnp.argmax(probabilities, axis=1)
    realized_utility_array = true_labels_one_hot[
        jnp.arange(n_samples), pred_top_class_indices
    ]
    vectorized_u_array = jax.nn.one_hot(
        pred_top_class_indices, n_classes, dtype=probabilities.dtype
    )
    return realized_utility_array, vectorized_u_array


@partial(jit, static_argnames=["n_classes"])
def _get_class_wise_util_arrays_jit(
    probabilities, true_labels_one_hot, class_index_c: int, n_classes: int
):
    n_samples = probabilities.shape[0]
    realized_utility_array = true_labels_one_hot[:, class_index_c]
    vectorized_u_array_single = jax.nn.one_hot(
        jnp.array(class_index_c), n_classes, dtype=probabilities.dtype
    )
    vectorized_u_array = jnp.tile(vectorized_u_array_single, (n_samples, 1))
    return realized_utility_array, vectorized_u_array


# --- Core UC Error Estimation ---
@jit
def _estimate_uc_error_interval_jit(
    probabilities, realized_utility_array, vectorized_u_array
):
    n_samples = probabilities.shape[0]
    v_u_values = jnp.sum(probabilities * vectorized_u_array, axis=1)
    A_terms = realized_utility_array - v_u_values
    sorted_indices = jnp.argsort(v_u_values)
    sorted_A_terms = A_terms[sorted_indices]
    sorted_v_u_values = v_u_values[sorted_indices]
    prefix_sum_A = jnp.concatenate([jnp.array([0.0]), jnp.cumsum(sorted_A_terms)])
    idx_i, idx_j = jnp.triu_indices(n_samples)
    sums_A_in_intervals = prefix_sum_A[idx_j + 1] - prefix_sum_A[idx_i]
    abs_sums_A = jnp.abs(sums_A_in_intervals)
    max_abs_sum = jnp.max(abs_sums_A)
    error = max_abs_sum / n_samples
    argmax_idx_flat = jnp.argmax(abs_sums_A)
    worst_i = idx_i[argmax_idx_flat]
    worst_j = idx_j[argmax_idx_flat]
    worst_interval_v_u_min = sorted_v_u_values[worst_i]
    worst_interval_v_u_max = sorted_v_u_values[worst_j]
    return error, worst_interval_v_u_min, worst_interval_v_u_max


@partial(jit, static_argnames=["n_samples_data"])
def _estimate_uc_error_interval_safe_jit(
    probabilities, realized_utility_array, vectorized_u_array, n_samples_data: int
):
    if n_samples_data == 0:
        return 0.0, 0.0, 0.0
    return _estimate_uc_error_interval_jit(
        probabilities, realized_utility_array, vectorized_u_array
    )


# --- JITted Core Distribution Calculation Functions ---
@partial(jit, static_argnames=["num_utility_samples", "n_classes", "n_samples_data"])
def _linear_uc_dist_core_jit(
    key,
    probabilities,
    true_labels_one_hot,
    num_utility_samples: int,
    n_classes: int,
    n_samples_data: int,
):
    a_vectors_batch = jax.random.uniform(
        key, shape=(num_utility_samples, n_classes), minval=-1.0, maxval=1.0
    )

    def process_single_a_vector(a_vec):
        realized_u, vectorized_u = _get_linear_util_arrays_jit(
            probabilities, true_labels_one_hot, a_vec, n_classes
        )
        error, _, _ = _estimate_uc_error_interval_safe_jit(
            probabilities, realized_u, vectorized_u, n_samples_data
        )
        return error

    uc_errors_all_samples = jax.vmap(process_single_a_vector)(a_vectors_batch)
    overall_max_error = jnp.max(uc_errors_all_samples)
    idx_worst_a = jnp.argmax(uc_errors_all_samples)
    worst_a_vector_overall = a_vectors_batch[idx_worst_a]
    realized_u_worst, vectorized_u_worst = _get_linear_util_arrays_jit(
        probabilities, true_labels_one_hot, worst_a_vector_overall, n_classes
    )
    _, worst_interval_min, worst_interval_max = _estimate_uc_error_interval_safe_jit(
        probabilities, realized_u_worst, vectorized_u_worst, n_samples_data
    )
    return (
        jnp.sort(uc_errors_all_samples),
        overall_max_error,
        worst_a_vector_overall,
        worst_interval_min,
        worst_interval_max,
    )


@partial(jit, static_argnames=["num_utility_samples", "n_classes", "n_samples_data"])
def _rank_based_uc_dist_core_jit(
    key,
    probabilities,
    true_labels_one_hot,
    num_utility_samples: int,
    n_classes: int,
    n_samples_data: int,
):
    raw_thetas_batch = jax.random.uniform(
        key, shape=(num_utility_samples, n_classes), minval=-1.0, maxval=1.0
    )
    theta_vectors_ordered_batch = -jnp.sort(-raw_thetas_batch, axis=1)

    def process_single_theta_vector(theta_vec_ordered):
        realized_u, vectorized_u = _get_rank_based_util_arrays_jit(
            probabilities, true_labels_one_hot, theta_vec_ordered, n_classes
        )
        error, _, _ = _estimate_uc_error_interval_safe_jit(
            probabilities, realized_u, vectorized_u, n_samples_data
        )
        return error

    uc_errors_all_samples = jax.vmap(process_single_theta_vector)(
        theta_vectors_ordered_batch
    )
    overall_max_error = jnp.max(uc_errors_all_samples)
    idx_worst_theta = jnp.argmax(uc_errors_all_samples)
    worst_theta_vector_overall = theta_vectors_ordered_batch[idx_worst_theta]
    realized_u_worst, vectorized_u_worst = _get_rank_based_util_arrays_jit(
        probabilities, true_labels_one_hot, worst_theta_vector_overall, n_classes
    )
    _, worst_interval_min, worst_interval_max = _estimate_uc_error_interval_safe_jit(
        probabilities, realized_u_worst, vectorized_u_worst, n_samples_data
    )
    return (
        jnp.sort(uc_errors_all_samples),
        overall_max_error,
        worst_theta_vector_overall,
        worst_interval_min,
        worst_interval_max,
    )


# --- JITted Core functions for Top-Class, Class-Wise, Top-K (single calculation) ---
@partial(jit, static_argnames=["n_classes", "n_samples_data"])
def _top_class_uc_core_jit(
    probabilities, true_labels_one_hot, n_classes: int, n_samples_data: int
):
    realized_u, vectorized_u = _get_top_class_util_arrays_jit(
        probabilities, true_labels_one_hot, n_classes
    )
    error, v_min, v_max = _estimate_uc_error_interval_safe_jit(
        probabilities, realized_u, vectorized_u, n_samples_data
    )
    example_vectorized_u = (
        vectorized_u[0]
        if n_samples_data > 0
        else jnp.zeros(n_classes, dtype=vectorized_u.dtype)
    )
    return error, v_min, v_max, example_vectorized_u


@partial(jit, static_argnames=["n_cls", "n_samp"])
def _class_wise_uc_single_core_jit(
    class_to_calc: int, probs_j, thot_j, n_cls: int, n_samp: int
):
    r_u, v_u = _get_class_wise_util_arrays_jit(probs_j, thot_j, class_to_calc, n_cls)
    err, vmi, vma = _estimate_uc_error_interval_safe_jit(probs_j, r_u, v_u, n_samp)
    param_vec = jax.nn.one_hot(jnp.array(class_to_calc), n_cls, dtype=probs_j.dtype)
    return err, vmi, vma, param_vec


@partial(jit, static_argnames=["n_classes"])
def _compute_ranks_jit(probabilities, n_classes: int):
    """Helper to compute 0-indexed ranks. Shared calculation."""
    # ranks_for_samples_0idx[s,c] is the rank of class c for sample s (0 is highest)
    return jnp.argsort(jnp.argsort(-probabilities, axis=1), axis=1)


@partial(jit, static_argnames=["n_classes", "n_samples_data"])
def _calculate_uc_metrics_for_k_kernel_jit(
    k_val,  # Dynamic: will be vmapped over
    # Static or precomputed args for vmap (in_axes=None for these in vmap call)
    ranks_for_samples_0idx,
    probabilities_jnp,
    true_class_indices_jnp,
    n_classes: int,  # Static arg for JIT
    n_samples_data: int,  # Static arg for JIT
):
    """Calculates UC error and related metrics for a single k_val, designed for vmap."""
    # 1. Get utility arrays for this k_val using precomputed ranks
    # vectorized_u_array[s,c] = 1 if class c is in top k_val for sample s, else 0
    # k_val is 1-indexed (e.g., k_val=1 means rank 0), ranks are 0-indexed.
    vectorized_u_array = (ranks_for_samples_0idx < k_val).astype(
        probabilities_jnp.dtype
    )

    realized_utility_array = vectorized_u_array[
        jnp.arange(n_samples_data), true_class_indices_jnp
    ]

    # 2. Estimate UC error for this k_val
    error, v_min, v_max = _estimate_uc_error_interval_safe_jit(
        probabilities_jnp, realized_utility_array, vectorized_u_array, n_samples_data
    )

    # 3. Example vectorized_u (e.g., for the first sample, if data exists)
    example_vectorized_u = (
        vectorized_u_array[0]
        if n_samples_data > 0
        else jnp.zeros(n_classes, dtype=vectorized_u_array.dtype)
    )
    return error, v_min, v_max, example_vectorized_u


# --- Helper function to combine results from ECDF distribution calculations over subsamples ---
def _combine_ecdf_results(results_list, n_classes_for_default_param):
    """
    Combines results from ECDF distribution calculations over subsamples/batches.
    results_list is a list of tuples:
    (errors_dist_jnp, max_err, worst_param_vec, v_min, v_max)
    n_classes_for_default_param: needed if results_list is empty.
    """
    if not results_list:
        # print("Warning: _combine_ecdf_results called with an empty list.")
        return (
            jnp.array([]),
            0.0,
            jnp.zeros(n_classes_for_default_param),  # Default param vector
            0.0,
            0.0,
        )

    if len(results_list) == 1:
        return results_list[0]

    # Concatenate all error distributions from all subsamples/batches
    all_errors_dist_list = [
        r[0]
        for r in results_list
        if r is not None and r[0] is not None and r[0].size > 0
    ]
    if not all_errors_dist_list:
        all_errors_dist_sorted = jnp.array([])
    else:
        all_errors_dist_sorted = jnp.sort(jnp.concatenate(all_errors_dist_list))

    # Find the overall max error and corresponding parameters
    overall_max_error = -jnp.inf
    best_result_tuple = None
    for r_tuple in results_list:
        if r_tuple is not None:  # Ensure tuple itself is not None
            current_max_err = r_tuple[1]  # max_err is the second element
            if current_max_err > overall_max_error:
                overall_max_error = current_max_err
                best_result_tuple = r_tuple
            elif (
                best_result_tuple is None
            ):  # Initialize with the first valid tuple if none better found yet
                best_result_tuple = r_tuple

    if best_result_tuple is None:
        # print(
        #     "Warning: Could not determine best_result_tuple in _combine_ecdf_results."
        # )
        # Fallback to defaults, assuming n_classes is known
        return (
            all_errors_dist_sorted,  # Might be empty or sorted list of errors
            0.0,
            jnp.zeros(n_classes_for_default_param),
            0.0,
            0.0,
        )

    # Return the concatenated sorted errors, overall max error, and parameters from the worst case
    return (
        all_errors_dist_sorted,
        overall_max_error,
        best_result_tuple[2],  # worst_param_vec from the subsample/batch with max error
        best_result_tuple[3],  # v_min from the subsample/batch with max error
        best_result_tuple[4],  # v_max from the subsample/batch with max error
    )


# --- Python Wrappers for JAX functions (User-facing API) ---
def calculate_uc_linear_distribution(
    key,
    probabilities_np,
    true_labels_indices_np,
    num_utility_samples=100,
    # Utility batching parameters
    utility_batch_size=50,  # Number of utility vectors per batch
    # Data subsampling parameters (used if use_subsampling is True)
    use_subsampling=True,  # General flag to enable/disable data subsampling
    data_subsample_trigger_size=30000,  # Dataset size to trigger subsampling
    data_subsample_max_samples=1000,  # Size of each data subsample batch
    data_num_subsamples=5,  # Number of data subsample batches
    data_subsampling_seed=42,  # Seed for data subsampling reproducibility
):
    n_samples_total, n_classes = probabilities_np.shape

    if n_classes == 0:
        # print("Warning: n_classes is 0 in calculate_uc_linear_distribution.")
        return (jnp.array([]), 0.0, jnp.zeros(0), 0.0, 0.0)
    if n_samples_total == 0:
        # print("Warning: n_samples_total is 0 in calculate_uc_linear_distribution.")
        return (jnp.array([]), 0.0, jnp.zeros(n_classes), 0.0, 0.0)

    # Determine number of utility batches
    num_utility_batches = (
        num_utility_samples + utility_batch_size - 1
    ) // utility_batch_size

    master_key = key  # Preserve the original key for splitting
    all_results_across_utility_batches = []

    for i_util_batch in tqdm(
        range(num_utility_batches), desc="Utility Batches (Linear)"
    ):
        # Determine current utility batch size
        current_utility_samples_in_batch = utility_batch_size
        if i_util_batch == num_utility_batches - 1:  # Last batch
            current_utility_samples_in_batch = (
                num_utility_samples - i_util_batch * utility_batch_size
            )

        if current_utility_samples_in_batch <= 0:
            continue  # Should not happen with correct logic

        # Generate a unique key for this utility batch processing
        master_key, util_batch_key = jax.random.split(master_key)

        # Decide whether to use data subsampling for this utility batch
        should_subsample_data = (
            use_subsampling and n_samples_total > data_subsample_trigger_size
        )

        results_for_current_utility_batch = []

        if should_subsample_data:
            # print(f"Data subsampling enabled for utility batch {i_util_batch + 1}")
            data_subsample_list = _subsample_data(
                probabilities_np,
                true_labels_indices_np,
                max_samples=data_subsample_max_samples,
                num_subsamples=data_num_subsamples,
                seed=data_subsampling_seed
                + i_util_batch,  # Vary seed per utility batch if desired
            )

            # Generate keys for each data subsample for this utility batch
            data_subsample_keys = jax.random.split(
                util_batch_key, len(data_subsample_list)
            )

            for i_data_sub, (probs_sub, labels_sub) in enumerate(data_subsample_list):
                n_samples_data_sub, _ = probs_sub.shape
                if n_samples_data_sub == 0:
                    continue

                true_labels_jnp_sub = jnp.asarray(labels_sub)
                probabilities_jnp_sub = jnp.asarray(probs_sub)
                true_labels_one_hot_jnp_sub = _labels_to_one_hot_jax(
                    true_labels_jnp_sub, n_classes
                )

                sub_result = _linear_uc_dist_core_jit(
                    data_subsample_keys[i_data_sub],
                    probabilities_jnp_sub,
                    true_labels_one_hot_jnp_sub,
                    current_utility_samples_in_batch,  # Number of utilities for this batch
                    n_classes,
                    n_samples_data_sub,
                )
                results_for_current_utility_batch.append(sub_result)
        else:
            # No data subsampling, use full dataset for this utility batch
            # print(f"Using full dataset for utility batch {i_util_batch + 1}")
            true_labels_jnp = jnp.asarray(true_labels_indices_np)
            probabilities_jnp = jnp.asarray(probabilities_np)
            true_labels_one_hot_jnp = _labels_to_one_hot_jax(true_labels_jnp, n_classes)

            full_data_result = _linear_uc_dist_core_jit(
                util_batch_key,  # Key for this utility batch
                probabilities_jnp,
                true_labels_one_hot_jnp,
                current_utility_samples_in_batch,  # Number of utilities for this batch
                n_classes,
                n_samples_total,
            )
            results_for_current_utility_batch.append(full_data_result)

        # Combine results from data subsamples (if any) for the current utility batch
        if results_for_current_utility_batch:
            combined_for_utility_batch = _combine_ecdf_results(
                results_for_current_utility_batch, n_classes
            )
            all_results_across_utility_batches.append(combined_for_utility_batch)

    # Final combination of results from all utility batches
    if not all_results_across_utility_batches:
        # print("Warning: No results generated across any utility batches.")
        return (jnp.array([]), 0.0, jnp.zeros(n_classes), 0.0, 0.0)

    final_combined_result = _combine_ecdf_results(
        all_results_across_utility_batches, n_classes
    )
    return final_combined_result


def calculate_uc_rank_based_distribution(
    key,
    probabilities_np,
    true_labels_indices_np,
    num_utility_samples=100,
    # Utility batching parameters
    utility_batch_size=50,  # Number of utility vectors per batch
    # Data subsampling parameters (used if use_subsampling is True)
    use_subsampling=True,  # General flag to enable/disable data subsampling
    data_subsample_trigger_size=30000,  # Dataset size to trigger subsampling
    data_subsample_max_samples=1000,  # Size of each data subsample batch
    data_num_subsamples=5,  # Number of data subsample batches
    data_subsampling_seed=42,  # Seed for data subsampling reproducibility
):
    n_samples_total, n_classes = probabilities_np.shape

    if n_classes == 0:
        # print("Warning: n_classes is 0 in calculate_uc_rank_based_distribution.")
        return (jnp.array([]), 0.0, jnp.zeros(0), 0.0, 0.0)
    if n_samples_total == 0:
        # print("Warning: n_samples_total is 0 in calculate_uc_rank_based_distribution.")
        return (jnp.array([]), 0.0, jnp.zeros(n_classes), 0.0, 0.0)

    # Determine number of utility batches
    num_utility_batches = (
        num_utility_samples + utility_batch_size - 1
    ) // utility_batch_size

    master_key = key  # Preserve the original key for splitting
    all_results_across_utility_batches = []

    for i_util_batch in tqdm(
        range(num_utility_batches), desc="Utility Batches (Rank-Based)"
    ):
        # Determine current utility batch size
        current_utility_samples_in_batch = utility_batch_size
        if i_util_batch == num_utility_batches - 1:  # Last batch
            current_utility_samples_in_batch = (
                num_utility_samples - i_util_batch * utility_batch_size
            )

        if current_utility_samples_in_batch <= 0:
            continue

        master_key, util_batch_key = jax.random.split(master_key)
        should_subsample_data = (
            use_subsampling and n_samples_total > data_subsample_trigger_size
        )
        results_for_current_utility_batch = []

        if should_subsample_data:
            # print(f"Data subsampling enabled for utility batch {i_util_batch + 1} (Rank-Based)")
            data_subsample_list = _subsample_data(
                probabilities_np,
                true_labels_indices_np,
                max_samples=data_subsample_max_samples,
                num_subsamples=data_num_subsamples,
                seed=data_subsampling_seed + i_util_batch,
            )
            data_subsample_keys = jax.random.split(
                util_batch_key, len(data_subsample_list)
            )

            for i_data_sub, (probs_sub, labels_sub) in enumerate(data_subsample_list):
                n_samples_data_sub, _ = probs_sub.shape
                if n_samples_data_sub == 0:
                    continue

                true_labels_jnp_sub = jnp.asarray(labels_sub)
                probabilities_jnp_sub = jnp.asarray(probs_sub)
                true_labels_one_hot_jnp_sub = _labels_to_one_hot_jax(
                    true_labels_jnp_sub, n_classes
                )

                sub_result = _rank_based_uc_dist_core_jit(
                    data_subsample_keys[i_data_sub],
                    probabilities_jnp_sub,
                    true_labels_one_hot_jnp_sub,
                    current_utility_samples_in_batch,
                    n_classes,
                    n_samples_data_sub,
                )
                results_for_current_utility_batch.append(sub_result)
        else:
            # print(f"Using full dataset for utility batch {i_util_batch + 1} (Rank-Based)")
            true_labels_jnp = jnp.asarray(true_labels_indices_np)
            probabilities_jnp = jnp.asarray(probabilities_np)
            true_labels_one_hot_jnp = _labels_to_one_hot_jax(true_labels_jnp, n_classes)

            full_data_result = _rank_based_uc_dist_core_jit(
                util_batch_key,
                probabilities_jnp,
                true_labels_one_hot_jnp,
                current_utility_samples_in_batch,
                n_classes,
                n_samples_total,
            )
            results_for_current_utility_batch.append(full_data_result)

        if results_for_current_utility_batch:
            combined_for_utility_batch = _combine_ecdf_results(
                results_for_current_utility_batch, n_classes
            )
            all_results_across_utility_batches.append(combined_for_utility_batch)

    if not all_results_across_utility_batches:
        # print("Warning: No results generated across any utility batches (Rank-Based).")
        return (jnp.array([]), 0.0, jnp.zeros(n_classes), 0.0, 0.0)

    final_combined_result = _combine_ecdf_results(
        all_results_across_utility_batches, n_classes
    )
    return final_combined_result


def _subsample_data(
    probabilities_np,
    true_labels_indices_np,
    max_samples=1000,
    num_subsamples=5,
    seed=42,
):
    """
    Helper function to subsample data for large datasets.
    Returns average results across multiple subsamples.
    """
    n_samples = probabilities_np.shape[0]
    if n_samples <= max_samples:
        return [(probabilities_np, true_labels_indices_np)]

    rng = np.random.RandomState(seed)
    subsamples = []
    for _ in range(num_subsamples):
        indices = rng.choice(n_samples, size=max_samples, replace=False)
        subsamples.append((probabilities_np[indices], true_labels_indices_np[indices]))
    return subsamples


def _average_results(results_list):
    """Helper function to average results across subsamples"""
    if len(results_list) == 1:
        return results_list[0]

    # Average the error values
    avg_error = np.mean([r[0] for r in results_list])

    # For other return values, we'll take the ones from the subsample with highest error
    max_error_idx = np.argmax([r[0] for r in results_list])
    return (avg_error,) + results_list[max_error_idx][1:]


def calculate_uc_top_class(
    probabilities_np,
    true_labels_indices_np,
    use_subsampling=True,
    max_samples=1000,
    num_subsamples=5,
    seed=42,
):
    """
    Calculate utility calibration error for top-class prediction.

    Args:
        probabilities_np: numpy array of shape (n_samples, n_classes)
        true_labels_indices_np: numpy array of shape (n_samples,)
        use_subsampling: if True, subsample large datasets (default: True)
        max_samples: maximum number of samples to use when subsampling (default: 1000)
        num_subsamples: number of subsamples to average over (default: 5)
        seed: random seed for subsampling (default: 42)
    """
    if use_subsampling:
        subsamples = _subsample_data(
            probabilities_np, true_labels_indices_np, max_samples, num_subsamples, seed
        )
        results = []
        for probs, labels in subsamples:
            n_samples_data, n_classes = probs.shape
            true_labels_jnp = jnp.asarray(labels)
            probabilities_jnp = jnp.asarray(probs)
            true_labels_one_hot_jnp = _labels_to_one_hot_jax(true_labels_jnp, n_classes)
            result = _top_class_uc_core_jit(
                probabilities_jnp, true_labels_one_hot_jnp, n_classes, n_samples_data
            )
            results.append(result)
        return _average_results(results)
    else:
        n_samples_data, n_classes = probabilities_np.shape
        true_labels_jnp = jnp.asarray(true_labels_indices_np)
        probabilities_jnp = jnp.asarray(probabilities_np)
        true_labels_one_hot_jnp = _labels_to_one_hot_jax(true_labels_jnp, n_classes)
        return _top_class_uc_core_jit(
            probabilities_jnp, true_labels_one_hot_jnp, n_classes, n_samples_data
        )


def calculate_uc_class_wise(
    probabilities_np,
    true_labels_indices_np,
    use_subsampling=True,
    max_samples=1000,
    num_subsamples=5,
    seed=42,
):
    """
    Calculate utility calibration error for class-wise prediction.

    Args:
        probabilities_np: numpy array of shape (n_samples, n_classes)
        true_labels_indices_np: numpy array of shape (n_samples,)
        use_subsampling: if True, subsample large datasets (default: True)
        max_samples: maximum number of samples to use when subsampling (default: 1000)
        num_subsamples: number of subsamples to average over (default: 5)
        seed: random seed for subsampling (default: 42)
    """
    if use_subsampling:
        subsamples = _subsample_data(
            probabilities_np, true_labels_indices_np, max_samples, num_subsamples, seed
        )
        results = []
        for probs, labels in subsamples:
            n_samples_data, n_classes = probs.shape
            true_labels_jnp = jnp.asarray(labels)
            probabilities_jnp = jnp.asarray(probs)
            true_labels_one_hot_jnp = _labels_to_one_hot_jax(true_labels_jnp, n_classes)
            all_errors, all_vmins, all_vmaxs, all_param_vecs = [], [], [], []
            for c_idx_py in range(n_classes):
                error, v_min, v_max, param_vec = _class_wise_uc_single_core_jit(
                    c_idx_py,
                    probabilities_jnp,
                    true_labels_one_hot_jnp,
                    n_classes,
                    n_samples_data,
                )
                all_errors.append(error)
                all_vmins.append(v_min)
                all_vmaxs.append(v_max)
                all_param_vecs.append(param_vec)
            all_errors_jnp = jnp.array(all_errors)
            max_error = jnp.max(all_errors_jnp)
            worst_class_flat_idx_ = jnp.argmax(all_errors_jnp)
            worst_class_idx_val = worst_class_flat_idx_.item()
            result = (
                max_error,
                jnp.asarray(worst_class_idx_val),
                all_vmins[worst_class_idx_val],
                all_vmaxs[worst_class_idx_val],
                all_param_vecs[worst_class_idx_val],
            )
            results.append(result)
        return _average_results(results)
    else:
        n_samples_data, n_classes = probabilities_np.shape
        true_labels_jnp = jnp.asarray(true_labels_indices_np)
        probabilities_jnp = jnp.asarray(probabilities_np)
        true_labels_one_hot_jnp = _labels_to_one_hot_jax(true_labels_jnp, n_classes)
        all_errors, all_vmins, all_vmaxs, all_param_vecs = [], [], [], []
        for c_idx_py in range(n_classes):
            error, v_min, v_max, param_vec = _class_wise_uc_single_core_jit(
                c_idx_py,
                probabilities_jnp,
                true_labels_one_hot_jnp,
                n_classes,
                n_samples_data,
            )
            all_errors.append(error)
            all_vmins.append(v_min)
            all_vmaxs.append(v_max)
            all_param_vecs.append(param_vec)
        all_errors_jnp = jnp.array(all_errors)
        max_error = jnp.max(all_errors_jnp)
        worst_class_flat_idx_ = jnp.argmax(all_errors_jnp)
        worst_class_idx_val = worst_class_flat_idx_.item()
        return (
            max_error,
            jnp.asarray(worst_class_idx_val),
            all_vmins[worst_class_idx_val],
            all_vmaxs[worst_class_idx_val],
            all_param_vecs[worst_class_idx_val],
        )


def calculate_uc_top_k_overall(
    probabilities_np,
    true_labels_indices_np,
    use_subsampling=True,
    max_samples=1000,
    num_subsamples=5,
    seed=42,
):
    """
    Calculate utility calibration error for top-k prediction.

    Args:
        probabilities_np: numpy array of shape (n_samples, n_classes)
        true_labels_indices_np: numpy array of shape (n_samples,)
        use_subsampling: if True, subsample large datasets (default: True)
        max_samples: maximum number of samples to use when subsampling (default: 1000)
        num_subsamples: number of subsamples to average over (default: 5)
        seed: random seed for subsampling (default: 42)
    """
    if use_subsampling:
        subsamples = _subsample_data(
            probabilities_np, true_labels_indices_np, max_samples, num_subsamples, seed
        )
        results = []
        for probs, labels in subsamples:
            n_samples_data, n_classes = probs.shape
            true_labels_jnp = jnp.asarray(labels)
            probabilities_jnp = jnp.asarray(probs)
            true_labels_one_hot_jnp = _labels_to_one_hot_jax(true_labels_jnp, n_classes)
            true_class_indices_jnp = jnp.argmax(true_labels_one_hot_jnp, axis=1)

            # Pre-compute ranks once
            ranks_for_samples_0idx = _compute_ranks_jit(probabilities_jnp, n_classes)

            # Create array of k values (1 to n_classes)
            k_values_jnp = jnp.arange(1, n_classes + 1)

            # vmap the kernel function over k_values
            # _calculate_uc_metrics_for_k_kernel_jit is already JITted.
            # Args to kernel: k_val, ranks_for_samples_0idx, probabilities_jnp, true_class_indices_jnp, n_classes, n_samples_data
            vmapped_kernel = jax.vmap(
                _calculate_uc_metrics_for_k_kernel_jit,
                in_axes=(
                    0,
                    None,
                    None,
                    None,
                    None,
                    None,
                ),  # k_val is mapped, others are broadcasted/static
                out_axes=0,
            )
            all_errors, all_vmins, all_vmaxs, all_example_vec_u = vmapped_kernel(
                k_values_jnp,
                ranks_for_samples_0idx,
                probabilities_jnp,
                true_class_indices_jnp,
                n_classes,  # Passed as static via partial(jit, static_argnames) in kernel
                n_samples_data,  # Passed as static via partial(jit, static_argnames) in kernel
            )

            if (
                all_errors.shape[0] == 0
            ):  # Should be caught by n_classes == 0 earlier, but as safeguard
                results.append((0.0, 0, 0.0, 0.0, jnp.array([])))
                continue

            max_error = jnp.max(all_errors)
            argmax_k_idx = jnp.argmax(all_errors)

            worst_k_val = k_values_jnp[argmax_k_idx]
            worst_k_vmin = all_vmins[argmax_k_idx]
            worst_k_vmax = all_vmaxs[argmax_k_idx]
            worst_k_example_vec_u = all_example_vec_u[argmax_k_idx]

            result = (
                max_error,
                worst_k_val,  # This is already a JAX scalar, will be converted to Python int by .item() if needed by caller
                worst_k_vmin,
                worst_k_vmax,
                worst_k_example_vec_u,
            )
            results.append(result)
        return _average_results(results)
    else:
        n_samples_data, n_classes = probabilities_np.shape
        if n_classes == 0:  # Handle empty classes before JAX conversion
            return 0.0, 0, 0.0, 0.0, jnp.array([])

        true_labels_jnp = jnp.asarray(true_labels_indices_np)
        probabilities_jnp = jnp.asarray(probabilities_np)
        true_labels_one_hot_jnp = _labels_to_one_hot_jax(true_labels_jnp, n_classes)
        true_class_indices_jnp = jnp.argmax(true_labels_one_hot_jnp, axis=1)

        # Pre-compute ranks once
        ranks_for_samples_0idx = _compute_ranks_jit(probabilities_jnp, n_classes)

        # Create array of k values (1 to n_classes)
        k_values_jnp = jnp.arange(1, n_classes + 1)

        # vmap the kernel function over k_values
        # _calculate_uc_metrics_for_k_kernel_jit is already JITted.
        # Args to kernel: k_val, ranks_for_samples_0idx, probabilities_jnp, true_class_indices_jnp, n_classes, n_samples_data
        vmapped_kernel = jax.vmap(
            _calculate_uc_metrics_for_k_kernel_jit,
            in_axes=(
                0,
                None,
                None,
                None,
                None,
                None,
            ),  # k_val is mapped, others are broadcasted/static
            out_axes=0,
        )
        all_errors, all_vmins, all_vmaxs, all_example_vec_u = vmapped_kernel(
            k_values_jnp,
            ranks_for_samples_0idx,
            probabilities_jnp,
            true_class_indices_jnp,
            n_classes,  # Passed as static via partial(jit, static_argnames) in kernel
            n_samples_data,  # Passed as static via partial(jit, static_argnames) in kernel
        )

        if (
            all_errors.shape[0] == 0
        ):  # Should be caught by n_classes == 0 earlier, but as safeguard
            return 0.0, 0, 0.0, 0.0, jnp.array([])

        max_error = jnp.max(all_errors)
        argmax_k_idx = jnp.argmax(all_errors)

        worst_k_val = k_values_jnp[argmax_k_idx]
        worst_k_vmin = all_vmins[argmax_k_idx]
        worst_k_vmax = all_vmaxs[argmax_k_idx]
        worst_k_example_vec_u = all_example_vec_u[argmax_k_idx]

        return (
            max_error,
            worst_k_val,  # This is already a JAX scalar, will be converted to Python int by .item() if needed by caller
            worst_k_vmin,
            worst_k_vmax,
            worst_k_example_vec_u,
        )


# --- Plotting Function (uses NumPy) ---
def plot_combined_uc_ecdfs(errors_dict, title_prefix):
    plt.figure(figsize=(10, 7))
    # The keys of errors_dict are now scenario names (strings)
    for scenario_name in sorted(errors_dict.keys()):
        errors = np.array(errors_dict[scenario_name])
        if len(errors) == 0:
            print(f"No errors to plot for {title_prefix} scenario: {scenario_name}.")
            continue
        x_ecdf = np.sort(errors)
        y_ecdf = np.arange(1, len(errors) + 1) / len(errors)
        plt.plot(
            x_ecdf,
            y_ecdf,
            marker=".",
            linestyle="-",
            label=scenario_name,  # Use scenario_name directly for the label
            markersize=4,
            alpha=0.7,
        )
    plt.xlabel("Utility Calibration Error")
    plt.ylabel("Cumulative Probability (ECDF)")
    plt.title(f"{title_prefix}")
    plt.grid(True, linestyle="--", alpha=0.7)
    plt.legend(title="Scenario", loc="best")  # Changed legend title
    plt.xlim(left=0)
    plt.ylim(0, 1.05)
    plt.show()


# --- Main Execution Block ---
if __name__ == "__main__":
    master_key = jax.random.PRNGKey(0)
    N_CLASSES = 10
    N_SAMPLES = 2000
    NUM_UTILITY_SAMPLES_FOR_DIST = (
        50  # Number of utility vectors to sample for ECDF plots
    )

    configurations = {
        "Low Cal. Err Scenario (Noise 0.2, Alpha 1.0)": {
            "dirichlet_concentration": 1.0,
            "logit_noise_scale": 0.2,
        },
        "High Cal. Err Scenario (Noise 2.0, Alpha 0.1)": {
            "dirichlet_concentration": 0.1,
            "logit_noise_scale": 2.0,
        },
    }

    print(f"Testing with N_CLASSES={N_CLASSES}, N_SAMPLES={N_SAMPLES}")
    print(f"Sampling {NUM_UTILITY_SAMPLES_FOR_DIST} utility vectors for ECDF plots.\n")

    # Dictionaries to store error distributions for plotting
    all_linear_errors_for_plot = {}
    all_rank_errors_for_plot = {}

    for scenario_name, config in configurations.items():
        scenario_seed = hash(scenario_name) % (
            2**32 - 1
        )  # Create a simple seed from scenario name
        print(f"\n{'='*70}")
        print(f"RUNNING: {scenario_name} (Seed: {scenario_seed})")
        print(f"  Dirichlet Concentration: {config['dirichlet_concentration']}")
        print(f"  Logit Noise Scale: {config['logit_noise_scale']}")
        print(f"{'='*70}\n")

        # Generate data
        model_logits_np, true_labels_np = generate_synthetic_data_numpy(
            N_SAMPLES,
            N_CLASSES,
            config["dirichlet_concentration"],
            config["logit_noise_scale"],
        )
        model_probabilities_np = softmax_np(model_logits_np, axis=1)
        print(f"Generated {N_SAMPLES} NumPy samples.")

        # --- Calculate and Print Specific Utility Calibration Metrics (as before) ---
        print("\nCalculating specific utility calibration metrics...")
        start_time_metrics = time.time()
        uc_top_class_err, _, _, _ = calculate_uc_top_class(
            model_probabilities_np, true_labels_np
        )
        print(f"  Utility Top-Class Error: {float(uc_top_class_err):.4f}")
        uc_class_wise_err, _, _, _, _ = calculate_uc_class_wise(
            model_probabilities_np, true_labels_np
        )
        print(f"  Utility Class-Wise Error (sup_c): {float(uc_class_wise_err):.4f}")
        uc_top_k_overall_err, _, _, _, _ = calculate_uc_top_k_overall(
            model_probabilities_np, true_labels_np
        )
        print(
            f"  Utility Top-K Overall Error (sup_K): {float(uc_top_k_overall_err):.4f}"
        )
        print(
            f"Specific metrics calculation took {time.time() - start_time_metrics:.2f} seconds."
        )

        # --- Calculate Utility Calibration Error Distributions for ECDF plots ---
        print("\nCalculating error distributions for ECDF plots...")
        start_time_dist = time.time()

        # Linear Utility Distribution
        master_key, subkey_linear = jax.random.split(master_key)
        linear_errors_dist_jnp, max_lin_err, _, _, _ = calculate_uc_linear_distribution(
            subkey_linear,
            model_probabilities_np,
            true_labels_np,
            num_utility_samples=NUM_UTILITY_SAMPLES_FOR_DIST,
            utility_batch_size=50,  # Example: Can be tuned
            use_subsampling=True,  # Example: Control whether to allow data subsampling
            data_subsample_trigger_size=1000,  # Example: Lower trigger for testing
            data_subsample_max_samples=500,  # Example: Smaller subsamples for testing
            data_num_subsamples=3,  # Example: Fewer subsamples for testing
            data_subsampling_seed=scenario_seed,  # Use scenario-specific seed
        )
        all_linear_errors_for_plot[scenario_name] = np.array(linear_errors_dist_jnp)
        print(f"  Max Linear Utility Error from sampling: {float(max_lin_err):.4f}")

        # Rank-Based Utility Distribution
        master_key, subkey_rank = jax.random.split(master_key)
        rank_errors_dist_jnp, max_rank_err, _, _, _ = (
            calculate_uc_rank_based_distribution(
                subkey_rank,
                model_probabilities_np,
                true_labels_np,
                num_utility_samples=NUM_UTILITY_SAMPLES_FOR_DIST,
                utility_batch_size=50,  # Example
                use_subsampling=True,  # Example
                data_subsample_trigger_size=1000,  # Example
                data_subsample_max_samples=500,  # Example
                data_num_subsamples=3,  # Example
                data_subsampling_seed=scenario_seed
                + 1,  # Vary seed slightly for rank if desired
            )
        )
        all_rank_errors_for_plot[scenario_name] = np.array(rank_errors_dist_jnp)
        print(
            f"  Max Rank-Based Utility Error from sampling: {float(max_rank_err):.4f}"
        )
        print(
            f"Error distribution calculations took {time.time() - start_time_dist:.2f} seconds."
        )

    # --- Generate ECDF Plots ---
    print("\n\nPlotting ECDF of Utility Calibration Errors...")
    if all_linear_errors_for_plot:
        plot_combined_uc_ecdfs(
            all_linear_errors_for_plot,
            f"ECDF of Linear Utility Calibration Errors ({N_SAMPLES} samples, {N_CLASSES} classes)",
        )
    if all_rank_errors_for_plot:
        plot_combined_uc_ecdfs(
            all_rank_errors_for_plot,
            f"ECDF of Rank-Based Utility Calibration Errors ({N_SAMPLES} samples, {N_CLASSES} classes)",
        )

    print("\n\n--- Comparison Test with ECDF Plots Finished ---")
