import numpy as np


def get_sorted(values, reference):
    """
    Sort one array according to the sorted order of another array.

    Parameters:
    values : array-like
        Values to sort.
    reference : array-like
        Reference array used for sorting.

    Returns:
    tuple of np.ndarray
        Sorted values and corresponding sorted reference.
    """
    order = np.argsort(reference.flatten())
    return np.array(values).flatten()[order], reference.flatten()[order]


# =================================
# Binary Classification Metrics
# =================================

def expected_calibration_error(probabilities, labels, n_bins=10):
    """
    Compute ECE for binary classification.

    Parameters:
    probabilities : np.ndarray
        Predicted probabilities (n_samples,).
    labels : np.ndarray
        True binary labels (n_samples,).
    n_bins : int
        Number of bins for calibration.

    Returns:
    float
        Expected Calibration Error (ECE).
    """
    probabilities, labels = probabilities.flatten(), labels.flatten()
    bin_edges = np.linspace(0.0, 1.0, n_bins + 1)
    bin_indices = np.digitize(probabilities, bin_edges, right=True)

    ece = 0.0
    n = len(probabilities)

    for i in range(1, n_bins + 1):
        mask = (bin_indices == i)
        bin_count = np.sum(mask)
        if bin_count > 0:
            avg_confidence = np.mean(probabilities[mask])
            avg_accuracy = np.mean(labels[mask])
            ece += (bin_count / n) * np.abs(avg_accuracy - avg_confidence)

    return ece

def negative_log_likelihood(y, y_hat, classification=True, binary=True, p=1):
    """
    Compute the Negative Log-Likelihood (NLL).

    Parameters:
    y : array-like
        Ground truth labels.
    y_hat : array-like
        Predicted values or probabilities.
    classification : bool
        Whether this is a classification task.
    binary : bool
        Whether classification is binary.
    p : int
        Number of parameters (used for regression NLL).

    Returns:
    float
        Negative log-likelihood.
    """
    if classification and binary:
        y, y_hat = y.flatten(), y_hat.flatten()
        y = y.astype(int)
        epsilon = 1e-6
        y_hat = np.clip(y_hat, epsilon, 1 - epsilon)
        nll = -np.mean(y * np.log(y_hat) + (1 - y) * np.log(1 - y_hat))

    elif classification:
        epsilon = 1e-6
        y_hat = np.clip(y_hat, epsilon, 1 - epsilon)
        nll = -np.mean(np.sum(y * np.log(y_hat), axis=1))

    else:
        y, y_hat = y.flatten(), y_hat.flatten()
        n = len(y)
        residuals = y - y_hat
        sum_squared_residuals = np.sum(residuals**2)
        sigma_squared = sum_squared_residuals / (n - p)
        nll = (n / 2) * np.log(2 * np.pi * sigma_squared) + (sum_squared_residuals / (2 * sigma_squared))

    return nll

def binary_accuracy(y, p_hat):
    """
    Compute binary classification accuracy.

    Parameters:
    y : array-like
        True binary labels.
    p_hat : array-like
        Predicted probabilities.

    Returns:
    float
        Accuracy score.
    """
    y, p_hat = y.flatten(), p_hat.flatten()
    return np.mean((p_hat > 0.5) == y)

def binary_entropy(p_hat):
    """
    Compute the average entropy of binary predictions.

    Parameters:
    p_hat : array-like
        Predicted probabilities.

    Returns:
    float
        Entropy of the prediction distribution.
    """
    p_hat = p_hat.flatten()
    epsilon = 1e-6
    p_hat = np.clip(p_hat, epsilon, 1.0 - epsilon)
    probabilities = np.stack([1 - p_hat, p_hat], axis=1)
    entropy = -np.sum(probabilities * np.log(probabilities), axis=1)
    return np.mean(entropy)

def avg_prediction(p_hat):
    """
    Compute average predicted probability.

    Parameters:
    p_hat : array-like

    Returns:
    float
        Mean predicted value.
    """
    return p_hat.flatten().mean()

def binary_metrics(y, p_hat, p_true=None, n_bins=10):
    """
    Compute a set of evaluation metrics for binary classification.

    Returns:
    tuple of (accuracy, ECE, NLL, entropy, average prediction)
    """
    nll = negative_log_likelihood(y, p_hat)
    binary_acc = binary_accuracy(y, p_hat)
    if p_true is not None:
        ece = expected_calibration_error(p_true, p_hat, n_bins)
    else:
        ece = expected_calibration_error(y, p_hat, n_bins)
    entropy = binary_entropy(p_hat)
    average = avg_prediction(p_hat)
    return binary_acc, ece, nll, entropy, average

