import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.pipeline import Pipeline
import cvxpy as cp
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

# Utility functions for fairness metrics
def demographic_parity_difference(y_pred, sensitive_features):
    """
    Calculate the demographic parity difference.
    A value closer to 0 indicates better demographic parity.
    """
    privileged_acceptance = np.mean(y_pred[sensitive_features == 0])
    unprivileged_acceptance = np.mean(y_pred[sensitive_features == 1])
    return abs(privileged_acceptance - unprivileged_acceptance)

def equal_opportunity_difference(y_true, y_pred, sensitive_features):
    """
    Calculate the equal opportunity difference.
    A value closer to 0 indicates better equal opportunity.
    """
    privileged_indices = (sensitive_features == 0) & (y_true == 1)
    unprivileged_indices = (sensitive_features == 1) & (y_true == 1)
    
    if sum(privileged_indices) == 0 or sum(unprivileged_indices) == 0:
        return 0.0
    
    privileged_tpr = np.mean(y_pred[privileged_indices])
    unprivileged_tpr = np.mean(y_pred[unprivileged_indices])
    return abs(privileged_tpr - unprivileged_tpr)

def disparate_impact(y_pred, sensitive_features):
    """
    Calculate the disparate impact.
    A value closer to 1 indicates better fairness.
    """
    privileged_acceptance = np.mean(y_pred[sensitive_features == 0])
    unprivileged_acceptance = np.mean(y_pred[sensitive_features == 1])
    
    if unprivileged_acceptance == 0:
        return float('inf')
    
    return privileged_acceptance / unprivileged_acceptance

# 1. Standard Logistic Regression (Baseline)
class StandardModel:
    def __init__(self):
        self.model = LogisticRegression(max_iter=1000)
        
    def fit(self, X, y, sensitive_features=None):
        self.model.fit(X, y)
        
    def predict(self, X):
        return self.model.predict(X)
    
    def predict_proba(self, X):
        return self.model.predict_proba(X)[:, 1]

# 2. Fairness through Unawareness (dropping sensitive attributes)
class FairnessUnawareness:
    def __init__(self):
        self.model = LogisticRegression(max_iter=1000)
        
    def fit(self, X, y, sensitive_features):
        # Drop the sensitive feature (assuming it's in the dataset)
        # In practice, you would need to know which column(s) to drop
        # Here we're just illustrating the concept
        self.model.fit(X, y)
        
    def predict(self, X):
        return self.model.predict(X)
    
    def predict_proba(self, X):
        return self.model.predict_proba(X)[:, 1]

# 3. Reweighing (pre-processing method)
class Reweighing:
    def __init__(self):
        self.model = LogisticRegression(max_iter=1000)
        self.weights = None
        
    def fit(self, X, y, sensitive_features):
        # Calculate instance weights based on sensitive feature and label
        n_samples = len(y)
        n_sensitive_0 = sum(sensitive_features == 0)
        n_sensitive_1 = sum(sensitive_features == 1)
        n_y_1 = sum(y == 1)
        n_y_0 = sum(y == 0)
        
        # Calculate expected and observed probabilities
        expected_s0_y1 = (n_sensitive_0 * n_y_1) / n_samples
        expected_s0_y0 = (n_sensitive_0 * n_y_0) / n_samples
        expected_s1_y1 = (n_sensitive_1 * n_y_1) / n_samples
        expected_s1_y0 = (n_sensitive_1 * n_y_0) / n_samples
        
        observed_s0_y1 = sum((sensitive_features == 0) & (y == 1))
        observed_s0_y0 = sum((sensitive_features == 0) & (y == 0))
        observed_s1_y1 = sum((sensitive_features == 1) & (y == 1))
        observed_s1_y0 = sum((sensitive_features == 1) & (y == 0))
        
        # Calculate weights
        self.weights = np.ones(n_samples)
        
        if observed_s0_y1 > 0:
            self.weights[(sensitive_features == 0) & (y == 1)] = expected_s0_y1 / observed_s0_y1
        if observed_s0_y0 > 0:
            self.weights[(sensitive_features == 0) & (y == 0)] = expected_s0_y0 / observed_s0_y0
        if observed_s1_y1 > 0:
            self.weights[(sensitive_features == 1) & (y == 1)] = expected_s1_y1 / observed_s1_y1
        if observed_s1_y0 > 0:
            self.weights[(sensitive_features == 1) & (y == 0)] = expected_s1_y0 / observed_s1_y0
        
        # Train model with weights
        self.model.fit(X, y, sample_weight=self.weights)
        
    def predict(self, X):
        return self.model.predict(X)
    
    def predict_proba(self, X):
        return self.model.predict_proba(X)[:, 1]

