"""Helper functions for the xp_nn_calibration module."""
import os
import shutil

import numpy as np
import torch
from sklearn.base import BaseEstimator
from sklearn.calibration import CalibratedClassifierCV
from sklearn.calibration import calibration_curve as sklearn_calibration_curve
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.model_selection import ShuffleSplit, StratifiedShuffleSplit
from torchmetrics import CalibrationError
from tqdm import tqdm
from grouping.xp_grouping.gl_induced import CEstimator, estimate_GL_induced, estimate_CL_induced


def _validate_clustering(*args):
    if len(args) == 2:
        frac_pos, counts = args
    elif len(args) == 3:
        frac_pos, counts, mean_scores = args
    else:
        raise ValueError(f'2 or 3 args must be given. Got {len(args)}.')

    if frac_pos.shape != counts.shape:
        raise ValueError(f'Shape mismatch between frac_pos {frac_pos.shape} and counts {counts.shape}')

    if len(args) == 3 and frac_pos.shape != mean_scores.shape:
        raise ValueError(f'Shape mismatch between frac_pos {frac_pos.shape} and mean_scores {mean_scores.shape}')

    if frac_pos.ndim < 2:
        raise ValueError(f'frac_pos, counts and mean_scores must bet at least '
                         f'2D. Got {frac_pos.ndim}D.')


def _validate_scores(y_scores, one_dim=False):
    if one_dim is not None and one_dim and y_scores.ndim != 1:
        raise ValueError(f'y_scores must be 1D. Got shape {y_scores.shape}.')

    if one_dim is not None and not one_dim and y_scores.ndim != 2:
        raise ValueError(f'y_scores must be 2D. Got shape {y_scores.shape}.')

    if one_dim is None and y_scores.ndim not in [1, 2]:
        raise ValueError(f'y_scores must be 1D or 2D. Got shape {y_scores.shape}.')

    if np.any(y_scores < 0) or np.any(y_scores > 1):
        raise ValueError('y_scores must take values between 0 and 1.')

    if y_scores.ndim == 2 and not np.allclose(np.sum(y_scores, axis=1), 1):
        raise ValueError('y_scores must sum to 1 class-wise when 2D.')


def _validate_labels(y_labels, binary=False):
    uniques = np.unique(y_labels)
    if binary and len(uniques) > 2:
        raise ValueError(f'y_labels must be binary. Found values: {uniques}.')


def calibration_curve(frac_pos, counts, mean_scores=None, remove_empty=True,
                      return_mean_bins=True):
    """Compute calibration curve from output of clustering.
    Result is the same as sklearn's calibration_curve.

    Parameters
    ----------
    frac_pos : (bins, n_clusters) array
        The fraction of positives in each cluster for each bin.

    counts : (bins, n_clusters) array
        The number of samples in each cluster for each bin.

    mean_scores : (bins, n_clusters) array
        The mean score of samples in each cluster for each bin.

    remove_empty : bool
        Whether to remove empty bins.

    return_mean_bins : bool
        Whether to return mean_bins.

    Returns
    -------
    prob_bins : (bins,) arrays
        Fraction of positives in each bin.

    mean_bins : (bins,) arrays
        Mean score in each bin. Returned only if return_mean_bins=True.

    """
    if not return_mean_bins:
        _validate_clustering(frac_pos, counts)

    else:
        _validate_clustering(frac_pos, counts, mean_scores)

    if return_mean_bins and mean_scores is None:
        raise ValueError('mean_scores cannot be None when '
                         'return_mean_bins=True.')

    # if frac_pos.ndim == 2:
    #     frac_pos = frac_pos[..., None]
    #     counts = counts[..., None]
    #     mean_scores = mean_scores[..., None]

    count_sums = np.sum(counts, axis=1, dtype=float)
    non_empty = count_sums > 0
    prob_bins = np.divide(np.sum(frac_pos*counts, axis=1), count_sums,
                          where=non_empty, out=np.full_like(count_sums, np.nan))

    if return_mean_bins:
        mean_bins = np.divide(np.sum(mean_scores*counts, axis=1), count_sums,
                              where=non_empty, out=np.full_like(count_sums, np.nan))

    # The calibration_curve of sklearn removes empty bins.
    # Should do the same to give same result.
    if frac_pos.ndim == 2 and remove_empty:
        prob_bins = prob_bins[non_empty]
        if return_mean_bins:
            mean_bins = mean_bins[non_empty]
    # else:
    #     # If more than 2D, can't remove empty bins while keeping an array
    #     # structure as the resulting array have different length
    #     # depending on for eg the class k. (list of list would work).
    #     # Return array with nans instead.
    #     prob_bins = np.ma.masked_array(prob_bins, mask=~non_empty)
    #     mean_bins = np.ma.masked_array(mean_bins, mask=~non_empty)

    if return_mean_bins:
        return prob_bins, mean_bins

    return prob_bins


