import pandas as pd
import numpy as np
import argparse
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from fairlearn.metrics import false_negative_rate
import os

# Import utility functions from utils.py
from utils import (
    compute_metrics, 
    compute_metrics_with_correlation,
    compute_group_fnr_gap, 
    fairness_correction, 
    compute_descriptive_stats,
    run_exp1, 
    run_exp2, 
    run_exp3,  
    run_exp4
)

# =============================
# Main Script
# =============================

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--exp', required=True, choices=['exp1','exp2','exp3','exp4', 'exp12', 'stats'])
    parser.add_argument('--train_size', type=int, required=True)
    parser.add_argument('--seed', type=int, required=True)
    parser.add_argument('--output_dir', default="results")
    args = parser.parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    rng = np.random.RandomState(args.seed)

    # Load data
    df = pd.read_csv("loan_extended_with_A_candidates.csv", low_memory=False, nrows=500000)
    df = df[df['loan_status'].isin(['Fully Paid','Charged Off'])]
    df = df.dropna(subset=['annual_inc','loan_amnt','int_rate','installment'])
    df['y'] = (df['loan_status']=='Fully Paid').astype(int)

    features = ['loan_amnt','int_rate','installment','annual_inc']
    split_col = 'A__term__is_ 36 months'
    sensitive_col = 'A__home_ownership__is_MORTGAGE'
    X_all = df[features]
    y_all = df['y']
    A_all = df[split_col]
    S_all = df[sensitive_col]

    # Compute and save descriptive statistics if requested
    if args.exp == 'stats':
        stats = compute_descriptive_stats(df)
        stats_df = pd.DataFrame([stats]).T
        stats_df.to_csv(os.path.join(args.output_dir, f"descriptive_stats.csv"), index=True)
        print(f"Saved descriptive statistics to {args.output_dir}/descriptive_stats.csv")
        exit(0)

    # Split into train and test sets
    test_idx = df.index[:100000]
    train_pool_idx = df.index[100000:]

    X_test = X_all.loc[test_idx]
    y_test = y_all.loc[test_idx]
    A_test = A_all.loc[test_idx]
    S_test = S_all.loc[test_idx]

    train_idx = rng.choice(train_pool_idx, size=args.train_size, replace=False)
    X_train = X_all.loc[train_idx]
    y_train = y_all.loc[train_idx]
    A_train = A_all.loc[train_idx]
    S_train = S_all.loc[train_idx]

    results = {'exp': args.exp, 'train_size': args.train_size, 'seed': args.seed}

    # Run the specified experiment
    if args.exp == 'exp1':
        (r1,r2),(f1,f2) = run_exp1(X_train, y_train, X_test, y_test, S_train, S_test, args.seed)
        
        # Calculate FNR gaps
        compute_group_fnr_gap(y_test.values, r1, S_test.values, 'fnr1_raw', results)
        compute_group_fnr_gap(y_test.values, f1, S_test.values, 'fnr1_fair', results)
        compute_group_fnr_gap(y_test.values, r2, S_test.values, 'fnr2_raw', results)
        compute_group_fnr_gap(y_test.values, f2, S_test.values, 'fnr2_fair', results)
        
        # Calculate OR combinations
        or_raw = r1 + r2 - r1 * r2
        or_fair = f1 + f2 - f1 * f2
        compute_group_fnr_gap(y_test.values, or_raw, S_test.values, 'fnr_or_raw', results)
        compute_group_fnr_gap(y_test.values, or_fair, S_test.values, 'fnr_or_fair', results)
        
        # Explicitly calculate raw-fair gaps for analysis
        results['fnr1_rawMinusFair_gap'] = results['fnr1_raw_gap'] - results['fnr1_fair_gap']
        results['fnr2_rawMinusFair_gap'] = results['fnr2_raw_gap'] - results['fnr2_fair_gap']
        results['fnr_or_rawMinusFair_gap'] = results['fnr_or_raw_gap'] - results['fnr_or_fair_gap']
        
        # Calculate correlations between classifiers
        _, _, _, fnr_corr_raw, fnr_corr_raw_s0, fnr_corr_raw_s1 = compute_metrics_with_correlation(
            y_test.values, r1, r2, S_test.values)
        _, _, _, fnr_corr_fair, fnr_corr_fair_s0, fnr_corr_fair_s1 = compute_metrics_with_correlation(
            y_test.values, f1, f2, S_test.values)
        
        results['fnr_corr_raw'] = fnr_corr_raw
        results['fnr_corr_raw_s0'] = fnr_corr_raw_s0  # Correlation for non-mortgage
        results['fnr_corr_raw_s1'] = fnr_corr_raw_s1  # Correlation for mortgage
        results['fnr_corr_fair'] = fnr_corr_fair
        results['fnr_corr_fair_s0'] = fnr_corr_fair_s0  # Correlation for non-mortgage
        results['fnr_corr_fair_s1'] = fnr_corr_fair_s1  # Correlation for mortgage

    elif args.exp == 'exp2':
        (r1,r2),(f1,f2) = run_exp2(X_train, y_train, X_test, y_test, S_train, S_test, A_train, args.seed)
        
        # Calculate FNR gaps
        compute_group_fnr_gap(y_test.values, r1, S_test.values, 'fnr1_raw', results)
        compute_group_fnr_gap(y_test.values, f1, S_test.values, 'fnr1_fair', results)
        compute_group_fnr_gap(y_test.values, r2, S_test.values, 'fnr2_raw', results)
        compute_group_fnr_gap(y_test.values, f2, S_test.values, 'fnr2_fair', results)
        
        # Calculate OR combinations
        or_raw = r1 + r2 - r1 * r2
        or_fair = f1 + f2 - f1 * f2
        compute_group_fnr_gap(y_test.values, or_raw, S_test.values, 'fnr_or_raw', results)
        compute_group_fnr_gap(y_test.values, or_fair, S_test.values, 'fnr_or_fair', results)
        
        # Explicitly calculate raw-fair gaps for analysis
        results['fnr1_rawMinusFair_gap'] = results['fnr1_raw_gap'] - results['fnr1_fair_gap']
        results['fnr2_rawMinusFair_gap'] = results['fnr2_raw_gap'] - results['fnr2_fair_gap']
        results['fnr_or_rawMinusFair_gap'] = results['fnr_or_raw_gap'] - results['fnr_or_fair_gap']
        
        # Calculate correlations between classifiers
        _, _, _, fnr_corr_raw, fnr_corr_raw_s0, fnr_corr_raw_s1 = compute_metrics_with_correlation(
            y_test.values, r1, r2, S_test.values)
        _, _, _, fnr_corr_fair, fnr_corr_fair_s0, fnr_corr_fair_s1 = compute_metrics_with_correlation(
            y_test.values, f1, f2, S_test.values)
        
        results['fnr_corr_raw'] = fnr_corr_raw
        results['fnr_corr_raw_s0'] = fnr_corr_raw_s0  # Correlation for non-mortgage
        results['fnr_corr_raw_s1'] = fnr_corr_raw_s1  # Correlation for mortgage
        results['fnr_corr_fair'] = fnr_corr_fair
        results['fnr_corr_fair_s0'] = fnr_corr_fair_s0  # Correlation for non-mortgage
        results['fnr_corr_fair_s1'] = fnr_corr_fair_s1  # Correlation for mortgage
        
        
    elif args.exp == 'exp3':
        (r1,r2),(f1,f2) = run_exp3(X_train, y_train, X_test, y_test, S_train, S_test, A_train, args.seed)
        
        # Calculate FNR gaps
        compute_group_fnr_gap(y_test.values, r1, S_test.values, 'fnr1_raw', results)
        compute_group_fnr_gap(y_test.values, f1, S_test.values, 'fnr1_fair', results)
        compute_group_fnr_gap(y_test.values, r2, S_test.values, 'fnr2_raw', results)
        compute_group_fnr_gap(y_test.values, f2, S_test.values, 'fnr2_fair', results)
        
        # Calculate OR combinations
        or_raw = r1 + r2 - r1 * r2
        or_fair = f1 + f2 - f1 * f2
        compute_group_fnr_gap(y_test.values, or_raw, S_test.values, 'fnr_or_raw', results)
        compute_group_fnr_gap(y_test.values, or_fair, S_test.values, 'fnr_or_fair', results)
        
        # Explicitly calculate raw-fair gaps for analysis
        results['fnr1_rawMinusFair_gap'] = results['fnr1_raw_gap'] - results['fnr1_fair_gap']
        results['fnr2_rawMinusFair_gap'] = results['fnr2_raw_gap'] - results['fnr2_fair_gap']
        results['fnr_or_rawMinusFair_gap'] = results['fnr_or_raw_gap'] - results['fnr_or_fair_gap']
        
        # Calculate correlations between classifiers
        _, _, _, fnr_corr_raw, fnr_corr_raw_s0, fnr_corr_raw_s1 = compute_metrics_with_correlation(
            y_test.values, r1, r2, S_test.values)
        _, _, _, fnr_corr_fair, fnr_corr_fair_s0, fnr_corr_fair_s1 = compute_metrics_with_correlation(
            y_test.values, f1, f2, S_test.values)
        
        results['fnr_corr_raw'] = fnr_corr_raw
        results['fnr_corr_raw_s0'] = fnr_corr_raw_s0  # Correlation for non-mortgage
        results['fnr_corr_raw_s1'] = fnr_corr_raw_s1  # Correlation for mortgage
        results['fnr_corr_fair'] = fnr_corr_fair
        results['fnr_corr_fair_s0'] = fnr_corr_fair_s0  # Correlation for non-mortgage
        results['fnr_corr_fair_s1'] = fnr_corr_fair_s1  # Correlation for mortgage

    else:  # exp4
        # Using the modified version of experiment 4 as requested
        p_or_raw, p_or_fair, fnr_nonmort_raw, fnr_mort_raw, fnr_nonmort_fair, fnr_mort_fair, fnr_gap_raw, fnr_gap_fair = run_exp4(
            X_train, y_train, X_test, y_test, S_train, S_test, A_train, A_test, rng)
        
        # Store FNR results
        results['fnr_nonmort_raw'] = fnr_nonmort_raw
        results['fnr_mort_raw'] = fnr_mort_raw
        results['fnr_gap_raw'] = fnr_gap_raw
        results['fnr_nonmort_fair'] = fnr_nonmort_fair
        results['fnr_mort_fair'] = fnr_mort_fair
        results['fnr_gap_fair'] = fnr_gap_fair
        results['fnr_gap_improvement'] = fnr_gap_raw - fnr_gap_fair
        
        # Calculate FNR by sensitive feature within each group
        mask_nonmort = A_test.values == 0
        mask_mort = A_test.values == 1
        
        # For non-mortgage holders (raw)
        nonmort_s0 = (S_test.values == 0) & mask_nonmort
        nonmort_s1 = (S_test.values == 1) & mask_nonmort
        if np.sum(nonmort_s0) > 0:
            results['fnr_nonmort_raw_s0'] = false_negative_rate(y_test.values[nonmort_s0], p_or_raw[nonmort_s0])
        if np.sum(nonmort_s1) > 0:
            results['fnr_nonmort_raw_s1'] = false_negative_rate(y_test.values[nonmort_s1], p_or_raw[nonmort_s1])
        
        # For mortgage holders (raw)
        mort_s0 = (S_test.values == 0) & mask_mort
        mort_s1 = (S_test.values == 1) & mask_mort
        if np.sum(mort_s0) > 0:
            results['fnr_mort_raw_s0'] = false_negative_rate(y_test.values[mort_s0], p_or_raw[mort_s0])
        if np.sum(mort_s1) > 0:
            results['fnr_mort_raw_s1'] = false_negative_rate(y_test.values[mort_s1], p_or_raw[mort_s1])
        
        # For non-mortgage holders (fair)
        if np.sum(nonmort_s0) > 0:
            results['fnr_nonmort_fair_s0'] = false_negative_rate(y_test.values[nonmort_s0], p_or_fair[nonmort_s0])
        if np.sum(nonmort_s1) > 0:
            results['fnr_nonmort_fair_s1'] = false_negative_rate(y_test.values[nonmort_s1], p_or_fair[nonmort_s1])
        
        # For mortgage holders (fair)
        if np.sum(mort_s0) > 0:
            results['fnr_mort_fair_s0'] = false_negative_rate(y_test.values[mort_s0], p_or_fair[mort_s0])
        if np.sum(mort_s1) > 0:
            results['fnr_mort_fair_s1'] = false_negative_rate(y_test.values[mort_s1], p_or_fair[mort_s1])

        

    # Save results to CSV
    out_df = pd.DataFrame([results])
    out_df.to_csv(os.path.join(args.output_dir, f"{args.exp}_size{args.train_size}_seed{args.seed}.csv"), index=False)
    print(f"Saved results to {args.output_dir}/{args.exp}_size{args.train_size}_seed{args.seed}.csv")