import numpy as np
from scipy.optimize import minimize
from scipy.linalg import solve
from .model_fitting import get_classifier
from .load_data import load_dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics.pairwise import rbf_kernel
from .fairness_metrics import calculate_all_fairness_metrics
from .model_functions import extract_model_parameters, create_constraint_from_model
from sklearn.neural_network._base import ACTIVATIONS
from tqdm import tqdm
import os
import pandas as pd

def clear_screen():
    # on Windows use 'cls', on Linux/macOS use 'clear'
    os.system('cls' if os.name == 'nt' else 'clear')



def drune(X, A, Y, Y_pred, distances, delta, fair_measure = "demographic", eps_y=1e-6, eps_g=1e-6, K_max=100,  q=2, y_init=None, lambda_init=None):
    """
    Distributionally Robust Unfairness Estimator (DRUNE)
    
    Parameters:
    -----------
    X : array-like, shape (N, d)
        Input features
    A : array-like, shape (N,)
        Sensitive attributes
    Y : array-like, shape (N,)
        Labels or outcomes
    Y_pred : array-like, shape (N,)
        Predicted labels or outcomes
    g_theta : function
        Constraint function g_θ(y) in the algorithm
        Should take an input vector y and return a scalar constraint value
        Should provide methods:
        - g_theta(y): evaluate constraint
        - g_theta.grad(y): gradient of constraint
        - g_theta.hessian(y): hessian of constraint
    delta : float
        Robustness parameter δ > 0
    eps_y : float, optional
        Tolerance for y updates
    eps_g : float, optional
        Tolerance for constraint values
    K_max : int, optional
        Maximum number of Newton-KKT iterations
    omega : array-like, shape (N,), optional
        Individual weights, defaults to ones
    q : float, optional
        Distance norm parameter, q ≥ 1
    y_init : array-like or None, optional
        Initial guess for y values
    lambda_init : array-like or None, optional
        Initial guess for Lagrange multipliers
        
    Returns:
    --------
    xi : array-like, shape (N,)
        Optimal weights solving the distributionally robust problem
    obj_value : float
        Value of the objective function (1/N) * sum(omega_i * xi_i)
    distances : array-like, shape (N,)
        Computed distances d_i = dist_q(x_i, L_theta)
    """
    
    # Number of samples
    N = X.shape[0]
    
    # Default initializations if not provided
    if y_init is None:
        y_init = X.copy()  # Start with input points
    
    if lambda_init is None:
        lambda_init = np.ones(N)  # Start with ones
    

    # Find the indices of the privileged group
    p_1 = np.mean(A)
    p_0 = 1- p_1

    costs = distances**q
    # ceate the ombega
    omega = np.zeros(N)

    if fair_measure == "demographic_parity":
        mask1 = (A == 1) & (Y_pred == 1)
        mask2 = (A == 0) & (Y_pred == 0)
        combined = mask1 | mask2        
        omega[mask1] = 1/p_0
        omega[mask2] = 1/p_1
    
    elif fair_measure == 'equalized_odds':
        mask1 = (A == 1) & (Y == 1) & (Y_pred == 1)
        mask2 = (A == 0) & (Y == 1) & (Y_pred == 0)
        combined = mask1 | mask2
        omega[mask1] = 1/p_0
        omega[mask2] = 1/p_1

    
    costs[~combined] = np.inf

    values = omega 
    ratios = values / costs
    
    # Sort indices by decreasing value-to-cost ratio
    # sorted_indices = np.argsort(-ratios)
    
    # Initialize solution
    z = np.zeros(N)
    capacity  = N * (delta**q)  # Total capacity
    remaining = capacity
    idx = np.argsort(-ratios)
    for i in idx:
        if remaining <= 0:
            break
        take = min(1, remaining / costs[i])
        z[i] = take
        remaining -= costs[i] * take

    # Compute objective
    robust_fair = np.dot(omega, z) / N
    
    # Fill the knapsack
    # for k in sorted_indices:
    #     if costs[k] <= C:
    #         xi[k] = 1
    #         C -= costs[k]
    #     else:
    #         xi[k] = C / costs[k]
    #         C = 0
    #         break
    
    # # Calculate objective function value
    # obj_value = np.sum(omega * xi) / N
    
    return z, robust_fair, distances