# =================================
# Multiclass Classification Metrics
# =================================

def balance_dataset(Xs, Ys, class_weights):
    """
    Balance a dataset according to desired class proportions.

    Parameters:
    Xs : np.ndarray
        Feature data (e.g., images), shape (N, ...).
    Ys : np.ndarray
        One-hot encoded labels, shape (N, P).
    class_weights : np.ndarray
        Desired class distribution (length P).

    Returns:
    tuple
        Balanced X and Y arrays.
    """
    num_classes = class_weights.shape[0]
    Ys_indices = np.argmax(Ys, axis=1)
    total_samples = len(Ys)
    target_counts = np.round((class_weights / class_weights.sum()) * total_samples).astype(int)

    Xs_balanced, Ys_balanced = [], []

    for class_idx in range(num_classes):
        class_mask = Ys_indices == class_idx
        class_Xs = Xs[class_mask]
        class_Ys = Ys[class_mask]
        available_count = len(class_Xs)
        required_count = target_counts[class_idx]

        if available_count >= required_count:
            chosen_indices = np.random.choice(available_count, required_count, replace=False)
        else:
            chosen_indices = np.random.choice(available_count, required_count, replace=True)
        Xs_balanced.append(class_Xs[chosen_indices])
        Ys_balanced.append(class_Ys[chosen_indices])

    Xs_balanced = np.concatenate(Xs_balanced, axis=0)
    Ys_balanced = np.concatenate(Ys_balanced, axis=0)
    shuffle_indices = np.random.permutation(len(Xs_balanced))
    return Xs_balanced[shuffle_indices], Ys_balanced[shuffle_indices]

def multiclass_expected_calibration_error(y, p_hat, num_bins=10):
    """
    Compute Expected Calibration Error for multi-class classification.

    Parameters:
    y : np.ndarray
        One-hot encoded true labels.
    p_hat : np.ndarray
        Softmax probabilities.
    num_bins : int

    Returns:
    float
        ECE value.
    """
    y_true = np.argmax(y, axis=1)
    y_pred = np.argmax(p_hat, axis=1)
    confidences = np.max(p_hat, axis=1)

    bin_edges = np.linspace(0, 1, num_bins + 1)
    ece = 0.0
    n = len(y_true)

    for i in range(num_bins):
        bin_lower = bin_edges[i]
        bin_upper = bin_edges[i + 1]
        in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
        bin_size = np.sum(in_bin)

        if bin_size > 0:
            bin_accuracy = np.mean(y_true[in_bin] == y_pred[in_bin])
            bin_confidence = np.mean(confidences[in_bin])
            ece += (bin_size / n) * np.abs(bin_accuracy - bin_confidence)

    return ece

def multiclass_accuracy(y, p_hat):
    """
    Compute accuracy for multi-class classification.

    Parameters:
    y : np.ndarray
        One-hot encoded true labels.
    p_hat : np.ndarray
        Predicted softmax probabilities.

    Returns:
    float
        Classification accuracy.
    """
    y_true = np.argmax(y, axis=1)
    y_pred = np.argmax(p_hat, axis=1)
    return np.mean(y_true == y_pred)

def multiclass_metrics(y, p_hat, n_bins=10):
    """
    Compute evaluation metrics for multi-class classification.

    Returns:
    tuple
        (Accuracy, ECE, NLL)
    """
    nll = negative_log_likelihood(y, p_hat, classification=True, binary=False)
    ece = multiclass_expected_calibration_error(y, p_hat, n_bins)
    acc = multiclass_accuracy(y, p_hat)
    return acc, ece, nll

# =================================
# Regression  Metrics
# =================================

def rmse_fn(y_true, y_hat):
    """
    Compute Root Mean Square Error (RMSE).

    Returns:
    float
        RMSE value.
    """
    return np.sqrt(np.mean((y_true.flatten() - y_hat.flatten())**2))

def regression_metrics(y, y_hat):
    """
    Compute NLL and RMSE for regression outputs.

    Returns:
    tuple
        (NLL, RMSE)
    """
    nll = negative_log_likelihood(y, y_hat)
    rmse = rmse_fn(y, y_hat)
    return nll, rmse