def scores_to_id_bins(y_scores, bins):
    y_bins = np.digitize(y_scores, bins=bins) - 1
    y_bins = np.clip(y_bins, 0, len(bins)-2)
    return y_bins


def scores_to_pred(y_scores, threshold=0.5):
    _validate_scores(y_scores, one_dim=None)

    if y_scores.ndim == 1:
        y_pred = (y_scores >= threshold).astype(int)
        y_pred_scores = y_pred*y_scores + (1 - y_pred)*(1 - y_scores)

    elif y_scores.ndim == 2:
        y_pred = np.argmax(y_scores, axis=1).astype(int)
        y_pred_scores = np.max(y_scores, axis=1)

    return y_pred, y_pred_scores


def binarize_multiclass_max(y_scores, y_labels):
    _validate_scores(y_scores, one_dim=False)
    y_pred, y_pred_scores = scores_to_pred(y_scores)
    y_well_guess = (y_pred == y_labels).astype(int)
    return y_pred_scores, y_well_guess


def binarize_multiclass_marginal(y_scores, y_labels, positive_class):
    _validate_scores(y_scores, one_dim=False)
    y_pred_scores = y_scores[:, positive_class]
    y_well_guess = (positive_class == y_labels).astype(int)
    return y_pred_scores, y_well_guess


def scores_to_calibrated_scores(y_scores, prob_bins, bins):
    _validate_scores(y_scores, one_dim=True)

    if len(prob_bins) != len(bins)-1:
        raise ValueError(f'prob_bins must have {len(bins)-1} elements.'
                         f'Got {len(prob_bins)}.')

    # TODO use linear interpolation of mean_bins to adjust the calibrated score
    # xs = np.concatenate([[0], mean_bins, [1]])
    # ys = np.concatenate([[0], prob_bins, [1]])
    # y_scores_cal = piecewise_affine_mapping(y_scores, xs, ys)
    y_bins = scores_to_id_bins(y_scores, bins)
    y_scores_cal = prob_bins[y_bins]
    return y_scores_cal


def grouping_loss_bias(frac_pos, counts, reduce_bin=True):
    prob_bins = calibration_curve(frac_pos, counts,
                                  remove_empty=False, return_mean_bins=False)
    n_bins = np.sum(counts, axis=1)  # number of samples in bin
    n = np.sum(n_bins)
    var = np.divide(frac_pos*(1 - frac_pos), counts - 1, np.full_like(frac_pos, np.nan, dtype=float), where=counts>1)
    var = var*np.divide(counts, n_bins[:, None], np.full_like(frac_pos, np.nan, dtype=float), where=n_bins[:, None]>0)
    bias = np.nansum(var, axis=1) - np.divide(prob_bins*(1 - prob_bins), n_bins - 1)
    bias *= n_bins/n
    if reduce_bin:
        return np.nansum(bias)

    return bias


def grouping_loss_lower_bound(frac_pos, counts, scoring='brier', reduce_bin=True,
                                debiased=False, return_bias=False):
    """Compute a lower bound of the grouping loss from clustering."""
    prob_bins = calibration_curve(frac_pos, counts,
                                  remove_empty=False, return_mean_bins=False)
    diff = np.multiply(counts, np.square(frac_pos - prob_bins[:, None]))

    if reduce_bin:
        lower_bound = np.nansum(diff)/np.sum(counts)

    else:
        lower_bound = np.divide(np.nansum(diff, axis=1), np.sum(counts))

    if debiased:
        bias = np.nan_to_num(grouping_loss_bias(frac_pos, counts, reduce_bin=reduce_bin))
        lower_bound -= bias
        if return_bias:
            return lower_bound, bias

    return lower_bound