def wdf_calculate(dataset_name, classifier_name, n_samples=1000, n_experiments=1000, 
                 test_size=0.2, random_seed=42, delta=0.001, q=2):
    """
    Calculate Wasserstein Distributional Fairness metrics for a given dataset and classifier.
    
    Parameters:
    -----------
    dataset_name : str
        Name of the dataset to use
    classifier_name : str
        Type of classifier to use
    n_samples : int, optional (default=1000)
        Number of samples to use in each experiment
    n_experiments : int, optional (default=1000)
        Number of experiments to run
    test_size : float, optional (default=0.2)
        Proportion of the dataset to use for testing
    random_seed : int, optional (default=42)
        Random seed for reproducibility
    delta : float, optional (default=0.001)
        Parameter for DRUNE algorithm
    q : int, optional (default=2)
        Order of distance for DRUNE algorithm (2 for Euclidean)
    """
    # Load dataset
    X, y, A, alpha = load_dataset(dataset_name)
    clear_screen()
    
    # Find the number of features
    d = X.shape[1]
    
    # Split dataset into training and test sets
    X_train, X_test, y_train, y_test, A_train, A_test = train_test_split(
        X, y, A, test_size=test_size, random_state=random_seed
    )
    
    # Calculate p from q (used for norm calculations)
    p = q/(q-1)    

    # Get privileged and unprivileged indices
    privileged_indices = np.where(A == 1)[0]
    unprivileged_indices = np.where(A == 0)[0]
    
    fairness = []
    total_fairness = []
    regularizer = []
    
    for i in tqdm(range(n_experiments)):
        # Determine how many samples of each group to include
        n_privileged = min(int(n_samples * alpha), int(0.8*len(privileged_indices)))
        n_unprivileged = min(n_samples - n_privileged, int(0.8*len(unprivileged_indices)))
        
        # Sample indices
        sampled_privileged_indices = np.random.choice(privileged_indices, size=n_privileged, replace=False)
        sampled_unprivileged_indices = np.random.choice(unprivileged_indices, size=n_unprivileged, replace=False)
        sampled_indices = np.concatenate([sampled_privileged_indices, sampled_unprivileged_indices])
        
        # Get sampled data
        X_sampled = X[sampled_indices]
        y_sampled = y[sampled_indices]
        A_sampled = A[sampled_indices]
        
        # Fit the classifier
        model = get_classifier(classifier_name)
        model.fit(X_sampled, y_sampled)
        y_pred = model.predict(X_sampled)

        # Calculate distances based on classifier type
        if classifier_name == "logistic":
            scores = np.abs(model.decision_function(X_sampled)) 
            w = model.coef_.ravel()                 
            norm_w = np.linalg.norm(w, ord=p)
            distances = scores / norm_w 
        elif classifier_name == "linear_svm":
            scores = np.abs(model.decision_function(X_sampled)) 
            w = model.coef_[0]                 
            norm_w = np.linalg.norm(w, ord=p)
            distances = scores / norm_w
        else:
            distances, Y_proj = calculate_model_distances(
                X=X_sampled,
                model=model,
                classifier_name=classifier_name,
                q=q,
                K_max=100,
                eps_y=1e-6,
                eps_g=1e-6
            )

            
        # Run DRUNE
        xi, obj_value, distances = drune(
            X_sampled, A_sampled, y_sampled, y_pred, distances, delta, 
            fair_measure="equalized_odds",
            eps_y=1e-6, eps_g=1e-6, K_max=100, 
            q=q
        )
        
        # Calculate fairness metrics
        metrics = calculate_all_fairness_metrics(y_pred, y_sampled, A_sampled)
        total_y_pred = model.predict(X)
        total_metrics = calculate_all_fairness_metrics(total_y_pred, y, A)
        
        fairness.append(metrics['equalized_odds'])
        regularizer.append(obj_value)
        total_fairness.append(total_metrics['equalized_odds'])
        
        print(f"Regularizer: {obj_value:.6f}")
        print(f"Fairness: {metrics['equalized_odds']:.6f}")
        print(f"Total Fairness: {total_metrics['equalized_odds']:.6f}")
        print(f"Robust Fairness: {metrics['equalized_odds'] + obj_value:.6f}")
    
    # Create results DataFrame
    results_df = pd.DataFrame({
        'Sample_fairness': fairness,
        'true_fairness': total_fairness,
        'regularizer': regularizer        
    })
    
    # Save results
    results_df.to_csv(f'results/fairness_metrics_{dataset_name}_{classifier_name}_wdf.csv', index=False)
    
    return results_df



