import numpy as np


def get_constant_utilty(n_utilities: int = 100, t_target: list = None):
    t = np.linspace(0, 1, n_utilities + 2)[1:-1]
    if t_target is not None:
        # Merge to t and t_target and sort
        t = np.sort(np.concatenate([t, t_target]))
    U1 = np.divide(1, t) - 1

    zeros = np.zeros_like(U1)
    ones = np.ones_like(U1)
    U = np.array([[ones, zeros], [zeros, U1]])  # (2, 2, n_utilities)

    # U1s = np.linspace(0, 10, n_utilities + 2)[1:-1]
    # U0s = np.zeros_like(U1s)
    # U = np.array(
    #     [[np.ones_like(U1s), U0s], [np.zeros_like(U1s), U1s]]
    # )  # (2, 2, n_utilities)

    U = np.moveaxis(U, 2, 0)  # (n_utilities, 2, 2)
    return U


def get_threshold_from_utility(U: np.ndarray) -> np.ndarray:
    """Compute the optimal threshold from the utility matrix.

    Parameters
    ----------
    U : np.ndarray of shape (2, 2) or (k, 2, 2) or (n, k, 2, 2)
        The utility matrix.

    Returns
    -------
    np.ndarray of shape () or (k,) or (n, k)
        The optimal threshold.
    """
    U = np.asarray(U)
    if U.ndim == 2:
        U = U[None, None, :, :]
    elif U.ndim == 3:
        U = U[None, :, :, :]

    assert U.ndim == 4
    assert U.shape[2:] == (2, 2)

    U00 = U[:, :, 0, 0]
    U01 = U[:, :, 0, 1]
    U10 = U[:, :, 1, 0]
    U11 = U[:, :, 1, 1]

    diff0 = U00 - U10
    diff1 = U11 - U01

    t = np.divide(diff0, diff0 + diff1)
    t = t.squeeze()
    # print(t)
    return t


def EU_emp(a: np.ndarray, y: np.ndarray, U: np.ndarray) -> np.ndarray:
    """Compute empirical utility from utility matrix, action and label.

    Parameters
    ----------
    a : np.ndarray of shape (n,) or (n, k)
        The chosen actions.
    y : np.ndarray of shape (n,)
        The true labels.
    U : np.ndarray of shape (2, 2) or (k, 2, 2) or (n, k, 2, 2)
        The utility matrix.

    Returns
    -------
    np.ndarray of shape (n,) or (n, k)
        The empirical utility.
    """
    "Compute empirical expected utility."
    assert np.isin(a, [0, 1]).all()
    assert np.isin(y, [0, 1]).all()
    if a.ndim == 2:
        y = y[:, None]
    assert a.ndim == y.ndim

    if U.ndim == 2:  # (2, 2)
        return U[a, y]
    if U.ndim == 3:  # (k, 2, 2)
        assert a.ndim == 2
        n, k = a.shape
        U = U[None, :, :, :]
        # Using advanced indexing to select the elements
        U = U[:, np.arange(k), a, y]
        return U[0]

    assert U.ndim == 4  # (n, k, 2, 2)
    n, k = a.shape  # Extracting the dimensions of a and b

    # Creating an array of row indices (i) and an array of column indices (j)
    i, j = np.ogrid[:n, :k]

    # Selecting the elements from U
    return U[i, j, a, y]


def decision_from_score(
    S: np.ndarray,
    t: float,
) -> np.ndarray:
    """Fit/predict a classifier and compute decision on the test set.

    Parameters
    ----------
    X : np.ndarray of shape (n, d)
        Features array.
    y : np.ndarray of shape (n,)
        Binary labels.
    idx_train : np.ndarray
        Training indices.
    idx_val1 : np.ndarray
        Validation #1 indices.
    idx_val2 : np.ndarray
        Validation set #2 indices.
    idx_test : np.ndarray
        Test set indices.
    classifier : BaseEstimator
        Classifier to fit.
    rule : glest.Partitioner | None
        Whether to apply our augmentation or not.
    t : float or np.ndarray of shape (k,) or (n, k) or (n, k1, k2)
        Threshold to use.

    Returns
    -------
    float or np.ndarray of shape (n, k) or (n, k1, k2)
        The decision on the test set.
    """
    if np.isscalar(t):
        a_test = S >= t  # float
    elif np.ndim(t) == 1 and np.ndim(S) == 2:
        a_test = S >= t[None, :]  # (n, k)
    elif np.ndim(t) == 1:
        a_test = S[:, None] >= t[None, :]  # (n, k)
    elif np.ndim(t) == 2 and np.ndim(S) == 2:
        a_test = S >= t  # (n, k)
    elif np.ndim(t) == 2:
        a_test = S[:, None] >= t  # (n, k)
    elif np.ndim(t) == 3:
        a_test = S[:, None, None] >= t  # (n, k1, k2)
    else:
        raise ValueError(
            f"t must be scalar or array of shape (n_test,) or (n_test, k). Got {t}."
        )
    return a_test.astype(int)


def u_emp_from_score(
    S: np.ndarray,
    y: np.ndarray,
    t: np.ndarray,
    U: np.ndarray,
    return_action: bool = False,
) -> np.ndarray:
    """Compute empirical utility.

    Parameters
    ----------
    S : np.ndarray of shape (n,)
        The estimated probabilities.
    y : np.ndarray of shape (n,)
        The true labels.
    t : np.ndarray of shape (k,) or (n, k) or (n, k1, k2)
        The threshold.
    U : np.ndarray of shape (k, 2, 2) or (n, k, 2, 2)
        The utility matrix.

    Returns
    -------
    np.ndarray of shape (n,) or (n, k) or (n, k1, k2)
        The estimated utility.
    """
    a_test = decision_from_score(
        S,
        t,
    )
    if a_test.ndim == 2:
        assert U.ndim in [3, 4]
        return EU_emp(a_test, y, U)
        # u_test = np.array(
        #     [EU_emp(a_test[:, i], y, U[i, :, :]) for i in range(a_test.shape[1])]
        # ).T
    elif a_test.ndim == 3:
        u_test = np.array(
            [
                [EU_emp(a_test[:, i, j], y, U[i, :, :]) for i in range(a_test.shape[1])]
                for j in range(a_test.shape[2])
            ]
        ).transpose(2, 1, 0)
    else:
        u_test = EU_emp(a_test, y, U)

    if return_action:
        return u_test, a_test
    return u_test


def test_u_emp_from_score():
    n = 4
    k = 3
    rng = np.random.default_rng(0)
    S = rng.uniform(size=n)
    y = rng.choice([0, 1], size=n)
    t = rng.uniform(size=(n, k))
    U = rng.uniform(size=(n, k, 2, 2))
    u_emp = u_emp_from_score(S, y, t, U)
    assert u_emp.shape == (n, k)

    assert u_emp_from_score(t, y, t, U).shape == t.shape
    assert u_emp_from_score(t[:, 0], y, t, U).shape == t.shape
