from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    matthews_corrcoef,
    mean_squared_error,
    mean_absolute_error,
    r2_score,
)
from scipy.stats import pearsonr, spearmanr, kendalltau
import numpy as np

import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


def compute_classification_metrics(y_true, y_pred, y_prob=None):
    """Compute classification metrics ensuring proper array shapes."""

    logger.debug(f"y_true: {y_true.shape}")
    logger.debug(f"y_pred: {y_pred.shape}")
    logger.debug(f"y_prob: {y_prob.shape}" if y_prob is not None else "None")

    metrics = {
        "accuracy": accuracy_score(y_true, y_pred),
        "precision": precision_score(
            y_true, y_pred, average="weighted", zero_division=0
        ),
        "recall": recall_score(y_true, y_pred, average="weighted"),
        "f1": f1_score(y_true, y_pred, average="weighted"),
        "matthews_corrcoef": matthews_corrcoef(y_true, y_pred),
    }

    if y_prob is not None:
        y_prob = np.asarray(y_prob)

        # Handle different probability array shapes
        if y_prob.ndim == 1:
            # For binary classification with 1D probability array
            metrics["auc_roc"] = roc_auc_score(y_true, y_prob)
        elif y_prob.ndim == 2:
            if y_prob.shape[1] == 2:  # Binary classification
                metrics["auc_roc"] = roc_auc_score(y_true, y_prob[:, 1])
            else:  # Multi-class
                metrics["auc_roc"] = roc_auc_score(
                    y_true, y_prob, multi_class="ovr", average="weighted"
                )

    return metrics


def compute_regression_metrics(y_true, y_pred) -> dict:
    """Compute regression metrics ensuring 1D arrays."""
    # Ensure 1D arrays
    y_true = np.asarray(y_true).ravel()
    y_pred = np.asarray(y_pred).ravel()

    pr, pv = pearsonr(y_true, y_pred)
    sr, sv = spearmanr(y_true, y_pred)
    kendall_tau, kendall_tau_pval = kendalltau(y_true, y_pred)

    # top_k_accuracy_score on 10% of the data
    k = int(len(y_true) * 0.1)        
    # Top-K overlap/accuracy (how many of the top K by true_value are also in the top K by prediction)
    topk_true_idx = set(np.argsort(y_true)[-k:])
    topk_pred_idx = set(np.argsort(y_pred)[-k:])
    topk_accuracy = len(topk_true_idx & topk_pred_idx) / k if k > 0 else np.nan



    return {
        "mse": mean_squared_error(y_true, y_pred),
        "rmse": np.sqrt(mean_squared_error(y_true, y_pred)),
        "mae": mean_absolute_error(y_true, y_pred),
        "pearson_r": pr,
        "pearson_pval": pv,
        "kendall_tau": kendall_tau,
        "kendall_tau_pval": kendall_tau_pval,
        "spearman_r": sr,
        "spearman_pval": sv,
        "topk_10": topk_accuracy,
    }