def delta_calculate(dataset_name, classifier_name, n_samples=1000, n_experiments=10, 
                   test_size=0.2, random_seed=42, delta_steps=1000, q=2):
    """
    Calculate Wasserstein Distributional Fairness metrics for different delta values.
    
    Parameters:
    -----------
    dataset_name : str
        Name of the dataset to use
    classifier_name : str
        Type of classifier to use
    n_samples : int, optional (default=1000)
        Number of samples to use in each experiment
    n_experiments : int, optional (default=10)
        Number of experiments to run for each delta value
    test_size : float, optional (default=0.2)
        Proportion of the dataset to use for testing
    random_seed : int, optional (default=42)
        Random seed for reproducibility
    delta_steps : int, optional (default=1000)
        Number of delta values to test
    q : int, optional (default=2)
        Order of distance for DRUNE algorithm (2 for Euclidean)
    """
    # Load dataset
    X, y, A, alpha = load_dataset(dataset_name)
    clear_screen()
    
    # Find the number of features
    d = X.shape[1]
    
    # Split dataset into training and pool sets
    X_train, X_test, y_train, y_test, A_train, A_test = train_test_split(
        X, y, A, test_size=test_size, random_state=random_seed
    )
    
    # Get privileged and unprivileged indices
    privileged_indices = np.where(A == 1)[0]
    unprivileged_indices = np.where(A == 0)[0]
    
    # Calculate p from q (used for norm calculations)
    p = q/(q-1)    

    # Calculate sample sizes
    n_privileged = min(int(n_samples * alpha), int(0.8*len(privileged_indices)))
    n_unprivileged = min(n_samples - n_privileged, int(0.8*len(unprivileged_indices)))

    regularizer = []
    delta_list = []
    
    for i in range(delta_steps):
        delta = 0.001*i
        wdf = []
        for i in tqdm(range(n_experiments)):
            sampled_privileged_indices = np.random.choice(privileged_indices, size=n_privileged, replace=False)
            sampled_unprivileged_indices = np.random.choice(unprivileged_indices, size=n_unprivileged, replace=False)
            sampled_indices = np.concatenate([sampled_privileged_indices, sampled_unprivileged_indices])
            
            # Get sampled data
            X_sampled = X[sampled_indices]
            y_sampled = y[sampled_indices]
            A_sampled = A[sampled_indices]
            
            model = get_classifier(classifier_name)
            model.fit(X_sampled, y_sampled)
            y_pred = model.predict(X_sampled)

            params = extract_model_parameters(model, classifier_name)

            if classifier_name == "logistic":
                scores = np.abs(model.decision_function(X_sampled)) 
                w = model.coef_.ravel()                 
                norm_w = np.linalg.norm(w, ord=p)
                distances = scores / norm_w 
            elif classifier_name == "linear_svm":
                scores = np.abs(model.decision_function(X_sampled)) 
                w = model.coef_[0]                 
                norm_w = np.linalg.norm(w, ord=p)
                distances = scores / norm_w
            else:
                distances, Y_proj = calculate_model_distances(
                X=X_sampled,
                model=model,
                classifier_name=classifier_name,
                q=q,
                K_max=100,
                eps_y=1e-6,
                eps_g=1e-6
            )

            
            # Run DRUNE
            xi, obj_value, distances = drune(
                X_sampled, A_sampled, y_sampled, y_pred, distances, delta, 
                fair_measure="equalized_odds",
                eps_y=1e-6, eps_g=1e-6, K_max=100, 
                q=q
            )
            wdf.append(obj_value)
        
        delta_list.append(delta)
        regularizer.append(np.mean(wdf))
    
    # Create results DataFrame
    results_df = pd.DataFrame({
        'delta': delta_list,
        'regularizer': regularizer        
    })
    
    # Save results
    results_df.to_csv(f'results/fairness_metrics_{dataset_name}_{classifier_name}_delta.csv', index=False)
    
    return results_df