# 4. Distributionally Robust Optimization (DRO) for fairness
class DROFairness:
    def __init__(self, radius=0.1, fairness_constraint=0.05):
        self.model = None
        self.radius = radius  # Wasserstein ball radius
        self.fairness_constraint = fairness_constraint
        self.scaler = StandardScaler()
        
    def fit(self, X, y, sensitive_features):
        # Scale features
        X_scaled = self.scaler.fit_transform(X)
        
        n, d = X_scaled.shape
        
        # Define variables
        w = cp.Variable(d)
        b = cp.Variable()
        t = cp.Variable()  # Worst-case loss
        
        # Linear model predictions
        predictions = X_scaled @ w + b
        
        # Loss function (logistic loss)
        # For practical implementation, we use a convex approximation
        loss = cp.sum(cp.logistic(cp.multiply(-y, predictions))) / n
        
        # Fairness constraint using demographic parity proxy
        group0_indices = (sensitive_features == 0)
        group1_indices = (sensitive_features == 1)
        
        group0_mean_prediction = cp.sum(predictions[group0_indices]) / sum(group0_indices)
        group1_mean_prediction = cp.sum(predictions[group1_indices]) / sum(group1_indices)
        fairness_violation = cp.abs(group0_mean_prediction - group1_mean_prediction)
        
        # DRO regularization (L2 norm of weights as a simple approximation)
        dro_reg = self.radius * cp.norm(w, 2)
        
        # Objective: minimize loss + DRO regularization
        objective = cp.Minimize(loss + dro_reg)
        
        # Constraints
        constraints = [
            fairness_violation <= self.fairness_constraint
        ]
        
        # Solve the problem
        problem = cp.Problem(objective, constraints)
        problem.solve(solver=cp.ECOS)
        
        # Save the model parameters
        self.w = w.value
        self.b = b.value
        
    def predict(self, X):
        X_scaled = self.scaler.transform(X)
        scores = X_scaled @ self.w + self.b
        return (scores > 0).astype(int)
    
    def predict_proba(self, X):
        X_scaled = self.scaler.transform(X)
        scores = X_scaled @ self.w + self.b
        return 1 / (1 + np.exp(-scores))

# 5. Adversarial Debiasing (in-processing method)
class AdversarialDebiasing:
    def __init__(self, epochs=50, batch_size=128, learning_rate=0.001, adversary_lr=0.01):
        self.epochs = epochs
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.adversary_lr = adversary_lr
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
    def fit(self, X, y, sensitive_features):
        X_tensor = torch.FloatTensor(X).to(self.device)
        y_tensor = torch.FloatTensor(y.reshape(-1, 1)).to(self.device)
        s_tensor = torch.FloatTensor(sensitive_features.reshape(-1, 1)).to(self.device)
        
        dataset = TensorDataset(X_tensor, y_tensor, s_tensor)
        dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
        
        # Classifier and adversary networks
        input_dim = X.shape[1]
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()
        ).to(self.device)
        
        self.adversary = nn.Sequential(
            nn.Linear(1, 32),
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Sigmoid()
        ).to(self.device)
        
        classifier_optimizer = optim.Adam(self.classifier.parameters(), lr=self.learning_rate)
        adversary_optimizer = optim.Adam(self.adversary.parameters(), lr=self.adversary_lr)
        
        # Loss functions
        bce_loss = nn.BCELoss()
        
        # Training loop
        for epoch in range(self.epochs):
            for X_batch, y_batch, s_batch in dataloader:
                # Classifier forward pass
                y_pred = self.classifier(X_batch)
                
                # Adversary forward pass
                s_pred = self.adversary(y_pred.detach())
                
                # Compute losses
                classifier_loss = bce_loss(y_pred, y_batch)
                adversary_loss = bce_loss(s_pred, s_batch)
                
                # Update adversary
                adversary_optimizer.zero_grad()
                adversary_loss.backward()
                adversary_optimizer.step()
                
                # Update classifier (minimize prediction loss, maximize adversary loss)
                classifier_optimizer.zero_grad()
                combined_loss = classifier_loss - 0.1 * adversary_loss
                combined_loss.backward()
                classifier_optimizer.step()
                
            # Print progress every 10 epochs
            if (epoch + 1) % 10 == 0:
                print(f"Epoch {epoch+1}/{self.epochs}, Classifier Loss: {classifier_loss.item():.4f}, Adversary Loss: {adversary_loss.item():.4f}")
        
        self.classifier.eval()
        
    def predict(self, X):
        X_tensor = torch.FloatTensor(X).to(self.device)
        with torch.no_grad():
            preds = self.classifier(X_tensor).cpu().numpy()
        return (preds > 0.5).astype(int).flatten()
    
    def predict_proba(self, X):
        X_tensor = torch.FloatTensor(X).to(self.device)
        with torch.no_grad():
            return self.classifier(X_tensor).cpu().numpy().flatten()

