"""Compute Rcl Rgl and bounds."""

import numpy as np
from scipy.stats import norm, beta, truncnorm


def regret_over(r, c, t, sigma):
    first_part = (r + c - t)
    module = norm.cdf((t - r -c) / sigma)
    f = norm.pdf((t - r - c) / sigma)

    return first_part * (1 - module) + sigma * f

def regret_under(r, c, t, sigma):
    first_part = t - (r + c)
    module = norm.cdf((t - r - c) / sigma)
    f = norm.pdf((t - r - c) / sigma)
    return first_part * module + sigma * f

def regret_over_truncated(r, c, t, sigma):
    """Compute regret under the truncated normal distribution.

    Parameters
    ----------
    r : float
        The estimated probability of the sample.
    c : float
        The calibrated score of the sample.
    t : float
        The threshold t.
    sigma : float
        The standard deviation of the normal distribution.

    Returns
    -------
    regret : float
        The regret under the truncated normal distribution.
    """
    loc = r + c
    scale = np.sqrt(sigma) + 1e-10  # Avoid division by zero
    a = 0
    b = 1 
    return truncnorm.expect(
        lambda x: x - t,
        args=((a-loc)/scale, (b-loc)/scale),
        loc=r + c,
        scale=scale,
        lb=t)


def regret_under_truncated(r, c, t, sigma):
    """Compute regret over the truncated normal distribution.

    Parameters
    ----------
    r : float
        The estimated probability of the sample.
    c : float
        The calibrated score of the sample.
    t : float
        The threshold t.
    sigma : float
        The standard deviation of the normal distribution.

    Returns
    -------
    regret : float
        The regret over the truncated normal distribution.
    """
    # The truncation is from 0 to 1, so we need to adjust the bounds accordingly
    loc = r + c
    scale = np.sqrt(sigma) + 1e-10  # Avoid division by zero
    a = 0
    b = 1

    return truncnorm.expect(
        lambda x: t - x,
        args=((a-loc)/scale, (b-loc)/scale),
        loc=r + c,
        scale=scale,
        ub=t)

def regret_under_beta(t, alpha, beta_param):
    """Compute regret under the beta distribution.

    Parameters
    ----------
    t : float
        The threshold t.
    alpha : float
        The alpha parameter of the beta distribution.
    beta_param : float
        The beta parameter of the beta distribution.

    Returns
    -------
    regret : float
        The regret under the beta distribution.
    """
    return t * beta.cdf(t, alpha, beta_param) - \
           (alpha / (alpha + beta_param)) * beta.cdf(t, alpha + 1, beta_param)


def regret_over_beta(t, alpha, beta_param):
    """Compute regret over the beta distribution.

    Parameters
    ----------
    t : float
        The threshold t.
    alpha : float
        The alpha parameter of the beta distribution.
    beta_param : float
        The beta parameter of the beta distribution.

    Returns
    -------
    regret : float
        The regret over the beta distribution.
    """
    return alpha / (alpha + beta_param) * beta.sf(t, alpha + 1, beta_param) - \
            t * beta.sf(t, alpha, beta_param)

def beta_params_from_mean_var(mean, variance):
    """Compute beta distribution parameters from mean and variance.
    
    Parameters
    ----------
    mean : float or np.ndarray
        The mean of the beta distribution.
    variance : float or np.ndarray
        The variance of the beta distribution.
        
    Returns
    -------
    alpha : float or np.ndarray
        The alpha parameter of the beta distribution.
    beta : float or np.ndarray
        The beta parameter of the beta distribution.
    """
    # For beta distribution: mean = alpha / (alpha + beta)
    # variance = (alpha * beta) / ((alpha + beta)^2 * (alpha + beta + 1))
    
    # Solve for alpha and beta
    # Let s = alpha + beta, then:
    # mean = alpha / s  =>  alpha = mean * s
    # variance = (mean * s * (s - mean * s)) / (s^2 * (s + 1))
    # variance = (mean * (1 - mean) * s) / (s + 1)
    # s = (mean * (1 - mean) / variance) - 1
    # print("Mean:", mean, "Variance:", variance)
    # Clip mean to be strictly between 0 and 1 to avoid numerical issues
    mean = np.clip(mean, 2 * 1e-2, 1 - 2 * 1e-2)
    s = (mean * (1 - mean) / (variance + 10e-14)) - 1
    alpha = mean * s
    beta = (1 - mean) * s
    
    return alpha, beta