def grouping_loss_upper_bound(frac_pos, counts, y_scores, y_labels, bins,
                              scoring='brier', reduce_bin=True):
    """Compute an upper bound of the grouping loss: grouping loss + irreducible
    loss."""
    _validate_scores(y_scores, one_dim=True)
    _validate_labels(y_labels, binary=True)

    prob_bins = calibration_curve(frac_pos, counts, remove_empty=False,
                                  return_mean_bins=False)
    y_scores_cal = scores_to_calibrated_scores(y_scores, prob_bins, bins)

    if reduce_bin:
        upper_bound = np.mean(np.square(y_scores_cal - y_labels))

    else:
        y_bins = scores_to_id_bins(y_scores, bins)
        n_bins = len(bins) - 1
        upper_bound = np.full(n_bins, np.nan)
        diff = np.square(y_scores_cal - y_labels)
        for i in range(n_bins):
            upper_bound[i] = np.mean(diff[y_bins == i])

    return upper_bound


def grouping_loss_upper_bound_c(frac_pos, counts,
                                scoring='brier', reduce_bin=True):
    """Compute an upper bound of the grouping loss: C(1-C)."""
    prob_bins = calibration_curve(frac_pos, counts, remove_empty=False,
                                  return_mean_bins=False)

    upper_bound = prob_bins*(1 - prob_bins)

    if reduce_bin:
        bin_density = np.nansum(counts, axis=1)/np.nansum(counts)
        upper_bound = np.nansum(bin_density*upper_bound)

    return upper_bound


def compute_calib_metrics(frac_pos, counts, y_scores, y_labels, bins):
    """Compute calibration metrics from output of clustering.

    Parameters
    ----------
    frac_pos : (n_bins, n_clusters) array
        The fraction of positives in each cluster for each bin.

    counts : (n_bins, n_clusters) array
        The number of samples in each cluster for each bin.

    mean_scores : (n_bins, n_clusters) array
        The mean score of samples in each cluster for each bin.

    y_scores : (n_samples,) or (n_samples, n_classes) array
        Array of classification scores. If 1D, binary classification
        with threhsold at 0.5 is assumed.

    y_labels : (n_samples,) array
        True labels, taking at most n_classes values (n_classes=2) if binary.

    bins : int or (n_bins+1,) array
        Number of equaly spaced bins or array defining the bins bounds.

    Returns
    -------
    metrics : dict

    """
    _validate_clustering(frac_pos, counts)

    try:
        bins = np.linspace(0, 1, bins+1)
    except TypeError:
        pass

    lower_bound = grouping_loss_lower_bound(frac_pos, counts, reduce_bin=True)
    upper_bound = grouping_loss_upper_bound(frac_pos, counts, y_scores, y_labels, bins, reduce_bin=True)
    upper_bound_c = grouping_loss_upper_bound_c(frac_pos, counts, reduce_bin=True)

    lower_bound_debiased, bias = grouping_loss_lower_bound(frac_pos, counts, reduce_bin=True, debiased=True, return_bias=True)

    # Estimation of GL_induced
    est = CEstimator(y_scores, y_labels)
    c_hat = est.c_hat()
    GL_ind = estimate_GL_induced(c_hat, y_scores, bins)
    CL_ind = estimate_CL_induced(c_hat, y_scores, bins)

    metrics = {
        'lower_bound': lower_bound,
        'upper_bound': upper_bound,
        'upper_bound_c': upper_bound_c,
        'lower_bound_debiased': lower_bound_debiased,
        'lower_bound_bias': bias,
        'n_samples_per_cluster': np.mean(counts, where=counts > 0),
        'n_size_one_clusters': np.sum(counts == 1),
        'n_nonzero_clusters': np.sum(counts > 0),
        'n_bins': len(bins)-1,
        'GL_ind': GL_ind,
        'CL_ind': CL_ind,
        'GL_ind_est': 'KNNRegressor(n_neighbors=2000)',
        'CL_ind_est': 'KNNRegressor(n_neighbors=2000)',
    }

    return metrics


def brier_multi(y_scores, y_labels):
    y_labels = np.array(y_labels, dtype=int)
    y_binary = np.zeros_like(y_scores, dtype=int)
    y_binary[np.arange(len(y_labels)), y_labels] = 1
    return np.mean(np.sum(np.square(y_scores - y_binary), axis=1))