# 6. Your Custom Method (Example implementation - replace with your actual method)
class YourFairnessMethod:
    def __init__(self, radius=0.2, fairness_weight=0.3):
        self.model = None
        self.radius = radius  # DRO radius parameter
        self.fairness_weight = fairness_weight  # Weight for fairness constraints
        self.scaler = StandardScaler()
        
    def fit(self, X, y, sensitive_features):
        # Scale features
        X_scaled = self.scaler.fit_transform(X)
        
        n, d = X_scaled.shape
        
        # Define variables for the optimization problem
        w = cp.Variable(d)
        b = cp.Variable()
        
        # Linear model predictions
        predictions = X_scaled @ w + b
        
        # Basic loss function (logistic loss)
        loss = cp.sum(cp.logistic(cp.multiply(-y, predictions))) / n
        
        # Group indices
        group0_indices = (sensitive_features == 0)
        group1_indices = (sensitive_features == 1)
        
        # Calculate multiple fairness metrics as constraints
        # 1. Demographic parity
        group0_mean_pred = cp.sum(predictions[group0_indices]) / sum(group0_indices)
        group1_mean_pred = cp.sum(predictions[group1_indices]) / sum(group1_indices)
        dp_violation = cp.abs(group0_mean_pred - group1_mean_pred)
        
        # 2. Equalized odds proxy (simplified version)
        pos0_indices = group0_indices & (y == 1)
        pos1_indices = group1_indices & (y == 1)
        neg0_indices = group0_indices & (y == 0)
        neg1_indices = group1_indices & (y == 0)
        
        if sum(pos0_indices) > 0 and sum(pos1_indices) > 0:
            pos0_mean_pred = cp.sum(predictions[pos0_indices]) / sum(pos0_indices)
            pos1_mean_pred = cp.sum(predictions[pos1_indices]) / sum(pos1_indices)
            tpr_violation = cp.abs(pos0_mean_pred - pos1_mean_pred)
        else:
            tpr_violation = 0
            
        if sum(neg0_indices) > 0 and sum(neg1_indices) > 0:
            neg0_mean_pred = cp.sum(predictions[neg0_indices]) / sum(neg0_indices)
            neg1_mean_pred = cp.sum(predictions[neg1_indices]) / sum(neg1_indices)
            fpr_violation = cp.abs(neg0_mean_pred - neg1_mean_pred)
        else:
            fpr_violation = 0
        
        # DRO regularization based on Wasserstein distance
        dro_reg = self.radius * cp.norm(w, 2)
        
        # Combined objective with fairness penalties
        objective = cp.Minimize(
            loss + dro_reg + self.fairness_weight * (dp_violation + tpr_violation + fpr_violation)
        )
        
        # Solve the problem
        problem = cp.Problem(objective, [])
        problem.solve(solver=cp.ECOS)
        
        # Save the model parameters
        self.w = w.value
        self.b = b.value
        
    def predict(self, X):
        X_scaled = self.scaler.transform(X)
        scores = X_scaled @ self.w + self.b
        return (scores > 0).astype(int)
    
    def predict_proba(self, X):
        X_scaled = self.scaler.transform(X)
        scores = X_scaled @ self.w + self.b
        return 1 / (1 + np.exp(-scores))

