import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder

# Fairness metrics
def demographic_parity_difference(y_true, y_pred, sensitive_feature):
    """Difference in positive prediction rates between groups"""
    groups = np.unique(sensitive_feature)
    rates = []
    for g in groups:
        mask = sensitive_feature == g
        if np.sum(mask) > 0:
            rates.append(np.mean(y_pred[mask]))
    if len(rates) >= 2:
        return abs(rates[0] - rates[1])
    return 0

def equalized_odds_difference(y_true, y_pred, sensitive_feature):
    """Difference in TPR and FPR between groups"""
    groups = np.unique(sensitive_feature)
    if len(groups) < 2:
        return 0
    
    tpr_list = []
    fpr_list = []
    
    for g in groups:
        mask = sensitive_feature == g
        y_t = y_true[mask]
        y_p = y_pred[mask]
        
        if len(y_t) == 0:
            continue
            
        # TPR
        pos_mask = y_t == 1
        if np.sum(pos_mask) > 0:
            tpr = np.sum((y_p[pos_mask] == 1)) / np.sum(pos_mask)
        else:
            tpr = 0
        tpr_list.append(tpr)
        
        # FPR
        neg_mask = y_t == 0
        if np.sum(neg_mask) > 0:
            fpr = np.sum((y_p[neg_mask] == 1)) / np.sum(neg_mask)
        else:
            fpr = 0
        fpr_list.append(fpr)
    
    if len(tpr_list) >= 2 and len(fpr_list) >= 2:
        return abs(tpr_list[0] - tpr_list[1]) + abs(fpr_list[0] - fpr_list[1])
    return 0

def predictive_parity_difference(y_true, y_pred, sensitive_feature):
    """Difference in PPV between groups"""
    groups = np.unique(sensitive_feature)
    if len(groups) < 2:
        return 0
    
    ppv_list = []
    
    for g in groups:
        mask = sensitive_feature == g
        y_t = y_true[mask]
        y_p = y_pred[mask]
        
        if len(y_p) == 0:
            continue
            
        pred_pos = y_p == 1
        if np.sum(pred_pos) > 0:
            ppv = np.sum((y_p[pred_pos] == 1) & (y_t[pred_pos] == 1)) / np.sum(pred_pos)
        else:
            ppv = 0
        ppv_list.append(ppv)
    
    if len(ppv_list) >= 2:
        return abs(ppv_list[0] - ppv_list[1])
    return 0

# Load and preprocess Adult dataset
def load_adult_data():
    """Load Adult Income dataset from UCI"""
    from sklearn.datasets import fetch_openml
    
    # Fetch Adult dataset
    data = fetch_openml('adult', version=2, as_frame=True, parser='auto')
    X = data.data
    y = (data.target == '>50K').astype(int)
    
    # Use sex as sensitive attribute
    sensitive = (X['sex'] == 'Male').astype(int)
    
    # Select relevant features
    numeric_features = ['age', 'education-num', 'hours-per-week', 'capital-gain', 'capital-loss']
    categorical_features = ['workclass', 'education', 'marital-status', 'occupation', 
                           'relationship', 'race', 'native-country']
    
    # Keep only features we'll use
    X = X[numeric_features + categorical_features]
    
    # Handle missing values
    X = X.fillna(X.mode().iloc[0])
    
    # Preprocessing
    preprocessor = ColumnTransformer(
        transformers=[
            ('num', StandardScaler(), numeric_features),
            ('cat', OneHotEncoder(drop='first', sparse_output=False, handle_unknown='ignore'), 
             categorical_features)
        ])
    
    X_processed = preprocessor.fit_transform(X)
    
    return X_processed, y.values, sensitive.values, preprocessor