def compute_multi_classif_metrics(y_scores, y_labels):
    _validate_scores(y_scores, one_dim=None)
    if y_scores.ndim == 1:
        _y_scores = np.stack([1-y_scores, y_scores], axis=1)
    else:
        _y_scores = y_scores

    y_pred, _ = scores_to_pred(y_scores)

    y_pred_scores, y_well_guess = binarize_multiclass_max(_y_scores, y_labels)

    y_well_guess = torch.from_numpy(y_well_guess)
    y_pred_scores = torch.from_numpy(y_pred_scores)

    if y_scores.ndim == 2 and y_scores.shape[1] == 2:
        auc = roc_auc_score(y_labels, y_scores[:, 1])
    else:
        try:
            auc = roc_auc_score(y_labels, y_scores, multi_class='ovr')
        except ValueError:
            auc = None

    metrics = {
        'acc': accuracy_score(y_labels, y_pred),
        'auc': auc,
        'brier_multi': brier_multi(y_scores, y_labels),
        'max_ece': CalibrationError(norm='l1', compute_on_step=True).forward(y_pred_scores, y_well_guess).item(),
        'max_mce': CalibrationError(norm='max', compute_on_step=True).forward(y_pred_scores, y_well_guess).item(),
        'max_rmsce': CalibrationError(norm='l2', compute_on_step=True).forward(y_pred_scores, y_well_guess).item(),
    }
    metrics['max_msce'] = np.square(metrics['max_rmsce'])

    return metrics


def affine_interpolation(x, x1, y1, x2, y2):
    """Give the image of x by the affine function passing at (x1, y1)
    and (x2, y2).

    Parameters
    ----------
    x : float or (n,) array
    x1 : float or (n,) array
    y1 : float or (n,) array
    x2 : float or (n,) array
    y2 : float or (n,) array

    Returns
    -------
    y : float or (n,) array

    """
    assert np.all(x1 <= x)
    assert np.all(x <= x2)
    assert np.all(x1 != x2)

    a = np.divide(y2 - y1, x2 - x1)
    b = y1 - np.multiply(a, x1)
    y = np.multiply(a, x) + b
    return y


def piecewise_affine_mapping(x, xs, ys):
    """Give image of x through piecewise affine function defined by xs and ys.

    Parameters
    ----------
    x : (n,) array
    xs : (k,) array
    ys : (k,) array

    Returns
    -------
    y : (n,) array

    """
    assert xs.shape == ys.shape

    idx_sort = np.argsort(xs)
    xs = xs[idx_sort]
    ys = ys[idx_sort]

    idx_bin = np.digitize(x, xs)
    lower_idx = idx_bin - 1
    upper_idx = idx_bin
    # upper_idx = np.clip(idx_bin, a_min=None, a_max=len(xs)-1)
    offset_idx = upper_idx >= len(xs)
    lower_idx[offset_idx] -= 1
    upper_idx[offset_idx] -= 1

    x1 = xs[lower_idx]
    x2 = xs[upper_idx]
    y1 = ys[lower_idx]
    y2 = ys[upper_idx]
    return affine_interpolation(x, x1, y1, x2, y2)

    # import matplotlib.pyplot as plt
    # plt.plot(xs, ys)
    # plt.scatter(x, y, color='red')
    # plt.show()


class MyMaxCalibrator(BaseEstimator):

    def __init__(self, n_bins=15):
        super().__init__()
        self.n_bins = n_bins
        self.prob_true = np.array([])
        self.prob_pred = np.array([])

    def fit(self, y_scores, y):
        if y_scores.ndim == 1:
            y_scores = np.stack([1 - y_scores, y_scores], axis=1)

        y_pred = np.argmax(y_scores, axis=1)
        y_scores_max = np.max(y_scores, axis=1)
        y_binarized = (y_pred == y).astype(int)

        prob_true, prob_pred = sklearn_calibration_curve(y_binarized, y_scores_max,
                                                         n_bins=self.n_bins)
        self.prob_true = prob_true
        self.prob_pred = prob_pred
        return self

    def predict(self, y_scores):
        if y_scores.ndim == 1:
            y_scores = np.stack([1 - y_scores, y_scores], axis=1)

        xs = np.concatenate([[0], self.prob_pred, [1]])
        ys = np.concatenate([[0], self.prob_true, [1]])

        y_scores_max = np.max(y_scores, axis=1)
        y_pred = np.argmax(y_scores, axis=1)
        y_scores_max = piecewise_affine_mapping(y_scores_max, xs, ys)

        print(y_scores.shape)

        n, K = y_scores.shape
        y_scores_cal = np.zeros_like(y_scores)
        y_scores_cal[:, :] = ((1 - y_scores_max)/(K - 1))[:, None]
        y_scores_cal[np.arange(n), y_pred] = y_scores_max

        assert np.allclose(np.sum(y_scores_cal, axis=1), 1)

        return y_scores_cal


