"""
Utility functions for fairness experiments in loan classification.

This module provides helper functions for analyzing fairness metrics,
especially in competitive classifier settings where multiple classifiers
may be combined using logical OR operations.
"""

import numpy as np
import pandas as pd
from fairlearn.metrics import false_negative_rate
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from fairlearn.postprocessing import ThresholdOptimizer


def compute_metrics(y_true, pred1, pred2):
    """
    Compute metrics for two classifiers and their OR combination.
    
    Calculates false negative rates and correlation of false negatives
    between two classifiers.
    
    Args:
        y_true (array-like): Ground truth binary labels
        pred1 (array-like): Predictions from first classifier
        pred2 (array-like): Predictions from second classifier
        
    Returns:
        tuple: (fnr1, fnr2, fnr_or, fnr_corr)
            - fnr1: False negative rate of first classifier
            - fnr2: False negative rate of second classifier
            - fnr_or: False negative rate of the OR combination
            - fnr_corr: Correlation between false negatives of both classifiers
    """
    # Combine predictions using logical OR
    or_preds = pred1 + pred2 - pred1 * pred2
    
    # Calculate individual false negative rates
    fnr1 = false_negative_rate(y_true, pred1)
    fnr2 = false_negative_rate(y_true, pred2)
    fnr_or = false_negative_rate(y_true, or_preds)

    # Calculate correlation between false negatives
    is_fn_1 = (y_true == 0) & (pred1 == 1)
    is_fn_2 = (y_true == 0) & (pred2 == 1)
    same_fn = is_fn_1 & is_fn_2
    either_fn = is_fn_1 | is_fn_2
    fnr_corr = np.sum(same_fn) / np.sum(either_fn) if np.sum(either_fn) > 0 else np.nan

    return fnr1, fnr2, fnr_or, fnr_corr


def compute_metrics_with_correlation(y_true, pred1, pred2, sensitive_features):
    """
    Compute metrics for two classifiers with separate correlations for each demographic group.
    
    Extends compute_metrics to calculate correlation of false negatives
    separately for each sensitive feature group.
    
    Args:
        y_true (array-like): Ground truth binary labels
        pred1 (array-like): Predictions from first classifier
        pred2 (array-like): Predictions from second classifier
        sensitive_features (array-like): Binary sensitive attribute (0/1)
        
    Returns:
        tuple: (fnr1, fnr2, fnr_or, fnr_corr, fnr_corr_s0, fnr_corr_s1)
            - fnr1: False negative rate of first classifier
            - fnr2: False negative rate of second classifier
            - fnr_or: False negative rate of the OR combination
            - fnr_corr: Overall correlation between false negatives
            - fnr_corr_s0: Correlation between false negatives for group S=0
            - fnr_corr_s1: Correlation between false negatives for group S=1
    """
    # Original metrics
    or_preds = pred1 + pred2 - pred1 * pred2
    fnr1 = false_negative_rate(y_true, pred1)
    fnr2 = false_negative_rate(y_true, pred2)
    fnr_or = false_negative_rate(y_true, or_preds)

    is_fn_1 = (y_true == 0) & (pred1 == 1)
    is_fn_2 = (y_true == 0) & (pred2 == 1)
    same_fn = is_fn_1 & is_fn_2
    either_fn = is_fn_1 | is_fn_2
    fnr_corr = np.sum(same_fn) / np.sum(either_fn) if np.sum(either_fn) > 0 else np.nan
    
    # Separate correlations by sensitive feature
    mask_s0 = sensitive_features == 0
    mask_s1 = sensitive_features == 1
    
    # For S0 group
    is_fn_1_s0 = is_fn_1[mask_s0]
    is_fn_2_s0 = is_fn_2[mask_s0]
    same_fn_s0 = same_fn[mask_s0]
    either_fn_s0 = either_fn[mask_s0]
    fnr_corr_s0 = np.sum(same_fn_s0) / np.sum(either_fn_s0) if np.sum(either_fn_s0) > 0 else np.nan
    
    # For S1 group
    is_fn_1_s1 = is_fn_1[mask_s1]
    is_fn_2_s1 = is_fn_2[mask_s1]
    same_fn_s1 = same_fn[mask_s1]
    either_fn_s1 = either_fn[mask_s1]
    fnr_corr_s1 = np.sum(same_fn_s1) / np.sum(either_fn_s1) if np.sum(either_fn_s1) > 0 else np.nan
    
    return fnr1, fnr2, fnr_or, fnr_corr, fnr_corr_s0, fnr_corr_s1