# Train models with different fairness objectives
def train_fairness_models_adult(X_train, y_train, sensitive_train):
    """Train 5 models with different fairness-accuracy tradeoffs"""
    models = {}
    
    # Model 1: Pure accuracy (standard ML)
    clf1 = LogisticRegression(max_iter=1000, random_state=42, solver='lbfgs')
    clf1.fit(X_train, y_train)
    models['Accuracy'] = clf1
    
    # Model 2: Demographic parity via reweighting
    # Upweight minority group to equalize selection rates
    weights_dp = np.ones(len(y_train))
    minority_mask = sensitive_train == 0  # Female is minority in Adult
    weights_dp[minority_mask] = 2.5
    clf2 = LogisticRegression(max_iter=1000, random_state=43, solver='lbfgs', C=1.0)
    clf2.fit(X_train, y_train, sample_weight=weights_dp)
    models['Demographic_Parity'] = clf2
    
    # Model 3: Equalized odds (balance errors across groups)
    # Use different regularization
    clf3 = LogisticRegression(max_iter=1000, random_state=44, solver='lbfgs', C=0.1)
    clf3.fit(X_train, y_train)
    models['Equalized_Odds'] = clf3
    
    # Model 4: Predictive parity focus
    # Use balanced class weights
    clf4 = LogisticRegression(max_iter=1000, random_state=45, solver='lbfgs', 
                              class_weight='balanced')
    clf4.fit(X_train, y_train)
    models['Predictive_Parity'] = clf4
    
    # Model 5: Moderate fairness-accuracy tradeoff
    weights_balanced = np.ones(len(y_train))
    weights_balanced[minority_mask] = 1.5
    clf5 = LogisticRegression(max_iter=1000, random_state=46, solver='lbfgs', C=0.5)
    clf5.fit(X_train, y_train, sample_weight=weights_balanced)
    models['Balanced'] = clf5
    
    return models

# Create preference profile
def create_fairness_preference_profile(models, X_test, y_test, sensitive_test):
    """Create preference profile from stakeholder perspectives"""
    model_names = list(models.keys())
    k = len(model_names)
    
    # Get predictions
    predictions = {}
    for name, model in models.items():
        predictions[name] = model.predict(X_test)
    
    # Define stakeholders with different priorities
    stakeholders = {
        'Business': lambda acc, dp, eo, pp: acc,  # Only cares about accuracy
        'Civil_Rights': lambda acc, dp, eo, pp: -dp - 0.5*eo,  # Focus on group fairness
        'Equal_Opportunity': lambda acc, dp, eo, pp: -eo + 0.3*acc,  # Balance opportunity
        'Calibration': lambda acc, dp, eo, pp: -pp + 0.2*acc,  # Care about predictive parity
        'Regulator': lambda acc, dp, eo, pp: 0.6*acc - 0.2*dp - 0.2*eo  # Balanced view
    }
    
    # Compute utility for each model from each stakeholder
    scores_matrix = np.zeros((len(stakeholders), k))
    
    for i, (stakeholder_name, utility_fn) in enumerate(stakeholders.items()):
        for j, model_name in enumerate(model_names):
            pred = predictions[model_name]
            acc = accuracy_score(y_test, pred)
            dp = demographic_parity_difference(y_test, pred, sensitive_test)
            eo = equalized_odds_difference(y_test, pred, sensitive_test)
            pp = predictive_parity_difference(y_test, pred, sensitive_test)
            
            scores_matrix[i, j] = utility_fn(acc, dp, eo, pp)
    
    # Pairwise Majority Vote aggregation
    preference_matrix = np.zeros((k, k))
    
    for i in range(k):
        for j in range(k):
            if i != j:
                votes_for_i = np.sum(scores_matrix[:, i] > scores_matrix[:, j])
                votes_for_j = np.sum(scores_matrix[:, j] > scores_matrix[:, i])
                preference_matrix[i, j] = (votes_for_i - votes_for_j) / len(stakeholders)
    
    return preference_matrix, model_names, scores_matrix, stakeholders

# Hodge decomposition
def hodge_decomposition(preference_matrix):
    """Decompose preference flow into gradient + cycle components"""
    k = preference_matrix.shape[0]
    
    edges = []
    for i in range(k):
        for j in range(i+1, k):
            edges.append((i, j))
    
    n_edges = len(edges)
    
    # Coboundary operator
    d0 = np.zeros((n_edges, k))
    for edge_idx, (i, j) in enumerate(edges):
        d0[edge_idx, i] = -1
        d0[edge_idx, j] = 1
    
    # Convert to flow vector
    flow = np.zeros(n_edges)
    for edge_idx, (i, j) in enumerate(edges):
        flow[edge_idx] = preference_matrix[i, j]
    
    try:
        potential, residuals, rank, s = np.linalg.lstsq(d0, flow, rcond=None)
        gradient_flow_vec = d0 @ potential
        cycle_flow_vec = flow - gradient_flow_vec
        
        cycle_norm = np.linalg.norm(cycle_flow_vec)
        total_norm = np.linalg.norm(flow)
        
        # Reconstruct matrices
        gradient_matrix = np.zeros((k, k))
        cycle_matrix = np.zeros((k, k))
        
        for edge_idx, (i, j) in enumerate(edges):
            gradient_matrix[i, j] = gradient_flow_vec[edge_idx]
            gradient_matrix[j, i] = -gradient_flow_vec[edge_idx]
            cycle_matrix[i, j] = cycle_flow_vec[edge_idx]
            cycle_matrix[j, i] = -cycle_flow_vec[edge_idx]
        
        return cycle_matrix, gradient_matrix, cycle_norm, total_norm
    except Exception as e:
        print(f"Decomposition error: {e}")
        return np.zeros((k,k)), np.zeros((k,k)), 0.0, 0.0