def compute_regret_CL(C: np.ndarray, t: np.ndarray, a: np.ndarray) -> np.ndarray:
    """Compute RCL.

    Parameters
    ----------
    C : np.ndarray of shape (n,)
        The calibrated scores of each samples.
    t : np.ndarray of shape (k,)
        The thresholds t* derived from the utilities.
    a : np.ndarray of shape (n, k)
        The action taken on each sample.

    Returns
    -------
    RCL : np.ndarray of shape (n, k)
        The regret of the estimated probabilities to the calibrated scores.

    """
    C = np.atleast_1d(C)  # (n,)
    t = np.atleast_1d(t)  # (k,)
    a = np.atleast_2d(a)  # (n, k)

    assert C.ndim == 1
    assert t.ndim == 1
    assert a.shape == (C.shape[0], t.shape[0])

    a_star_cal = (C[:, None] >= t[None, :]).astype(int)  # (n, k)
    RCL = np.zeros((C.shape[0], t.shape[0]))  # (n, k)
    idx_disagreement = a != a_star_cal  # (n, k)
    RCL[idx_disagreement] = np.abs(C[:, None] - t[None, :])[idx_disagreement]

    return RCL  # (n, k)


def compute_regret_CL_beta(C, t, a, v, r):
    """Compute RCL with beta distribution assumption.

    Parameters
    ----------
    C : np.ndarray of shape (n,)
        The calibrated scores of each samples.
    t : np.ndarray of shape (k,)
        The thresholds t* derived from the utilities.
    a : np.ndarray of shape (n, k)
        The action taken on each sample.
    v : np.ndarray of shape (n,)
        The variance for each sample.
    r : np.ndarray of shape (n,)
        The estimated probabilities for each sample.

    Returns
    -------
    RCL : np.ndarray of shape (n, k)
        The regret of the estimated probabilities to the calibrated scores.
    """
    C = np.atleast_1d(C)  # (n,)
    t = np.atleast_1d(t)  # (k,)
    a = np.atleast_2d(a)  # (n, k)
    v = np.atleast_1d(v)  # (n,)
    r = np.atleast_1d(r)  # (n,)

    assert C.ndim == 1
    assert t.ndim == 1
    assert a.shape == (C.shape[0], t.shape[0])
    assert v.shape == C.shape
    assert r.shape == C.shape

    alpha, beta_param = beta_params_from_mean_var(r + C, v)  # (n,), (n,)
    print("Alpha and beta parameters for beta distribution:", alpha, beta_param)
    RCL = np.zeros((C.shape[0], t.shape[0]))  # (n, k)

    for i in range(C.shape[0]):
        for j in range(t.shape[0]):
            if C[i] >= t[j]:
                RCL[i, j] = regret_under_beta(t[j], alpha[i], beta_param[i])
                if np.isnan(RCL[i, j]):
                    print("Warning: NaN value encountered for sample", i, "and threshold", j)
                    print('parameters:', alpha[i], beta_param[i], 't:', t[j], 'C:', C[i])
                    print("Mean and variance for beta parameters:", r[i] + C[i], v[i])
            else:
                RCL[i, j] = regret_over_beta(t[j], alpha[i], beta_param[i])
                if np.isnan(RCL[i, j]):
                    print("Warning: NaN value encountered for sample", i, "and threshold", j)
                    print('parameters:', alpha[i], beta_param[i], 't:', t[j], 'C:', C[i])
                    print("Mean and variance for beta parameters:", r[i] + C[i], v[i])
    return RCL  # (n, k)

