from typing import Tuple, Dict, Any, List, Optional

def _safe_div(a: float, b: float) -> float:
    return a / b if b else 0.0

def binary_confusion(pred: bool, label: bool) -> Tuple[int, int, int, int]:
    if pred and label:
        return 1, 0, 0, 0
    if pred and not label:
        return 0, 1, 0, 0
    if (not pred) and (not label):
        return 0, 0, 1, 0
    # not pred and label
    return 0, 0, 0, 1

def binary_metrics(tp: int, fp: int, tn: int, fn: int) -> Dict[str, float]:
    prec = _safe_div(tp, tp + fp)
    rec  = _safe_div(tp, tp + fn)
    f1   = _safe_div(2 * prec * rec, (prec + rec)) if (prec + rec) else 0.0
    acc  = _safe_div(tp + tn, tp + fp + tn + fn)
    spec = _safe_div(tn, tn + fp)  # specificity
    bacc = (rec + spec) / 2.0
    # Matthews Correlation Coefficient
    denom = ((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)) ** 0.5
    mcc = _safe_div(tp*tn - fp*fn, denom) if denom else 0.0
    err = 1.0 - acc
    return {
        "accuracy": acc, "precision": prec, "recall": rec, "f1": f1,
        "specificity": spec, "balanced_accuracy": bacc,
        "mcc": mcc, "error_rate": err
    }

def set_metrics(pred: set, label: set) -> Dict[str, float]:
    inter = len(pred & label)
    p = len(pred)
    g = len(label)
    union = len(pred | label)
    prec = _safe_div(inter, p)
    rec  = _safe_div(inter, g)
    f1   = _safe_div(2 * prec * rec, (prec + rec)) if (prec + rec) else (1.0 if p==0 and g==0 else 0.0)
    jacc = _safe_div(inter, union) if union else 1.0
    return {"precision": prec, "recall": rec, "f1": f1, "jaccard": jacc}

def aggregate_macro(records: List[Dict[str, float]]) -> Dict[str, float]:
    if not records:
        return {"precision":0.0,"recall":0.0,"f1":0.0,"jaccard":0.0}
    keys = records[0].keys()
    out = {}
    for k in keys:
        out[k] = sum(r.get(k, 0.0) for r in records) / len(records)
    return out

def aggregate_micro(all_pred_sets: List[set], all_label_sets: List[set]) -> Dict[str, float]:
    total_pred = sum(len(s) for s in all_pred_sets)
    total_label = sum(len(s) for s in all_label_sets)
    inter = sum(len(p & g) for p, g in zip(all_pred_sets, all_label_sets))
    prec = _safe_div(inter, total_pred)
    rec  = _safe_div(inter, total_label)
    f1   = _safe_div(2 * prec * rec, (prec + rec)) if (prec + rec) else (1.0 if total_pred==0 and total_label==0 else 0.0)
    union = sum(len(p | g) for p, g in zip(all_pred_sets, all_label_sets))
    jacc = _safe_div(inter, union) if union else 1.0
    return {"precision": prec, "recall": rec, "f1": f1, "jaccard": jacc}