def bin_train_test_split(y_scores, test_size=0.5, n_splits=10, bins=15,
                         random_state=0, stratss=False):
    try:
        bins = np.linspace(0, 1, bins+1)
    except TypeError:
        pass

    # Preserve proportion of scores of each bin in train and test
    n_samples = y_scores.shape[0]
    y_bins = scores_to_id_bins(y_scores, bins)
    if stratss:
        cv = StratifiedShuffleSplit(n_splits=n_splits, test_size=test_size,
                                    random_state=random_state)
        split = cv.split(np.zeros(n_samples), y_bins)

    else:
        def mysplit(y_scores, n_splits, test_size):
            indices = np.arange(y_scores.shape[0])
            cv = ShuffleSplit(n_splits=n_splits, test_size=test_size, random_state=0)
            shuffle_splits = []

            # Create one iterator in each bin
            for i in range(len(bins)):
                y_scores_bin = y_scores[y_bins == i]
                shuffle_splits.append(cv.split(y_scores_bin))

            # For each split iterate trough bins to collect samples
            for _ in range(n_splits):
                train_idx = []
                test_idx = []

                for i in range(len(bins)):
                    y_scores_bin = y_scores[y_bins == i]
                    n_samples_bin = len(y_scores_bin)
                    if n_samples_bin - np.ceil(test_size * n_samples_bin) <= 0:
                        continue  # skip bins with not enough points
                    indices_bin = indices[y_bins == i]
                    shuffle_split = shuffle_splits[i]
                    train_idx_bin, test_idx_bin = next(shuffle_split)
                    train_idx.extend(indices_bin[train_idx_bin])
                    test_idx.extend(indices_bin[test_idx_bin])

                train_idx = np.array(train_idx)
                test_idx = np.array(test_idx)

                yield train_idx, test_idx

        split = mysplit(y_scores, n_splits=n_splits, test_size=test_size)

    return split


def calibrate_scores(y_scores, y_labels, test_size=0.5, method='isotonic',
                     max_calibration=False):
    """Calibrate the output of a classifier.

    Fit calibrator on training set and predict on training+test set.

    Parameters
    ----------
    y_scores : (n_samples, n_classes) or (n_samples) array

    y_labels: (n_samples,) array

    test_size : float or array like or None
        Proportion of samples to use as test set. Or indices of test set.

    method : str
        Available: 'isotonic', 'sigmoid'.

    max_calibration : bool
        Whether to calibrate only the maximum scores. If True, output scores
        will be of shape (n_samples, 2).

    Returns
    -------
    y_scores_cal : (n_samples, n_classes) arrray
        Calibrated scores (containing both training and test samples).
    test_idx : array
        Indices of test samples.
    """
    y_scores = np.array(y_scores)
    y_labels = np.array(y_labels)

    n_samples = y_scores.shape[0]

    if test_size is None:
        # Test on the training set
        train_idx = np.ones(n_samples, dtype=bool)
        test_idx = np.ones(n_samples, dtype=bool)

    elif hasattr(test_size, '__len__') and not isinstance(test_size, str):
        # array like given: create train/test split from it
        test_size = np.array(test_size)
        if test_size.size != 0 and not np.can_cast(test_size, int, casting='safe'):
            raise ValueError(f'Values of test_size should be safely castable '
                                'to int.')
        test_size = test_size.astype(int)

        if np.any(test_size >= n_samples) or np.any(test_size < 0):
            raise ValueError(f'test_size is an array with values out of range '
                             f'[0, {n_samples-1}].')

        test_idx = np.zeros(n_samples, dtype=bool)
        test_idx[test_size] = True
        train_idx = np.logical_not(test_idx)
        assert np.all(np.logical_or(train_idx, test_idx))
        assert np.sum(test_idx) == len(test_size)

    else:
        # scalar given
        cv = ShuffleSplit(n_splits=1, test_size=test_size, random_state=0)
        split = cv.split(np.zeros(n_samples))
        train_idx, test_idx = next(split)

    if max_calibration:
        # Calibrate only the maximum confidence score (weakest def of calibration)
        y_scores, y_labels = binarize_multiclass_max(y_scores, y_labels)
        y_scores = np.stack([1 - y_scores, y_scores], axis=1)

    y_scores_train = y_scores[train_idx]
    y_scores_test = y_scores[test_idx]
    y_labels_train = y_labels[train_idx]
    y_labels_test = y_labels[test_idx]

    if method == 'map':
        estimator = MyMaxCalibrator()
        estimator.fit(y_scores_train, y_labels_train)
        y_scores_cal = estimator.predict(y_scores)
    else:
        class DummyClassifier(BaseEstimator):

            def __init__(self):
                self.classes_ = np.unique(y_labels)

            def fit(self, X, y):
                return self

            def predict_proba(self, X):
                return X

        estimator = DummyClassifier()
        calibrated_clf = CalibratedClassifierCV(estimator, method=method, cv='prefit')
        calibrated_clf.fit(y_scores_train, y_labels_train)
        y_scores_cal = calibrated_clf.predict_proba(y_scores)

    eps = 1e-16
    y_scores_cal = np.clip(y_scores_cal, eps, 1-eps)

    test_idx = np.where(test_idx)[0]
    return y_scores_cal, test_idx


