import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score


def calculate_score_classification(preds, y_true):  # weighted, macro
    """Return accuracy, ua, f1, precision and confuse_matrix."""
    accuracy = accuracy_score(y_true, preds)
    f1_weighted = f1_score(y_true, preds, average="weighted", zero_division=0)
    precision = precision_score(y_true, preds, average="weighted", zero_division=0)
    recall = recall_score(y_true, preds, average="weighted", zero_division=0)
    return accuracy, f1_weighted, precision, recall


def compute_p_target(labels):
    target_trials = np.sum(labels)
    total_trials = len(labels)
    p_target = target_trials / total_trials if total_trials > 0 else 0

    return p_target


# Generate target and impostor scores
def classify_scores(scores, labels):
    tar_scores = []
    imp_scores = []
    for i, score in enumerate(scores):
        if labels[i] == 1:
            tar_scores.append(score)
        else:
            imp_scores.append(score)

    return np.array(tar_scores), np.array(imp_scores)


def compute_frr_far(tar, imp):
    # Combine target and impostor scores and find unique thresholds
    thresholds = np.unique(np.hstack((tar, imp)))

    # Initialize arrays to store FRR and FAR
    frr = np.zeros_like(thresholds, dtype=float)
    far = np.zeros_like(thresholds, dtype=float)

    # Compute FRR and FAR for each threshold
    for i, threshold in enumerate(thresholds):
        frr[i] = np.sum(tar < threshold) / len(tar)  # False Rejection Rate
        far[i] = np.sum(imp >= threshold) / len(imp)  # False Acceptance Rate

    # Extend thresholds to ensure it covers all scores
    thresholds = np.hstack((thresholds, thresholds[-1] + 1e-6))
    frr = np.hstack((frr, frr[-1]))
    far = np.hstack((far, far[-1]))

    return thresholds, frr, far


def compute_min_c(pt, tar, imp, c_miss=1, c_fa=1):
    tar_imp, fnr, fpr = compute_frr_far(tar, imp)

    beta = c_fa * (1 - pt) / (c_miss * pt)
    log_beta = np.log(beta)
    act_c = fnr + beta * fpr
    index_min = np.argmin(act_c)
    min_c = act_c[index_min]
    threshold = tar_imp[index_min]

    return min_c, threshold, log_beta


def get_min_c(scores, labels, c_miss=1, c_fa=1):
    p_target = compute_p_target(labels)
    tar, imp = classify_scores(scores, labels)
    min_c = compute_min_c(p_target, tar, imp, c_miss, c_fa)[0]

    return min_c


def compute_eer(tar, imp):
    tar_imp, fr, fa = compute_frr_far(tar, imp)

    index_min = np.argmin(np.abs(fr - fa))
    eer = 100.0 * np.mean((fr[index_min], fa[index_min]))
    threshold = tar_imp[index_min]

    return eer, threshold


def get_eer(scores, labels):
    tar, imp = classify_scores(scores, labels)
    return compute_eer(tar, imp)[0]
