import numpy as np
from typing import Sequence, Union, Dict, Any
import warnings

def _validate_binary_array(arr: Sequence, name: str) -> np.ndarray:
    """Ensure array-like input is binary and return as NumPy array."""
    arr = np.asarray(arr)
    if arr.ndim != 1:
        raise ValueError(f"{name} must be one‐dimensional.")
    unique = np.unique(arr)
    if set(unique) - {0, 1}:
        raise ValueError(f"{name} must be binary (0/1). Found: {unique}")
    return arr

def _group_indices(
    sensitive: np.ndarray
) -> Dict[Any, np.ndarray]:
    """Return a dict mapping each sensitive‐group value to boolean mask."""
    groups = np.unique(sensitive)
    if len(groups) != 2:
        raise ValueError(f"sensitive_attr must have exactly two groups. Got {groups}")
    return {g: sensitive == g for g in groups}

def calculate_demographic_parity(
    y_pred: Sequence[int],
    sensitive_attr: Sequence[Union[int,str]]
) -> float:
    """
    Demographic parity difference:
      |P(ŷ=1 | A=priv) - P(ŷ=1 | A=unpriv)|
    """
    y_pred = _validate_binary_array(y_pred, "y_pred")
    sensitive = np.asarray(sensitive_attr)
    if sensitive.shape != y_pred.shape:
        raise ValueError("y_pred and sensitive_attr must be same length.")
    groups = _group_indices(sensitive)
    rates = {g: y_pred[mask].mean() for g, mask in groups.items()}
    return abs(rates[groups.keys().__iter__().__next__()] -
               rates[list(groups.keys())[1]])

def calculate_equal_opportunity(
    y_pred: Sequence[int],
    y_true: Sequence[int],
    sensitive_attr: Sequence[Union[int,str]]
) -> float:
    """
    Equal opportunity difference:
      |TPR(priv) - TPR(unpriv)|
    """
    y_pred = _validate_binary_array(y_pred, "y_pred")
    y_true = _validate_binary_array(y_true, "y_true")
    if y_true.shape != y_pred.shape:
        raise ValueError("y_true and y_pred must be same length.")
    sensitive = np.asarray(sensitive_attr)
    if sensitive.shape != y_pred.shape:
        raise ValueError("sensitive_attr must be same length as y_pred.")
    groups = _group_indices(sensitive)

    tprs = {}
    for g, mask in groups.items():
        positives = mask & (y_true == 1)
        if positives.sum() == 0:
            warnings.warn(f"No positive instances for group {g}; setting TPR to nan.")
            tprs[g] = np.nan
        else:
            tprs[g] = y_pred[positives].mean()
    return abs(tprs[list(groups.keys())[0]] - tprs[list(groups.keys())[1]])

def calculate_equalized_odds(
    y_pred: Sequence[int],
    y_true: Sequence[int],
    sensitive_attr: Sequence[Union[int,str]]
) -> float:
    """
    Equalized odds difference:
      max( |TPR(priv) - TPR(unpriv)|, |FPR(priv) - FPR(unpriv)| )
    """
    y_pred = _validate_binary_array(y_pred, "y_pred")
    y_true = _validate_binary_array(y_true, "y_true")
    sensitive = np.asarray(sensitive_attr)
    groups = _group_indices(sensitive)

    diffs = []
    for rate_fn in (
        lambda pred, true: pred[(true == 1)].mean(),  # TPR
        lambda pred, true: pred[(true == 0)].mean()   # FPR
    ):
        vals = []
        for g, mask in groups.items():
            sel = rate_fn(y_pred[mask], y_true[mask])
            vals.append(sel)
        diffs.append(abs(vals[0] - vals[1]))
    return max(diffs)