def compute_regret_CL_normal(C, t, a, v, r):
    """Compute RCL with normal distribution assumption.

    Parameters
    ----------
    C : np.ndarray of shape (n,)
        The calibrated scores of each samples.
    t : np.ndarray of shape (k,)
        The thresholds t* derived from the utilities.
    a : np.ndarray of shape (n, k)
        The action taken on each sample.
    v : np.ndarray of shape (n,)
        The variance for each sample.
    r : np.ndarray of shape (n,)
        The estimated probabilities for each sample.

    Returns
    -------
    RCL : np.ndarray of shape (n, k)
        The regret of the estimated probabilities to the calibrated scores.
    """
    C = np.atleast_1d(C)  # (n,)
    t = np.atleast_1d(t)  # (k,)
    a = np.atleast_2d(a)  # (n, k)
    v = np.atleast_1d(v)  # (n,)
    r = np.atleast_1d(r)  # (n,)

    assert C.ndim == 1
    assert t.ndim == 1
    assert a.shape == (C.shape[0], t.shape[0])
    assert v.shape == C.shape
    assert r.shape == C.shape

    sigma = np.sqrt(v)  # (n,)
    RCL = np.zeros((C.shape[0], t.shape[0]))  # (n, k)

    for i in range(C.shape[0]):
        for j in range(t.shape[0]):
            if C[i] >= t[j]:
                RCL[i, j] = regret_under(r[i], C[i], t[j], sigma[i])
            else:
                RCL[i, j] = regret_over(r[i], C[i], t[j], sigma[i])

    return RCL  # (n, k)

def compute_regret_CL_truncated_normal(C, t, a, v, r):
    """Compute RCL with truncated normal distribution assumption.

    Parameters
    ----------
    C : np.ndarray of shape (n,)
        The calibrated scores of each samples.
    t : np.ndarray of shape (k,)
        The thresholds t* derived from the utilities.
    a : np.ndarray of shape (n, k)
        The action taken on each sample.
    v : np.ndarray of shape (n,)
        The variance for each sample.
    r : np.ndarray of shape (n,)
        The estimated probabilities for each sample.
    Returns
    -------
    RCL : np.ndarray of shape (n, k)
        The regret of the estimated probabilities to the calibrated scores.
    """

    C = np.atleast_1d(C)  # (n,)
    t = np.atleast_1d(t)  # (k,)
    a = np.atleast_2d(a)  # (n, k)
    v = np.atleast_1d(v)  # (n,)
    r = np.atleast_1d(r)  # (n,)

    assert C.ndim == 1
    assert t.ndim == 1
    assert a.shape == (C.shape[0], t.shape[0])
    assert v.shape == C.shape
    assert r.shape == C.shape

    sigma = np.sqrt(v)  # (n,)
    RCL = np.zeros((C.shape[0], t.shape[0]))  # (n, k)

    for i in range(C.shape[0]):
        for j in range(t.shape[0]):
            if C[i] >= t[j]:
                RCL[i, j] = regret_under_truncated(r[i], C[i], t[j], sigma[i])
            else:
                RCL[i, j] = regret_over_truncated(r[i], C[i], t[j], sigma[i])

    return RCL  # (n, k)

def compute_V_min(C: np.ndarray, t: np.ndarray):
    """Compute the minimum variance.

    Parameters
    ----------
    C : np.ndarray of shape (n_bins,)
        The calibrated scores within each bin.
    t : np.ndarray of shape (k,)
        The thresholds t* derived from the utilities.

    Returns
    -------
    V_min : np.ndarray of shape (n_bins, k)
        The minimum variance within each bin.

    """
    C = np.atleast_1d(C)  # (n_bins,)
    t = np.atleast_1d(t)  # (k,)

    assert C.ndim == 1
    assert t.ndim == 1

    C = C[:, None]  # (n_bins, 1)
    t = t[None, :]  # (1, k)

    M = np.asarray(C >= t).astype(int)  # (n_bins, k)
    return C * (t - C) * (1 - M) + (1 - C) * (C - t) * M  # (n_bins, k)


