import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import os
from .load_data import load_dataset
from src.fairness_metrics import calculate_all_fairness_metrics
from src.model_fitting import get_classifier


def fit_calculate(dataset='adult', classifier_name='logistic', n_samples = 1000, n_experiments = 10, 
         random_seed=42, **classifier_params):
    
    # Create directories if they don't exist
    os.makedirs('results', exist_ok=True)
    os.makedirs('plots', exist_ok=True)
    
    # Set random seed
    np.random.seed(random_seed)
    
    # Load dataset and sensitive attribute
    X, y, sensitive_attr, alpha = load_dataset(dataset)


    # Split dataset into training and pool sets
    X_train, X_pool, y_train, y_pool, sa_train, sa_pool = train_test_split(
        X, y, sensitive_attr, test_size=0.2, random_state=42
    )
    
    # Get classifier
    print(f"Training {classifier_name} classifier...")
    clf = get_classifier(classifier_name, **classifier_params)
    clf.fit(X_train, y_train)
    
    # Create indices for privileged and unprivileged samples in the pool
    privileged_indices = np.where(sa_pool == 1)[0]
    unprivileged_indices = np.where(sa_pool == 0)[0]
    
    # Number of samples to draw
    
    # Calculate number of privileged and unprivileged samples to maintain alpha ratio
    # I want to min of n_samples * alpha by whole number or privileged samples
    n_privileged = min(int(n_samples * alpha), int(0.8*len(privileged_indices)))
    n_unprivileged = min(n_samples - n_privileged, int(0.8*len(unprivileged_indices)))
    # print(f"Privileged samples: {n_privileged}, Unprivileged samples: {n_unprivileged}")
    
    print(f"Running {n_experiments} experiments with {n_samples} samples each")
    print(f"Each sample will contain {n_privileged} privileged and {n_unprivileged} unprivileged samples (ratio = {alpha:.4f})")
    
    # Store fairness metrics results
    dp_results = []
    eo_results = []
    eod_results = []
    te_results = []
    pp_results = []
    ae_results = []
    ma_results = []
    
    # Run experiments
    for i in tqdm(range(n_experiments)):
        # Randomly sample while maintaining ratio
        sampled_privileged_indices = np.random.choice(privileged_indices, size=n_privileged, replace=False)
        sampled_unprivileged_indices = np.random.choice(unprivileged_indices, size=n_unprivileged, replace=False)
        sampled_indices = np.concatenate([sampled_privileged_indices, sampled_unprivileged_indices])
        
        # Get sampled data
        X_sampled = X_pool[sampled_indices]
        y_sampled = y_pool.iloc[sampled_indices]
        sa_sampled = sa_pool.iloc[sampled_indices]
        
    
        # Train a classifier on the sampled data
        sample_clf = get_classifier(classifier_name, **classifier_params)
        sample_clf.fit(X_sampled, y_sampled)
        
        # Make predictions
        y_pred = sample_clf.predict(X_sampled)
        
        # Calculate all fairness metrics
        metrics = calculate_all_fairness_metrics(y_pred, y_sampled, sa_sampled)
        
        # Store results
        dp_results.append(metrics['demographic_parity'])
        eo_results.append(metrics['equal_opportunity'])
        eod_results.append(metrics['equalized_odds'])
        te_results.append(metrics['treatment_equality'])
        pp_results.append(metrics['predictive_parity'])  
        ae_results.append(metrics['accuracy_equality'])  
        ma_results.append(metrics['model_accuracy'])     
    
    # Save results to CSV
    results_df = pd.DataFrame({
        'Demographic Parity': dp_results,
        'Equal Opportunity': eo_results,
        'Equalized Odds': eod_results,
        'Treatment Equality': te_results,
        'Predictive Parity': pp_results,    
        'Accuracy Equality': ae_results,     
        'Model Accuracy': ma_results         
    })
    results_df.to_csv(f'results/fairness_metrics_{dataset}_{classifier_name}_{random_seed}.csv', index=False)