def calculate_treatment_equality(
    y_pred: Sequence[int],
    y_true: Sequence[int],
    sensitive_attr: Sequence[Union[int,str]],
    zero_div_policy: str = 'nan'
) -> float:
    """
    Treatment equality difference:
      |(FNR/FPR)_priv - (FNR/FPR)_unpriv|
    zero_div_policy: 'nan'|'inf'|'warn' 
    """
    y_pred = _validate_binary_array(y_pred, "y_pred")
    y_true = _validate_binary_array(y_true, "y_true")
    sensitive = np.asarray(sensitive_attr)
    groups = _group_indices(sensitive)

    ratios = {}
    for g, mask in groups.items():
        preds, trues = y_pred[mask], y_true[mask]
        fn = ((trues == 1) & (preds == 0)).sum()
        fp = ((trues == 0) & (preds == 1)).sum()
        if fp == 0:
            if zero_div_policy == 'inf':
                ratio = np.inf
            elif zero_div_policy == 'nan':
                ratio = np.nan
            else:
                warnings.warn(f"Zero FP for group {g}; setting ratio to nan.")
                ratio = np.nan
        else:
            ratio = fn / fp
        ratios[g] = ratio
    return abs(ratios[list(groups.keys())[0]] - ratios[list(groups.keys())[1]])

def calculate_predictive_parity(
    y_pred: Sequence[int],
    y_true: Sequence[int],
    sensitive_attr: Sequence[Union[int,str]]
) -> float:
    """
    Predictive parity difference (precision difference):
      |PPV(priv) - PPV(unpriv)|
    """
    y_pred = _validate_binary_array(y_pred, "y_pred")
    y_true = _validate_binary_array(y_true, "y_true")
    sensitive = np.asarray(sensitive_attr)
    groups = _group_indices(sensitive)

    ppvs = {}
    for g, mask in groups.items():
        sel = mask & (y_pred == 1)
        if sel.sum() == 0:
            warnings.warn(f"No positive predictions for group {g}; setting PPV to nan.")
            ppvs[g] = np.nan
        else:
            ppvs[g] = np.mean(y_true[sel] == 1)
    return abs(ppvs[list(groups.keys())[0]] - ppvs[list(groups.keys())[1]])

def calculate_accuracy_equality(
    y_pred: Sequence[int],
    y_true: Sequence[int],
    sensitive_attr: Sequence[Union[int,str]]
) -> float:
    """
    Overall accuracy difference:
      |Acc(priv) - Acc(unpriv)|
    """
    y_pred = _validate_binary_array(y_pred, "y_pred")
    y_true = _validate_binary_array(y_true, "y_true")
    sensitive = np.asarray(sensitive_attr)
    groups = _group_indices(sensitive)

    accs = {g: np.mean(y_pred[mask] == y_true[mask])
            for g, mask in groups.items()}
    return abs(accs[list(groups.keys())[0]] - accs[list(groups.keys())[1]])

def calculate_model_accuracy(
    y_pred: Sequence[int],
    y_true: Sequence[int]
) -> float:
    """
    Calculate overall model accuracy:
      Acc = (TP + TN) / (TP + TN + FP + FN)
    """
    y_pred = _validate_binary_array(y_pred, "y_pred")
    y_true = _validate_binary_array(y_true, "y_true")
    if y_true.shape != y_pred.shape:
        raise ValueError("y_true and y_pred must be same length.")
    
    return np.mean(y_pred == y_true)

def calculate_all_fairness_metrics(
    y_pred: Sequence[int],
    y_true: Sequence[int],
    sensitive_attr: Sequence[Union[int,str]]
) -> Dict[str, float]:
    """
    Compute a suite of group‐fairness metrics.
    Returns a dict with keys:
      - model_accuracy
      - demographic_parity
      - equal_opportunity
      - equalized_odds
      - treatment_equality
      - predictive_parity
      - accuracy_equality
    """
    return {
        'model_accuracy': calculate_model_accuracy(y_pred, y_true),
        'demographic_parity': calculate_demographic_parity(y_pred, sensitive_attr),
        'equal_opportunity': calculate_equal_opportunity(y_pred, y_true, sensitive_attr),
        'equalized_odds': calculate_equalized_odds(y_pred, y_true, sensitive_attr),
        'treatment_equality': calculate_treatment_equality(y_pred, y_true, sensitive_attr),
        'predictive_parity': calculate_predictive_parity(y_pred, y_true, sensitive_attr),
        'accuracy_equality': calculate_accuracy_equality(y_pred, y_true, sensitive_attr),
    }
