import warnings

import numpy as np
from joblib import Parallel, delayed
from scipy.stats import ConstantInputWarning, pearsonr, spearmanr
from sklearn.exceptions import UndefinedMetricWarning
from sklearn.metrics import average_precision_score, mean_absolute_error, mean_squared_error, roc_auc_score


def bootstrap_regression(
    predictions: np.ndarray,
    targets: np.ndarray,
    n: int = 1000,
    n_jobs: int = -4,
    verbose: int = 0,
):
    def _bootstrap_single_iteration(
        sub_preds: np.ndarray,
        sub_targs: np.ndarray,
    ):
        indices = np.random.choice(len(sub_preds), len(sub_preds), replace=True)
        preds_sample = sub_preds[indices]
        targs_sample = sub_targs[indices]

        try:
            mae = mean_absolute_error(targs_sample, preds_sample)
        except ValueError:
            mae = np.nan

        try:
            mse = mean_squared_error(targs_sample, preds_sample)
        except ValueError:
            mse = np.nan

        try:
            spearman = spearmanr(targs_sample, preds_sample)[0]
        except ValueError:
            spearman = np.nan

        try:
            pcc = pearsonr(targs_sample, preds_sample)[0]
        except ValueError:
            pcc = np.nan

        return mae, mse, spearman, pcc

    predictions = np.asarray(predictions)
    targets = np.asarray(targets)

    results = Parallel(n_jobs=n_jobs, verbose=verbose)(
        delayed(_bootstrap_single_iteration)(predictions, targets) for _ in range(n)
    )
    maes, mses, spearmans, pccs = zip(*results)
    avg_metrics = {
        "MAE": np.nanmean(maes).item(),
        "MSE": np.nanmean(mses).item(),
        "Spearman": np.nanmean(spearmans).item(),
        "PCC": np.nanmean(pccs).item(),
    }
    return avg_metrics


def bootstrap_clf(
    predictions: np.ndarray,
    targets: np.ndarray,
    n: int = 1000,
    n_jobs: int = -4,
    verbose: int = 0,
):
    def _bootstrap_single_iteration(
        sub_preds: np.ndarray,
        sub_targs: np.ndarray,
    ):
        warnings.filterwarnings("ignore", category=ConstantInputWarning)
        warnings.filterwarnings("ignore", category=UserWarning)
        warnings.filterwarnings("ignore", category=UndefinedMetricWarning)

        indices = np.random.choice(len(sub_preds), len(sub_preds), replace=True)
        preds_sample = sub_preds[indices]
        targs_sample = sub_targs[indices]

        try:
            auc = roc_auc_score(targs_sample, preds_sample)
        except ValueError:
            auc = np.nan

        try:
            ap = average_precision_score(targs_sample, preds_sample)
        except ValueError:
            ap = np.nan

        pred_labels = preds_sample >= 0.5
        acc = np.mean(pred_labels == targs_sample)

        return auc, ap, acc

    predictions = np.asarray(predictions)
    targets = np.asarray(targets)

    results = Parallel(n_jobs=n_jobs, verbose=verbose)(
        delayed(_bootstrap_single_iteration)(predictions, targets) for _ in range(n)
    )

    aucs, aps, accs = zip(*results)

    avg_metrics = {
        "AUROC": np.nanmean(aucs).item(),
        "AUPRC": np.nanmean(aps).item(),
        "Accuracy": np.nanmean(accs).item(),
    }

    return avg_metrics