def compute_regret_GL_LB(
    C: np.ndarray,
    V: np.ndarray,
    t: np.ndarray,
    bin_counts: np.ndarray | None = None,
):
    """Compute RGL lower bound.


    Parameters
    ----------
    C : np.ndarray of shape (n_bins,)
        The calibrated scores within each bin.
    V : np.ndarray of shape (n_bins,)
        The variance within each bin.
    t : np.ndarray of shape (k,)
        The thresholds t* derived from the utilities.
    bin_counts : np.ndarray of shape (n_bins,), optional
        The number of samples within each bin, by default None.
        If given, the output will be averaged over the bins and of shape (k,).

    Returns
    -------
    RGL_LB : np.ndarray of shape (n_bins, k) or (k,)
        The regret of the estimated probabilities to the calibrated scores.

    """
    C = np.atleast_1d(C)  # (n_bins,)
    V = np.atleast_1d(V)  # (n_bins,)
    t = np.atleast_1d(t)  # (k,)

    assert C.ndim == 1
    assert C.shape == V.shape
    assert t.ndim == 1

    V_min = compute_V_min(C, t)  # (n_bins, k)
    RGL_LB = np.clip(V[:, None] - V_min, 0, None)  # (n_bins, k)

    if bin_counts is not None:
        assert bin_counts.shape == C.shape
        RGL_LB = np.inner(bin_counts, RGL_LB.T) / np.sum(bin_counts)

    return RGL_LB


def compute_regret_GL_UB(
    C: np.ndarray,
    V: np.ndarray,
    t: np.ndarray,
    bin_counts: np.ndarray | None = None,
):
    """Compute RGL upper bound.


    Parameters
    ----------
    C : np.ndarray of shape (n_bins,)
        The calibrated scores within each bin.
    V : np.ndarray of shape (n_bins,)
        The variance within each bin.
    t : np.ndarray of shape (k,)
        The thresholds t* derived from the utilities.
    bin_counts : np.ndarray of shape (n_bins,), optional
        The number of samples within each bin, by default None.
        If given, the output will be averaged over the bins and of shape (k,).

    Returns
    -------
    RGL_LB : np.ndarray of shape (n_bins, k) or (k,)
        The regret of the estimated probabilities to the calibrated scores.

    """
    C = np.atleast_1d(C)  # (n_bins,)
    V = np.atleast_1d(V)  # (n_bins,)
    t = np.atleast_1d(t)  # (k,)

    assert C.ndim == 1
    assert C.shape == V.shape
    assert t.ndim == 1

    V = V[:, None]
    C = C[:, None]
    t = t[None, :]

    RGL_UB = 0.5 * (np.sqrt(V + np.square(C - t)) - np.abs(C - t))

    if bin_counts is not None:
        assert bin_counts.ndim == 1
        assert bin_counts.shape[0] == C.shape[0]
        RGL_UB = np.inner(bin_counts, RGL_UB.T) / np.sum(bin_counts)

    return RGL_UB


def compute_accuracy(y: np.ndarray, S: np.ndarray, t: np.ndarray) -> np.ndarray:
    """Compute RCL.

    Parameters
    ----------
    y : np.ndarray of shape (n,)
        The true binary labels.
    S : np.ndarray of shape (n,)
        The estimated probabilities of each samples.
    t : np.ndarray of shape (k,)
        The thresholds t.

    Returns
    -------
    acc : np.ndarray of shape (k)
        The accuracy for each of the thresholds.

    """
    y = np.atleast_1d(y)  # (n,)
    S = np.atleast_1d(S)  # (n,)
    t = np.atleast_1d(t)  # (k,)

    assert y.ndim == 1
    assert S.shape == y.shape
    assert t.ndim == 1

    y_pred = S[:, None] >= t[None, :]  # (n, k)
    return np.mean(y_pred == y[:, None], axis=0)  # (k,)


# def compute_regret_residuals():
    