def calculate_model_distances(X, model, classifier_name, q=2, K_max=100, eps_y=1e-6, eps_g=1e-6):
    """
    Calculate distances from model decision boundary using Newton-KKT method.
    
    Parameters:
    -----------
    X : array-like, shape (n_samples, n_features)
        Input features
    model : sklearn.base.BaseEstimator
        Trained classifier model
    classifier_name : str
        Name of the classifier type
    q : float, optional (default=2)
        Order of distance norm
    K_max : int, optional (default=100)
        Maximum number of Newton-KKT iterations
    eps_y : float, optional (default=1e-6)
        Tolerance for y updates
    eps_g : float, optional (default=1e-6)
        Tolerance for constraint values
        
    Returns:
    --------
    distances : array-like, shape (n_samples,)
        Computed distances for each sample
    Y_proj : array-like, shape (n_samples, n_features)
        Projected points on the decision boundary
    """
    N = X.shape[0]
    distances = np.zeros(N)
    Y_proj = np.zeros(X.shape)

    # Get constraint function from model
    g_theta = create_constraint_from_model(model, classifier_name)

    # Initialize y and lambda
    y_init = X.copy()
    lambda_init = np.ones(N)

    for i in range(N):
        x_i = X[i]
        y_i = y_init[i] 
        lambda_i = lambda_init[i]
        
        # Newton-KKT iterations
        k = 0
        converged = False
        
        while k < K_max and not converged:
            v = x_i - y_i
            
            # Compute absolute values properly for the q-norm
            abs_v = np.abs(v)
            abs_v_q_minus_2 = abs_v**(q-2)
            abs_v_q_minus_2 = np.where(abs_v < 1e-10, 1e-10, abs_v_q_minus_2)
            
            # Gradient of the q-norm with respect to y
            G_q = -q * np.sign(v) * abs_v**(q-1)
            
            # Compute residuals
            r_y = G_q + lambda_i * g_theta.grad(y_i)
            r_g = g_theta(y_i)
            
            # Set up the Newton-KKT system
            W_q = np.diag((q-1) * abs_v_q_minus_2)
            nabla_g = g_theta.grad(y_i)
            nabla2_g = g_theta.hessian(y_i)
            
            # Build the Jacobian matrix
            dim = len(y_i)
            J = np.zeros((dim + 1, dim + 1))
            J[:dim, :dim] = W_q + lambda_i * nabla2_g
            J[:dim, dim] = nabla_g
            J[dim, :dim] = nabla_g
            
            # Build the right-hand side
            rhs = np.concatenate([-r_y, [-r_g]])
            
            # Check for non-finite values
            if not np.all(np.isfinite(J)) or not np.all(np.isfinite(rhs)):
                # Use a simpler update if matrix has non-finite values
                delta_y = -r_y * 0.1
                delta_lambda = -r_g * 0.1
            else:
                try:
                    # Try to solve the system
                    delta_yl = solve(J, rhs)
                    delta_y = delta_yl[:dim]
                    delta_lambda = delta_yl[dim]
                except (np.linalg.LinAlgError, ValueError):
                    # Fallback if solve fails
                    delta_y = -r_y * 0.1
                    delta_lambda = -r_g * 0.1
            
            # Update variables
            y_i = y_i + delta_y
            lambda_i = lambda_i + delta_lambda
            
            # Check convergence
            converged = (np.linalg.norm(delta_y) < eps_y) and (np.abs(r_g) < eps_g)
            k += 1
        
        distances[i] = np.linalg.norm(x_i - y_i, ord=q)
        Y_proj[i] = y_i

    return distances, Y_proj