def compute_group_fnr_gap(y_true, y_pred, sensitive_features, prefix, results):
    """
    Compute false negative rate gap between demographic groups.
    
    Calculates FNR for each demographic group and the gap between them.
    Updates the results dictionary with these values.
    
    Args:
        y_true (array-like): Ground truth binary labels
        y_pred (array-like): Model predictions
        sensitive_features (array-like): Binary sensitive attribute (0/1)
        prefix (str): Prefix for storing results in the results dictionary
        results (dict): Dictionary to store results
        
    Returns:
        None: Updates the results dictionary in-place
    """
    # Calculate FNR for each group
    mask0 = sensitive_features == 0
    mask1 = sensitive_features == 1
    fnr0 = false_negative_rate(y_true[mask0], y_pred[mask0])
    fnr1 = false_negative_rate(y_true[mask1], y_pred[mask1])
    
    # Calculate gap (difference between groups)
    gap = fnr0 - fnr1
    
    # Store results in the provided dictionary
    results[f'{prefix}_S0'] = fnr0
    results[f'{prefix}_S1'] = fnr1
    results[f'{prefix}_gap'] = gap


def fairness_correction(model, X_train, y_train, sensitive_features):
    """
    Apply fairness correction to a trained model using ThresholdOptimizer.
    
    Uses Fairlearn's ThresholdOptimizer to post-process model predictions
    and enforce false negative rate parity between demographic groups.
    
    Args:
        model: Trained classifier (must have predict_proba method)
        X_train (array-like): Training features
        y_train (array-like): Training labels
        sensitive_features (array-like): Binary sensitive attribute (0/1)
        
    Returns:
        ThresholdOptimizer: Fairness-corrected model
    """
    post = ThresholdOptimizer(
        estimator=model,
        constraints="false_negative_rate_parity",
        predict_method="predict_proba",
        prefit=True
    )
    post.fit(X_train, y_train, sensitive_features=sensitive_features)
    return post