# Experiment Framework for Comparison
def compare_fairness_methods(dataset_name="adult", test_size=0.3, random_state=42):
    """
    Compare different fairness methods on a standard dataset.
    
    Parameters:
    - dataset_name: Name of the dataset to use (default: "adult")
    - test_size: Proportion of data to use for testing
    - random_state: Random seed for reproducibility
    
    Returns:
    - results: DataFrame containing performance metrics for each method
    """
    # Load data (for this example, we'll use a synthetic dataset)
    # In practice, you'd load a real dataset like Adult, COMPAS, etc.
    if dataset_name == "synthetic":
        # Generate synthetic data with bias
        np.random.seed(random_state)
        n_samples = 5000
        n_features = 10
        
        # Feature matrix
        X = np.random.randn(n_samples, n_features)
        
        # Sensitive feature (binary for simplicity)
        sensitive_features = np.random.binomial(1, 0.3, size=n_samples)
        
        # Add bias: make the outcome depend more on the sensitive feature
        y = (0.8 * X[:, 0] + 0.2 * X[:, 1] - 0.5 * sensitive_features + 0.1 * np.random.randn(n_samples) > 0).astype(int)
        
        # Add sensitive feature as last column in X
        X_with_sensitive = np.column_stack([X, sensitive_features])
        
    elif dataset_name == "adult":
        # For demonstration - in practice you'd load the actual Adult dataset
        # The code below is a placeholder - you would replace it with actual data loading
        
        # Placeholder for Adult dataset
        from sklearn.datasets import fetch_openml
        
        try:
            # Try to fetch the Adult dataset
            data = fetch_openml(name='adult', version=2, as_frame=True)
            X = data.data
            y = (data.target == '>50K').astype(int)
            
            # Extract the sensitive feature (sex)
            sensitive_features = (X['sex'] == 'Male').astype(int)
            
            # Drop sensitive features from X
            X = X.drop(columns=['sex', 'race'])
            
            # Convert categorical features to one-hot encoding
            X = pd.get_dummies(X, drop_first=True)
            X = X.fillna(X.mean())
            X = X.values
            
        except Exception as e:
            print(f"Error loading Adult dataset: {e}")
            print("Falling back to synthetic data...")
            return compare_fairness_methods(dataset_name="synthetic", test_size=test_size, random_state=random_state)
    
    else:
        print(f"Dataset {dataset_name} not supported. Using synthetic data.")
        return compare_fairness_methods(dataset_name="synthetic", test_size=test_size, random_state=random_state)
    
    # Split data
    X_train, X_test, y_train, y_test, s_train, s_test = train_test_split(
        X, y, sensitive_features, test_size=test_size, random_state=random_state
    )
    
    # Initialize models
    models = {
        "Standard LR": StandardModel(),
        "Fairness Unawareness": FairnessUnawareness(),
        "Reweighing": Reweighing(),
        "DRO Fairness": DROFairness(radius=0.1, fairness_constraint=0.05),
        "Adversarial Debiasing": AdversarialDebiasing(epochs=20),  # Reduced epochs for demonstration
        "Your Method": YourFairnessMethod(radius=0.2, fairness_weight=0.3)
    }
    
    # Train and evaluate each model
    results = []
    
    for name, model in models.items():
        print(f"\nTraining {name}...")
        model.fit(X_train, y_train, s_train)
        
        # Make predictions
        y_pred = model.predict(X_test)
        
        # Calculate metrics
        accuracy = accuracy_score(y_test, y_pred)
        dp_diff = demographic_parity_difference(y_pred, s_test)
        eo_diff = equal_opportunity_difference(y_test, y_pred, s_test)
        di = disparate_impact(y_pred, s_test)
        
        # Store results
        results.append({
            "Method": name,
            "Accuracy": accuracy,
            "Demographic Parity Diff": dp_diff,
            "Equal Opportunity Diff": eo_diff,
            "Disparate Impact": di
        })
        
        print(f"  Accuracy: {accuracy:.4f}")
        print(f"  Demographic Parity Difference: {dp_diff:.4f}")
        print(f"  Equal Opportunity Difference: {eo_diff:.4f}")
        print(f"  Disparate Impact: {di:.4f}")
    
    # Convert results to DataFrame
    results_df = pd.DataFrame(results)
    
    # Visualize results
    plot_comparison(results_df)
    
    return results_df

# Function to plot comparison results
def plot_comparison(results):
    """
    Create bar charts comparing the different methods across metrics.
    
    Parameters:
    - results: DataFrame containing performance metrics for each method
    """
    plt.figure(figsize=(15, 12))
    
    # Plot accuracy
    plt.subplot(2, 2, 1)
    methods = results["Method"]
    accuracy = results["Accuracy"]
    plt.bar(methods, accuracy)
    plt.title("Accuracy Comparison")
    plt.xticks(rotation=45, ha="right")
    plt.ylabel("Accuracy")
    plt.ylim(0, 1)
    
    # Plot demographic parity difference
    plt.subplot(2, 2, 2)
    dp_diff = results["Demographic Parity Diff"]
    plt.bar(methods, dp_diff)
    plt.title("Demographic Parity Difference")
    plt.xticks(rotation=45, ha="right")
    plt.ylabel("Absolute Difference")
    plt.axhline(y=0, color='r', linestyle='-', alpha=0.3)
    
    # Plot equal opportunity difference
    plt.subplot(2, 2, 3)
    eo_diff = results["Equal Opportunity Diff"]
    plt.bar(methods, eo_diff)
    plt.title("Equal Opportunity Difference")
    plt.xticks(rotation=45, ha="right")
    plt.ylabel("Absolute Difference")
    plt.axhline(y=0, color='r', linestyle='-', alpha=0.3)
    
    # Plot disparate impact
    plt.subplot(2, 2, 4)
    di = results["Disparate Impact"]
    plt.bar(methods, di)
    plt.title("Disparate Impact")
    plt.xticks(rotation=45, ha="right")
    plt.ylabel("Ratio")
    plt.axhline(y=1, color='r', linestyle='-', alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Example usage
if __name__ == "__main__":
    # Run comparison on Adult dataset
    results = compare_fairness_methods(dataset_name="adult")
    
    print("\nSummary of Results:")
    print(results)
    
    # Save results to CSV
    results.to_csv("fairness_methods_comparison.csv", index=False)
    print("Results saved to 'fairness_methods_comparison.csv'")