# Main experiment
def run_adult_fairness_experiment(n_trials=10):
    """Run fairness experiment on Adult dataset"""
    
    print("Loading Adult dataset...")
    X, y, sensitive, preprocessor = load_adult_data()
    print(f"Dataset loaded: {X.shape[0]} samples, {X.shape[1]} features")
    print(f"Sensitive attribute (sex): {np.sum(sensitive==1)} male, {np.sum(sensitive==0)} female")
    print(f"Target (income): {np.sum(y==1)} high income, {np.sum(y==0)} low income\n")
    
    results = []
    
    for trial in range(n_trials):
        print(f"Trial {trial + 1}/{n_trials}")
        
        # Split data
        X_train, X_test, y_train, y_test, s_train, s_test = train_test_split(
            X, y, sensitive, test_size=0.3, random_state=trial+42, stratify=y
        )
        
        # Train models
        models = train_fairness_models_adult(X_train, y_train, s_train)
        
        # Create preference profile
        pref_matrix, model_names, scores_matrix, stakeholders = \
            create_fairness_preference_profile(models, X_test, y_test, s_test)
        
        # Hodge decomposition
        cycle_mat, grad_mat, cycle_norm, total_norm = hodge_decomposition(pref_matrix)
        
        cycle_ratio = cycle_norm / total_norm if total_norm > 1e-10 else 0
        
        results.append({
            'trial': trial,
            'cycle_norm': cycle_norm,
            'total_norm': total_norm,
            'cycle_ratio': cycle_ratio
        })
        
        # Detailed output for first trial
        if trial == 0:
            print(f"\n{'='*60}")
            print("DETAILED ANALYSIS - TRIAL 1")
            print(f"{'='*60}")
            print(f"\nModel names: {model_names}")
            print(f"\nStakeholder utility scores:")
            print("(rows=stakeholders, cols=models)")
            stakeholder_names = list(stakeholders.keys())
            df_scores = pd.DataFrame(scores_matrix, 
                                    index=stakeholder_names, 
                                    columns=model_names)
            print(df_scores.round(3))
            
            print(f"\nPreference matrix (PMV):")
            df_pref = pd.DataFrame(pref_matrix, 
                                  index=model_names, 
                                  columns=model_names)
            print(df_pref.round(3))
            
            print(f"\nHodge Decomposition Results:")
            print(f"  Cycle norm: {cycle_norm:.6f}")
            print(f"  Total norm: {total_norm:.6f}")
            print(f"  Cycle ratio: {cycle_ratio:.4f} ({cycle_ratio*100:.1f}%)")
            print(f"{'='*60}\n")
    
    return pd.DataFrame(results)

# Run experiment
if __name__ == "__main__":
    results_df = run_adult_fairness_experiment(n_trials=20)
    
    print("\n" + "="*60)
    print("FINAL RESULTS SUMMARY")
    print("="*60)
    print(f"Mean cycle norm: {results_df['cycle_norm'].mean():.6f} ± {results_df['cycle_norm'].std():.6f}")
    print(f"Mean cycle ratio: {results_df['cycle_ratio'].mean():.4f} ± {results_df['cycle_ratio'].std():.4f}")
    print(f"Min cycle ratio: {results_df['cycle_ratio'].min():.4f}")
    print(f"Max cycle ratio: {results_df['cycle_ratio'].max():.4f}")
    
    print(f"\nComparison with Experiment 2 (Standard Classification):")
    print(f"  Standard tasks: cycle_norm ~ 10^-16")
    print(f"  Fairness tasks: cycle_norm ~ {results_df['cycle_norm'].mean():.2e}")
    print(f"  Ratio: {results_df['cycle_norm'].mean() / 1e-16:.2e}x larger")
    
    print(f"\nInterpretation:")
    print(f"  {results_df['cycle_ratio'].mean()*100:.1f}% of preference structure is cyclical")
    print("="*60)
