import wandb
import numpy as np
import pandas as pd
from scipy.stats import spearmanr
from sklearn.metrics import (
    roc_auc_score,
    matthews_corrcoef,
    average_precision_score,
    r2_score,
)

from torch.utils.data import DataLoader


def report_metrics(name, metrics, use_wandb, is_print):
    if is_print:
        print(
            f"[{name}] "
            + ", ".join(f"{key}: {metrics[key]:.3f}" for key in metrics.keys())
        )

    if use_wandb:
        wandb.log({name: metrics})


def minmax(x):
    """Normalize data to [0,1] range using min-max scaling"""
    return (x - np.min(x)) / (np.max(x) - np.min(x))


def calc_ndcg(y_true, y_score, **kwargs):
    """
    Calculate Normalized Discounted Cumulative Gain (NDCG)

    Inputs:
        y_true: an array of the true scores where higher score is better
        y_score: an array of the predicted scores where higher score is better
    Options:
        quantile: If True, uses the top k quantile of the distribution
        top: under the quantile setting this is the top quantile to
            keep in the gains calc. This is a PERCENTAGE (i.e input 10 for top 10%)
    """
    if "quantile" not in kwargs:
        kwargs["quantile"] = True
    if "top" not in kwargs:
        kwargs["top"] = 10
    if kwargs["quantile"]:
        k = np.floor(y_true.shape[0] * (kwargs["top"] / 100)).astype(int)
    else:
        k = kwargs["top"]
    if isinstance(y_true, pd.Series):
        y_true = y_true.values
    if isinstance(y_score, pd.Series):
        y_score = y_score.values
    gains = minmax(y_true)
    ranks = np.argsort(np.argsort(-y_score)) + 1

    if k == "all":
        k = len(ranks)
    # sub to top k
    ranks_k = ranks[ranks <= k]
    gains_k = gains[ranks <= k]
    # all terms with a gain of 0 go to 0
    ranks_fil = ranks_k[gains_k != 0]
    gains_fil = gains_k[gains_k != 0]

    # if none of the ranks made it return 0
    if len(ranks_fil) == 0:
        return 0

    # discounted cumulative gains
    dcg = np.sum([g / np.log2(r + 1) for r, g in zip(ranks_fil, gains_fil)])

    # ideal dcg - calculated based on the top k actual gains
    ideal_ranks = np.argsort(np.argsort(-gains)) + 1
    ideal_ranks_k = ideal_ranks[ideal_ranks <= k]
    ideal_gains_k = gains[ideal_ranks <= k]
    ideal_ranks_fil = ideal_ranks_k[ideal_gains_k != 0]
    ideal_gains_fil = ideal_gains_k[ideal_gains_k != 0]
    idcg = np.sum(
        [g / np.log2(r + 1) for r, g in zip(ideal_ranks_fil, ideal_gains_fil)]
    )

    # normalize
    ndcg = dcg / idcg

    return ndcg


def calc_toprecall(true_scores, model_scores, top_true=10, top_model=10):
    """
    Calculate recall at top k for predicted scores

    Inputs:
        true_scores: array of true scores
        model_scores: array of predicted scores
        top_true: percentage of top true scores to consider (default 10%)
        top_model: percentage of top predicted scores to consider (default 10%)
    """
    top_true = true_scores >= np.percentile(true_scores, 100 - top_true)
    top_model = model_scores >= np.percentile(model_scores, 100 - top_model)

    TP = (top_true) & (top_model)
    recall = TP.sum() / (top_true.sum()) if top_true.sum() > 0 else 0

    return recall


def calc_metrics(y_true, y_score, metric):
    if isinstance(y_true, list):
        y_true = np.array(y_true)
    if isinstance(y_score, list):
        y_score = np.array(y_score)

    if metric == "auroc":  # y_true: binary
        return roc_auc_score(y_true, y_score)
    elif metric == "auroc_abs":  # y_true: binary
        return max(roc_auc_score(y_true, y_score), roc_auc_score(y_true, -y_score))
    elif metric == "mcc":  # y_true: binary, y_score: binary
        return matthews_corrcoef(y_true, y_score)
    elif metric == "ap":  # y_true: binary
        return average_precision_score(y_true, y_score)
    elif metric == "ap_abs":  # y_true: binary
        return max(
            average_precision_score(y_true, y_score),
            average_precision_score(y_true, -y_score),
        )
    elif metric == "toprecall":  # y_true: score
        return calc_toprecall(y_true, y_score)
    elif metric == "toprecall_abs":  # y_true: score
        return max(calc_toprecall(y_true, y_score), calc_toprecall(y_true, -y_score))
    elif metric == "r2":  # y_true: score
        return r2_score(y_true, y_score)
    elif metric == "spearman":  # y_true: binary or score
        y_true = y_true.astype(float)
        return spearmanr(y_true, y_score)[0]
    elif metric == "spearman_abs":  # y_true: binary or score
        y_true = y_true.astype(float)
        return abs(spearmanr(y_true, y_score)[0])
    elif metric == "ndcg":  # y_true: binary or score
        y_true = y_true.astype(float)
        return calc_ndcg(y_true, y_score)
    elif metric == "ndcg_abs":  # y_true: binary or score
        y_true = y_true.astype(float)
        return max(calc_ndcg(y_true, y_score), calc_ndcg(y_true, -y_score))
    else:
        raise ValueError(f"Metric {metric} not found")


def get_dataloader(cfg, dataset):
    dataloader = DataLoader(
        dataset=dataset,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
    )
    return dataloader


# Removed HuggingfaceDataset to keep lightweight; only RNAGym is supported in this release.