def compute_descriptive_stats(df):
    """
    Calculate descriptive statistics for the dataset.
    
    Computes group sizes, intersections, repayment rates, and 
    tests for independence between variables.
    
    Args:
        df (pandas.DataFrame): DataFrame containing loan data
        
    Returns:
        dict: Dictionary of descriptive statistics
    """
    results = {}
    
    # Identify columns for group analysis
    split_col = 'A__term__is_ 36 months'  # A (loan term)
    sensitive_col = 'A__home_ownership__is_MORTGAGE'  # S (mortgage status)
    
    # Calculate total counts
    total = len(df)
    s0_count = sum(df[sensitive_col] == 0)  # No mortgage
    s1_count = sum(df[sensitive_col] == 1)  # Has mortgage
    a0_count = sum(df[split_col] == 0)  # Not 36 months
    a1_count = sum(df[split_col] == 1)  # 36 months
    
    # Store overall counts and percentages
    results['total'] = total
    results['s0_count'] = s0_count
    results['s1_count'] = s1_count
    results['s0_percentage'] = s0_count / total * 100
    results['s1_percentage'] = s1_count / total * 100
    results['a0_count'] = a0_count
    results['a1_count'] = a1_count
    results['a0_percentage'] = a0_count / total * 100
    results['a1_percentage'] = a1_count / total * 100
    
    # Calculate intersections between groups
    s0_a0_count = sum((df[sensitive_col] == 0) & (df[split_col] == 0))
    s0_a1_count = sum((df[sensitive_col] == 0) & (df[split_col] == 1))
    s1_a0_count = sum((df[sensitive_col] == 1) & (df[split_col] == 0))
    s1_a1_count = sum((df[sensitive_col] == 1) & (df[split_col] == 1))
    
    results['s0_a0_count'] = s0_a0_count
    results['s0_a1_count'] = s0_a1_count
    results['s1_a0_count'] = s1_a0_count
    results['s1_a1_count'] = s1_a1_count
    
    # Calculate repayment rates by group
    # y=1 means loan was fully paid
    y_overall = df['y'].mean()
    y_s0 = df.loc[df[sensitive_col] == 0, 'y'].mean()
    y_s1 = df.loc[df[sensitive_col] == 1, 'y'].mean()
    y_a0 = df.loc[df[split_col] == 0, 'y'].mean()
    y_a1 = df.loc[df[split_col] == 1, 'y'].mean()
    
    results['repayment_overall'] = y_overall
    results['repayment_s0'] = y_s0  # No mortgage
    results['repayment_s1'] = y_s1  # Has mortgage
    results['repayment_a0'] = y_a0  # Not 36 months
    results['repayment_a1'] = y_a1  # 36 months
    
    # Calculate conditional repayment rates (intersections)
    y_s0_a0 = df.loc[(df[sensitive_col] == 0) & (df[split_col] == 0), 'y'].mean()
    y_s0_a1 = df.loc[(df[sensitive_col] == 0) & (df[split_col] == 1), 'y'].mean()
    y_s1_a0 = df.loc[(df[sensitive_col] == 1) & (df[split_col] == 0), 'y'].mean()
    y_s1_a1 = df.loc[(df[sensitive_col] == 1) & (df[split_col] == 1), 'y'].mean()
    
    results['repayment_s0_a0'] = y_s0_a0
    results['repayment_s0_a1'] = y_s0_a1
    results['repayment_s1_a0'] = y_s1_a0
    results['repayment_s1_a1'] = y_s1_a1
    
    # Test for independence: Does repayment probability depend on group intersections?
    s0_diff = y_s0_a1 - y_s0_a0  # Difference in repayment rates for non-mortgage by term
    s1_diff = y_s1_a1 - y_s1_a0  # Difference in repayment rates for mortgage by term
    
    results['repayment_s0_term_diff'] = s0_diff
    results['repayment_s1_term_diff'] = s1_diff
    
    return results


def run_exp1(X_train, y_train, X_test, y_test, sens_train, sens_test, seed):
    """
    Run experiment 1: Two different models, each with fairness correction.
    
    Trains logistic regression and decision tree models, applies fairness
    correction, and evaluates their raw and fair performance.
    
    Args:
        X_train (array-like): Training features
        y_train (array-like): Training labels
        X_test (array-like): Test features
        y_test (array-like): Test labels
        sens_train (array-like): Training sensitive attributes
        sens_test (array-like): Test sensitive attributes
        seed (int): Random seed for reproducibility
        
    Returns:
        tuple: ((p1r, p2r), (p1f, p2f))
            - p1r: Raw predictions from model 1
            - p2r: Raw predictions from model 2
            - p1f: Fair predictions from model 1
            - p2f: Fair predictions from model 2
    """
    # Train logistic regression and decision tree models
    m1 = LogisticRegression(max_iter=1000, random_state=seed).fit(X_train, y_train)
    m2 = DecisionTreeClassifier(max_depth=5, random_state=seed).fit(X_train, y_train)
    
    # Generate raw predictions
    p1r = m1.predict(X_test)
    p2r = m2.predict(X_test)
    
    # Apply fairness correction
    post1 = fairness_correction(m1, X_train, y_train, sens_train)
    post2 = fairness_correction(m2, X_train, y_train, sens_train)
    
    # Generate fair predictions
    p1f = post1.predict(X_test, sensitive_features=sens_test)
    p2f = post2.predict(X_test, sensitive_features=sens_test)
    
    return (p1r, p2r), (p1f, p2f)


