import pandas as pd
import numpy as np
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
import os
import joblib


def preprocess_data(X, y=None, strategy='mean'):
    """
    Preprocess the data by handling missing values and scaling features.
    
    Parameters:
    -----------
    X : pandas.DataFrame or numpy.ndarray
        Input features
    y : pandas.Series or numpy.ndarray, optional
        Target variable
    strategy : str, default='mean'
        Strategy for imputing missing values. Options are:
        - 'mean': Replace missing values with mean
        - 'median': Replace missing values with median
        - 'most_frequent': Replace missing values with most frequent value
        - 'constant': Replace missing values with a constant value
        
    Returns:
    --------
    X_processed : numpy.ndarray
        Processed features
    y_processed : numpy.ndarray, optional
        Processed target variable (if y was provided)
    """
    # Convert to numpy array if pandas DataFrame
    if isinstance(X, pd.DataFrame):
        X = X.values
    if y is not None and isinstance(y, pd.Series):
        y = y.values
    
    # Handle infinite values
    X = np.nan_to_num(X, posinf=np.nan, neginf=np.nan)
    
    # Handle missing values
    if np.isnan(X).any():
        imputer = SimpleImputer(strategy=strategy)
        X = imputer.fit_transform(X)
    
    # Scale features
    scaler = StandardScaler()
    X = scaler.fit_transform(X)
    
    if y is not None:
        return X, y
    return X 



def save_processed_data(dataset_name, X_processed, y, s, alpha):
    """
    Save processed dataset to disk.
    
    Parameters:
    -----------
    dataset_name : str
        Name of the dataset
    X_processed : numpy.ndarray
        Processed features
    y : numpy.ndarray
        Labels
    s : numpy.ndarray
        Protected attributes
    alpha : float
        Protected group ratio
    """
    # Create data directory if it doesn't exist
    os.makedirs('data/processed', exist_ok=True)
    joblib.dump((X_processed, y, s, alpha), f'data/processed/{dataset_name}.npy')
    # Save data
    # np.save(f'data/processed/{dataset_name}_X.npy', X_processed)
    # np.save(f'data/processed/{dataset_name}_y.npy', y)
    # np.save(f'data/processed/{dataset_name}_s.npy', s)
    # np.save(f'data/processed/{dataset_name}_alpha.npy', alpha)

def read_processed_data(dataset_name):
    """
    Read processed dataset from disk.
    
    Parameters:
    -----------
    dataset_name : str
        Name of the dataset
        
    Returns:
    --------
    X : numpy.ndarray
        Features
    y : numpy.ndarray
        Labels
    s : numpy.ndarray
        Protected attributes
    alpha : float
        Protected group ratio
    """
    # Load data
    # X = np.load(f'data/processed/{dataset_name}_X.npy', allow_pickle=True)
    # y = np.load(f'data/processed/{dataset_name}_y.npy', allow_pickle=True)
    # s = np.load(f'data/processed/{dataset_name}_s.npy', allow_pickle=True)
    # alpha = np.load(f'data/processed/{dataset_name}_alpha.npy', allow_pickle=True)
    X, y, s, alpha = joblib.load(f'data/processed/{dataset_name}.npy')
    return X, y, s, alpha 