def fixed_calculate(dataset='adult', classifier_name='logistic', n_samples=1000, n_experiments=10,
                      random_seed=42, **classifier_params):
    
    # Create directories if they don't exist
    os.makedirs('results', exist_ok=True)
    os.makedirs('plots', exist_ok=True)
    
    # Set random seed
    np.random.seed(random_seed)
    
    # Load dataset and sensitive attribute
    X, y, sensitive_attr, alpha = load_dataset(dataset)
    
    # Split dataset into training and pool sets
    X_train, X_pool, y_train, y_pool, sa_train, sa_pool = train_test_split(
        X, y, sensitive_attr, test_size=0.2, random_state=42
    )
    
    # Create indices for privileged and unprivileged samples in the pool
    privileged_indices = np.where(sa_pool == 1)[0]
    unprivileged_indices = np.where(sa_pool == 0)[0]
    
    # Calculate number of privileged and unprivileged samples to maintain alpha ratio
    n_privileged = min(int(n_samples * alpha), int(0.8*len(privileged_indices)))
    n_unprivileged = min(n_samples - n_privileged, int(0.8*len(unprivileged_indices)))
    
    print(f"Generating initial sample and training fixed model...")
    print(f"Initial sample will contain {n_privileged} privileged and {n_unprivileged} unprivileged samples (ratio = {alpha:.4f})")
    
    # Generate initial sample
    initial_privileged_indices = np.random.choice(privileged_indices, size=n_privileged, replace=False)
    initial_unprivileged_indices = np.random.choice(unprivileged_indices, size=n_unprivileged, replace=False)
    initial_indices = np.concatenate([initial_privileged_indices, initial_unprivileged_indices])
    
    # Get initial sample data
    X_initial = X_pool[initial_indices]
    y_initial = y_pool.iloc[initial_indices]
    
    # Train fixed classifier on the initial sample
    fixed_clf = get_classifier(classifier_name, **classifier_params)
    fixed_clf.fit(X_initial, y_initial)
    
    # Remove used samples from the pool
    # mask = np.ones(len(X_pool), dtype=bool)
    mask = np.ones(X_pool.shape[0], dtype=bool)
    mask[initial_indices] = False
    X_remaining_pool = X_pool[mask]
    y_remaining_pool = y_pool.iloc[mask]
    sa_remaining_pool = sa_pool.iloc[mask]
    
    # Update indices for privileged and unprivileged samples in the remaining pool
    privileged_indices = np.where(sa_remaining_pool == 1)[0]
    unprivileged_indices = np.where(sa_remaining_pool == 0)[0]
    
    print(f"Running {n_experiments} experiments with {n_samples} samples each")
    
    # Store fairness metrics results
    dp_results = []
    eo_results = []
    eod_results = []
    te_results = []
    pp_results = []
    ae_results = []
    ma_results = []
    
    # Run experiments
    for i in tqdm(range(n_experiments)):
        # Check if we have enough samples left
        if len(privileged_indices) < n_privileged or len(unprivileged_indices) < n_unprivileged:
            print(f"Warning: Not enough samples left in pool at experiment {i}. Stopping early.")
            break
            
        # Randomly sample while maintaining ratio
        sampled_privileged_indices = np.random.choice(privileged_indices, size=n_privileged, replace=False)
        sampled_unprivileged_indices = np.random.choice(unprivileged_indices, size=n_unprivileged, replace=False)
        sampled_indices = np.concatenate([sampled_privileged_indices, sampled_unprivileged_indices])
        
        # Get sampled data
        X_sampled = X_remaining_pool[sampled_indices]
        y_sampled = y_remaining_pool.iloc[sampled_indices]
        sa_sampled = sa_remaining_pool.iloc[sampled_indices]
        
        # Use the fixed classifier to make predictions on the new sample
        y_pred = fixed_clf.predict(X_sampled)
        
        # Calculate all fairness metrics
        metrics = calculate_all_fairness_metrics(y_pred, y_sampled, sa_sampled)
        
        # Store results
        dp_results.append(metrics['demographic_parity'])
        eo_results.append(metrics['equal_opportunity'])
        eod_results.append(metrics['equalized_odds'])
        te_results.append(metrics['treatment_equality'])
        pp_results.append(metrics['predictive_parity'])  
        ae_results.append(metrics['accuracy_equality'])  
        ma_results.append(metrics['model_accuracy'])
    
    # Save results to CSV
    results_df = pd.DataFrame({
        'Demographic Parity': dp_results,
        'Equal Opportunity': eo_results,
        'Equalized Odds': eod_results,
        'Treatment Equality': te_results,
        'Predictive Parity': pp_results,    
        'Accuracy Equality': ae_results,     
        'Model Accuracy': ma_results         
    })
    results_df.to_csv(f'results/fairness_deviation_{dataset}_{classifier_name}_{random_seed}.csv', index=False)
    
    # Calculate and print summary statistics
    print("\nSummary Statistics for Fairness Metrics:")
    print(results_df.describe())
    
    return results_df