def run_exp2(X_full, y_full, X_test, y_test, sens_train, sens_test, A_train, seed):
    """
    Run experiment 2: Two models trained on different data splits based on attribute A.
    
    Trains separate logistic regression models on A=0 and A=1 subgroups,
    applies fairness correction to each, and evaluates performance.
    
    Args:
        X_full (array-like): Full training features
        y_full (array-like): Full training labels
        X_test (array-like): Test features
        y_test (array-like): Test labels
        sens_train (array-like): Training sensitive attributes
        sens_test (array-like): Test sensitive attributes
        A_train (array-like): Training split attribute
        seed (int): Random seed for reproducibility
        
    Returns:
        tuple: ((p1r, p2r), (p1f, p2f))
            - p1r: Raw predictions from model trained on A=1
            - p2r: Raw predictions from model trained on A=0
            - p1f: Fair predictions from model trained on A=1
            - p2f: Fair predictions from model trained on A=0
    """
    # Split data based on attribute A
    X1 = X_full[A_train == 1]
    y1 = y_full[A_train == 1]
    X0 = X_full[A_train == 0]
    y0 = y_full[A_train == 0]
    
    # Train separate models on each subgroup
    m1 = LogisticRegression(max_iter=1000, random_state=seed).fit(X1, y1)
    m2 = LogisticRegression(max_iter=1000, random_state=seed).fit(X0, y0)
    
    # Generate raw predictions
    p1r = m1.predict(X_test)
    p2r = m2.predict(X_test)
    
    # Apply fairness correction to each model
    post1 = fairness_correction(m1, X1, y1, sens_train[A_train == 1])
    post2 = fairness_correction(m2, X0, y0, sens_train[A_train == 0])
    
    # Generate fair predictions
    p1f = post1.predict(X_test, sensitive_features=sens_test)
    p2f = post2.predict(X_test, sensitive_features=sens_test)
    
    return (p1r, p2r), (p1f, p2f)



def run_exp3(X_full, y_full, X_test, y_test, sens_train, sens_test, A_train, seed):
    # Split data based on attribute A
    X1 = X_full[A_train == 1]
    y1 = y_full[A_train == 1]
    X0 = X_full[A_train == 0]
    y0 = y_full[A_train == 0]
    
    # Train separate models on each subgroup
    m1 = LogisticRegression(max_iter=1000, random_state=seed).fit(X1, y1)
    m2 = DecisionTreeClassifier(max_depth=5, random_state=seed).fit(X0, y0)
    
    # Generate raw predictions
    p1r = m1.predict(X_test)
    p2r = m2.predict(X_test)
    
    # Apply fairness correction to each model
    post1 = fairness_correction(m1, X1, y1, sens_train[A_train == 1])
    post2 = fairness_correction(m2, X0, y0, sens_train[A_train == 0])
    
    # Generate fair predictions
    p1f = post1.predict(X_test, sensitive_features=sens_test)
    p2f = post2.predict(X_test, sensitive_features=sens_test)
    
    return (p1r, p2r), (p1f, p2f)