def merge_dataset_sublabels(ds_path, sublabels, merged_folder='_merged',
                            verbose=0, check_subtrees=True):
    first_subpaths = None

    # Check existence of all sublabels
    for sublabel in tqdm(sublabels, disable=(verbose <= 0)):
        ds_subpath = os.path.join(ds_path, sublabel)
        if not os.path.isdir(ds_subpath):
            raise ValueError(f'Directory {sublabel} does not exist at {ds_path}.')

    # Check if subtree match for all sublabel
    for sublabel in tqdm(sublabels, disable=(verbose <= 0)):
        if not check_subtrees:
            continue

        ds_subpath = os.path.join(ds_path, sublabel)
        subpaths = set()
        for root, _, filenames in os.walk(ds_subpath):
            if filenames:
                print(f'Reading {root}...')
                relpath = os.path.relpath(root, ds_subpath)
                subpaths.add(relpath)

        if first_subpaths is None:
            first_subpaths = subpaths

        elif subpaths != first_subpaths:
            raise ValueError(f'Subtrees differ between {sublabels[0]} and '
                             f'{sublabel}. '
                             f'Diff -: {first_subpaths - subpaths} '
                             f'Diff +: {subpaths - first_subpaths}.'
                             )

    for sublabel in tqdm(sublabels, disable=(verbose <= 0)):
        ds_subpath = os.path.join(ds_path, sublabel)
        for root, _, filenames in os.walk(ds_subpath):
            if filenames:
                relpath = os.path.relpath(root, ds_subpath)
                destpath = os.path.join(ds_path, merged_folder, relpath)
                os.makedirs(destpath, exist_ok=True)
                for filename in filenames:
                    source_filepath = os.path.join(root, filename)
                    dest_filepath = os.path.join(destpath, f'{sublabel}_{filename}')
                    if verbose > 0:
                        print(f'Copying {source_filepath} to {dest_filepath}...')
                    shutil.copyfile(source_filepath, dest_filepath)


class WrapperClassifier(BaseEstimator):
    """Wrap couples of arrays (X, y) in a scikit-learn compatible classifier.
    To be used for CalibratedClassifierCV.
    """

    def __init__(self, X_train, y_train, X_val, y_val):
        self.X_train = X_train
        self.y_train = y_train
        self.X_val = X_val
        self.y_val = y_val

    def fit(self, X=None, y=None):
        self.classes_ = np.unique(y)
        return self

    def predict_proba(self, X):
        if np.array_equal(X, self.X_train):
            y_scores = self.y_train

        elif np.array_equal(X, self.X_val):
            y_scores = self.y_val

        else:
            raise ValueError(f'Unknown X of shape {X.shape}.')

        if y_scores.ndim == 1:
            y_scores =  np.stack([1 - y_scores, y_scores], axis=1)

        return y_scores