def run_exp4(X_full, y_full, X_test, y_test, sens_train, sens_test, A_train, A_test, rng):
    """
    Run modified experiment 4: Compare FNR gap before and after fairness correction.
    
    Trains one model on all data with fairness correction and one model on mortgage holders.
    Competition occurs only on mortgage holders. Compares FNR gap between non-mortgage and
    mortgage competition before and after fairness correction.
    
    Args:
        X_full (array-like): Full training features
        y_full (array-like): Full training labels
        X_test (array-like): Test features
        y_test (array-like): Test labels
        sens_train (array-like): Training sensitive attributes
        sens_test (array-like): Test sensitive attributes
        A_train (array-like): Training split attribute (mortgage status)
        A_test (array-like): Test split attribute (mortgage status)
        rng (RandomState): Random number generator
        
    Returns:
        tuple: (p_or_raw, p_or_fair, fnr_nonmort_raw, fnr_mort_raw, fnr_nonmort_fair, fnr_mort_fair, fnr_gap_raw, fnr_gap_fair)
            - p_or_raw: Raw OR-combined predictions
            - p_or_fair: Fair OR-combined predictions
            - fnr_nonmort_raw: Raw FNR for non-mortgage holders
            - fnr_mort_raw: Raw FNR for mortgage holders (with competition)
            - fnr_nonmort_fair: Fair FNR for non-mortgage holders
            - fnr_mort_fair: Fair FNR for mortgage holders (with competition)
            - fnr_gap_raw: Raw FNR gap between groups
            - fnr_gap_fair: Fair FNR gap between groups
    """
    # Train classifier on ALL data
    m_all = LogisticRegression(max_iter=1000, random_state=rng.randint(0, 1e9)).fit(X_full, y_full)
    
    # Apply fairness correction to the classifier trained on all data
    post_all = ThresholdOptimizer(
        estimator=m_all,
        constraints="false_negative_rate_parity",
        predict_method="predict_proba",
        prefit=True
    )
    post_all.fit(X_full, y_full, sensitive_features=sens_train)
    
    # Raw predictions from all-data classifier
    p_all_raw = m_all.predict(X_test)
    
    # Fair predictions from all-data classifier
    p_all_fair = post_all.predict(X_test, sensitive_features=sens_test)
    
    # Train classifier only on mortgage holders
    X_mort = X_full[A_train == 1]
    y_mort = y_full[A_train == 1]
    m_mort = LogisticRegression(max_iter=1000, random_state=rng.randint(0, 1e9)).fit(X_mort, y_mort)
    
    # Predictions from mortgage-only classifier (only applied to mortgage holders)
    p_mort = np.zeros_like(y_test.values)
    mask_mort_test = A_test.values == 1
    p_mort[mask_mort_test] = m_mort.predict(X_test.loc[mask_mort_test])
    
    # Competition only happens on mortgage holders
    # OR combination for raw classifier
    p_or_raw = np.copy(p_all_raw)  # Non-mortgage: just all-data classifier
    p_or_raw[mask_mort_test] = p_all_raw[mask_mort_test] + p_mort[mask_mort_test] - p_all_raw[mask_mort_test] * p_mort[mask_mort_test]
    
    # OR combination for fair classifier
    p_or_fair = np.copy(p_all_fair)  # Non-mortgage: just all-data classifier with fairness
    p_or_fair[mask_mort_test] = p_all_fair[mask_mort_test] + p_mort[mask_mort_test] - p_all_fair[mask_mort_test] * p_mort[mask_mort_test]
    
    # Calculate FNRs by group
    mask_nonmort_test = A_test.values == 0
    
    # FNR for non-mortgage group
    fnr_nonmort_raw = false_negative_rate(y_test.values[mask_nonmort_test], p_or_raw[mask_nonmort_test])
    fnr_nonmort_fair = false_negative_rate(y_test.values[mask_nonmort_test], p_or_fair[mask_nonmort_test])
    
    # FNR for mortgage group (with competition)
    fnr_mort_raw = false_negative_rate(y_test.values[mask_mort_test], p_or_raw[mask_mort_test])
    fnr_mort_fair = false_negative_rate(y_test.values[mask_mort_test], p_or_fair[mask_mort_test])
    
    # Calculate gaps between groups
    fnr_gap_raw = fnr_nonmort_raw - fnr_mort_raw
    fnr_gap_fair = fnr_nonmort_fair - fnr_mort_fair
    
    return p_or_raw, p_or_fair, fnr_nonmort_raw, fnr_mort_raw, fnr_nonmort_fair, fnr_mort_fair, fnr_gap_raw, fnr_gap_fair