# pylint: disable=no-member
"""
    File for loading datasets.
"""
from sklearn.datasets import fetch_openml
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.preprocessing import MinMaxScaler

import numpy as np
import pandas as pd
class DatasetLoader:
    """ Dataset Loader Class
    """
    def __init__(self, dataset_name="", random_state=42, fairness_mode=None):
        self.mapper = {
            # "heart-disease": self.load_heart_disease,
            "haberman": self.load_haberman,
            "blood-transfusion": self.load_blood_transfusion,
            # "climate-simulation": self.load_climate_simulation,
            "sonar": self.load_sonar,
            "parkinsons": self.load_parkinsons,
            "banknote": self.load_banknote_authentication,
            "breast-cancer": self.load_breast_cancer,
            # "cylinder-bands": self.load_cylinder_bands, None numerical stuff
            "diabetes": self.load_diabetes,
            # "ionosphere": self.load_ionosphere,
            "planning-relax": self.load_planning_replax, # Bad Accuracy
            "spambase": self.load_spambase,
            "spectf": self.load_spectf,
            "wine-quality": self.load_wine_quality,
            "compas": self.load_compas,
            "compas-bin": self.load_compas_bin,
            "fico": self.load_fico,
            "mimic": self.load_mimic,
            "adult": self.load_adult,
            "adult-orig": self.load_adult_orig,
            "bank": self.load_bank,
            "census": self.load_census,
            "oulad": self.load_oulad,
            "german-credit": self.load_german_credit_bin,
            "default-credit": self.load_default_credit,
            "": lambda: None,
        }
        self.X = None
        self.y = None
        self.rs = random_state
        self.feat_name = None
        self.raw_data = None
        self.dataset_name = dataset_name
        self.load_dataset(self.dataset_name)

        # This should be sentive feature column idx if privacy mode is enabled
        self.fairness_mode = fairness_mode #TODO Revisit this

    def get_all(self) -> dict:
        """ Get all data

        Returns:
            dict: X, y, feature names, original data
        """
        return {
            'X': self.X,
            'y': self.y,
            'feature': self.feat_name,
            'original': self.raw_data,
        }

    def set_all(self, X, y, feat_name, openml_data):
        """ Set all data
        Args:
            X (numpy.ndarray): Feature data
            y (numpy.ndarray): Target data
            feat_name (numpy.ndarray): Feature names
            openml_data (sklearn.utils.Bunch): Original data
        """
        X, y = self.remove_nan(X, y)
        self.X = X
        self.y = y
        self.feat_name = feat_name
        self.raw_data = openml_data

    def kfold_data(self, kfold=5):
        """ K-Fold the whole dataset
        Args:
            kfold (int, optional): Number of folds. Defaults to 5.
        Returns:
            generator: K-Fold generator
        """
        return self.kfold_X_y(self.X, self.y, kfold)

    def kfold_X_y(self, X, y, kfold=5):
        """ K-Fold given X and y

        Args:
            X (np.ndarray): _description_
            y (np.ndarray): _description_
            kfold (int, optional): Number of folds. Defaults to 5.

        Returns:
            generator: K-Fold generator
        """
        return StratifiedKFold(n_splits=kfold, shuffle=True, random_state=self.rs).split(X, y)

    def kfold_normalized_generator(self, train_test_kfold=5,
                            train_val_kfold=5, nested=True, select_size=0.1):
        """ K-Fold Normalized Generator

        Args:
            train_test_kfold (int, optional): Number of folds for train-test. Defaults to 5.
            train_val_kfold (int, optional): Number of folds for train-val. Defaults to 5.
            nested (bool, optional): Enable nested cross validations. Defaults to True.
            select_size (float, optional): Percent of select set from train set. Defaults to 0.1.

        Yields:
            generator: K-Fold generator
        """
        for fold_id, (train_idx, test_idx) in enumerate(self.kfold_data(train_test_kfold)):
            X_train, X_test = self.X[train_idx].copy(), self.X[test_idx].copy()
            y_train, y_test = self.y[train_idx].copy(), self.y[test_idx].copy()
            X_select, y_select, X_select_sensitive = None, None, None
            if select_size > 0:
                X_train, X_select, y_train, y_select = train_test_split(
                    X_train, y_train, test_size=select_size, random_state=self.rs, stratify=y_train)

            nested_cv = None
            #TODO incorporate privacy mode in nested cv
            def nested_kfold_normalized_generator(X_train, y_train):
                for nested_fold_id, (train_idx, val_idx) in enumerate(
                            self.kfold_X_y(X_train, y_train, train_val_kfold)):
                    X_nested_train, X_val = X_train[train_idx].copy(), X_train[val_idx].copy()
                    y_nested_train, y_val = y_train[train_idx].copy(), y_train[val_idx].copy()
                    scaler = MinMaxScaler()
                    scaler.fit(X_train)
                    X_nested_train = scaler.transform(X_nested_train)
                    X_val = scaler.transform(X_val)
                    yield nested_fold_id, (X_nested_train, y_nested_train), (X_val, y_val)

            nested_cv = nested_kfold_normalized_generator(X_train, y_train) if nested else None

            scaler = MinMaxScaler()
            scaler.fit(X_train)
            X_train, X_test = scaler.transform(X_train), scaler.transform(X_test)

            if self.fairness_mode is not None:
                X_train_sensitive = X_train[:, self.fairness_mode]
                X_test_sensitive = X_test[:, self.fairness_mode]
                X_train = np.delete(X_train, self.fairness_mode, axis=1)
                X_test = np.delete(X_test, self.fairness_mode, axis=1)

            if select_size > 0:
                X_select = scaler.transform(X_select)
                if self.fairness_mode is not None:
                    X_select_sensitive = X_select[:, self.fairness_mode]
                    X_select = np.delete(X_select, self.fairness_mode, axis=1)

            if self.fairness_mode is not None:
                yield fold_id, nested_cv, (X_train, y_train, X_train_sensitive), \
                                    (X_select, y_select, X_select_sensitive), \
                                    (X_test, y_test, X_test_sensitive)
            else:
                yield (fold_id, nested_cv, (X_train, y_train), \
                                    (X_select, y_select), (X_test, y_test))

    def load_dataset(self, name: str):
        """ Load dataset

        Args:
            name (str): Dataset name

        Returns:
            dict: X, y, feature names, original data
        """
        self.dataset_name = name
        return self.mapper[name]()

    def list_all_dataset(self) -> list:
        """ List all datasets

        Returns:
            list: List of dataset names
        """
        datasets_list = list(self.mapper.keys())
        datasets_list.remove("")
        return datasets_list

    def remove_nan(self, X, y):
        """ Remove NaN values from X and y

        Args:
            X (np.ndarray): feature data
            y (np.ndarray): target data

        Returns:
            tuple: Cleaned X and y
        """
        y = y[~np.isnan(X).any(axis=1)]
        X = X[~np.isnan(X).any(axis=1)]
        return X, y

    #region Datasets
        #region OpenML Datasets
    def load_heart_disease(self):
        """ Heart Disease Dataset
        https://www.openml.org/search?type=data&status=active&id=43398
        """
        data = fetch_openml(data_id=43398)
        X = data.data.drop(columns=["target"])
        y = data.data['target']

        feat_name = X.columns.to_numpy()
        X = X.to_numpy()
        y = y.to_numpy()
        self.set_all(X, y, feat_name, data)

    def load_haberman(self):
        """ Haberman Survival Dataset
        https://www.openml.org/search?type=data&status=active&id=43
        Target = 1 if patient survive 5 years or longer
        Target = 2 if patient died within 5 years
        """
        data = fetch_openml(data_id=43)
        X = data.data
        feat_name = X.columns.to_numpy()
        X = X.to_numpy(np.float32) # .astype(int)
        y = data.target.to_numpy().astype(int)
        y = np.where(y == 2, 1, 0)
        self.set_all(X, y, feat_name, data)

    def load_blood_transfusion(self):
        """ Blood Transfusion Service Center Dataset
        https://www.openml.org/search?type=data&status=active&id=1464
        Target = 2 if patient donated blood
        Target = 1 not donating blood
        """
        data = fetch_openml(data_id=1464)
        X = data.data
        feat_name = X.columns.to_numpy()
        X = X.to_numpy().astype(np.float32) # .astype(int)
        y = data.target.to_numpy().astype(int)
        y = np.where(y == 2, 1, 0)
        self.set_all(X, y, feat_name, data)

    def load_climate_simulation(self):
        """ Climate Model Simulation Crash Dataset
        https://www.openml.org/search?type=data&status=active&id=1467
        """
        data = fetch_openml(data_id=1467)
        X = data.data
        feat_name = X.columns.to_numpy()
        X = X.to_numpy().astype(np.float32)
        y = data.target.to_numpy().astype(int)
        self.set_all(X, y, feat_name, data)

    def load_sonar(self):
        """ Sonar Dataset
        https://www.openml.org/search?type=data&status=active&id=40
        """
        data = fetch_openml(data_id=40)
        X = data.data
        feat_name = X.columns.to_numpy()
        X = X.to_numpy().astype(np.float32)
        y = data.target.to_numpy()
        y = np.where(y == 'Mine',1, 0).astype(int)

        self.set_all(X, y, feat_name, data)

    def load_parkinsons(self):
        """ Parkinson Dataset
        https://www.openml.org/search?type=data&status=active&id=1488
        """
        data = fetch_openml(data_id=1488)
        X = data.data
        feat_name = X.columns.to_numpy()
        X = X.to_numpy().astype(np.float32)
        y = data.target.to_numpy().astype(int)
        y = np.where(y == 2, 1, 0)
        self.set_all(X, y, feat_name, data)

    def load_banknote_authentication(self):
        """ Banknote Authentication Dataset
        https://www.openml.org/search?type=data&status=active&id=1462
        """
        data = fetch_openml(data_id=1462)
        X = data.data
        feat_name = X.columns.to_numpy()
        X = X.to_numpy().astype(np.float32)
        y = data.target.to_numpy().astype(int)
        y = np.where(y == 2, 1, 0)
        self.set_all(X, y, feat_name, data)

    def load_breast_cancer(self):
        """ Breast Cancer Wisconsin (Original) Data Set
        https://www.openml.org/search?type=data&status=active&id=15
        """
        data = fetch_openml(data_id=15)
        X = data.data
        feat_name = X.columns.to_numpy()
        X = X.to_numpy().astype(np.float32)
        y = data.target.to_numpy()
        y = np.where(y == "malignant", 1, 0).astype(int)
        self.set_all(X, y, feat_name, data)

    def load_cylinder_bands(self):
        """ Cylinder-Bands Dataset
        https://www.openml.org/search?type=data&status=active&id=6332
        """
        data = fetch_openml(data_id=6332)
        X = data.data
        feat_name = X.columns.to_numpy()
        X = X.to_numpy().astype(np.float32)
        y = data.target.to_numpy()
        y = np.where(y == "band", 1, 0).astype(int)

        self.set_all(X, y, feat_name, data)

    def load_diabetes(self):
        """ Diabetes Dataset
        https://www.openml.org/search?type=data&status=active&id=37
        """
        data = fetch_openml(data_id=37)
        X = data.data
        feat_name = X.columns.to_numpy()
        X = X.to_numpy().astype(np.float32)
        y = data.target.to_numpy()

        y = np.where(y == "tested_positive", 1, 0).astype(int)
        self.set_all(X, y, feat_name, data)

    def load_ionosphere(self):
        """ Ionosphere Dataset
        https://www.openml.org/search?type=data&status=active&id=59
        """
        data = fetch_openml(data_id=59)
        X = data.data
        feat_name = X.columns.to_numpy()
        X = X.to_numpy().astype(np.float32)
        y = data.target.to_numpy()
        y = np.where(y == "g", 1, 0).astype(int)
        self.set_all(X, y, feat_name, data)

    def load_planning_replax(self):
        """ Planning Relax Dataset
        https://www.openml.org/search?type=data&status=active&id=1490
        """
        data = fetch_openml(data_id=1490)
        X = data.data
        feat_name = X.columns.to_numpy()
        X = X.to_numpy().astype(np.float32)
        y = data.target.to_numpy()
        y = np.where(y == "2", 1, 0).astype(int)
        self.set_all(X, y, feat_name, data)

    def load_spambase(self):
        """ Spambase Dataset
        https://www.openml.org/search?type=data&status=active&id=44
        """
        data = fetch_openml(data_id=44)
        X = data.data
        feat_name = X.columns.to_numpy()
        X = X.to_numpy().astype(np.float32)
        y = data.target.to_numpy().astype(int)
        self.set_all(X, y, feat_name, data)

    def load_spectf(self):
        """ SPECTF Dataset
        https://www.openml.org/search?type=data&status=active&id=1600
        """
        data = fetch_openml(data_id=1600)
        X = data.data
        feat_name = X.columns.to_numpy()
        X = X.to_numpy().astype(np.float32)
        y = data.target.to_numpy().astype(int)
        self.set_all(X, y, feat_name, data)

    def load_wine_quality(self):
        """ Wine Quality Dataset
        https://www.openml.org/search?type=data&status=active&id=287
        """
        data = fetch_openml(data_id=287)
        X = data.data
        feat_name = X.columns.to_numpy()
        X = X.to_numpy().astype(np.float32)
        y = data.target.to_numpy().astype(int)
        y = np.where(y >= 6, 1, 0) # >= 6 is a good wine.
        self.set_all(X, y, feat_name, data)

        #endregion

        #region CSV Datasets
    def load_compas(self):
        """ COMPAS Dataset
        """
        data = pd.read_csv("module/data/compas.csv")
        X = data.drop(columns=["two_year_recid"])
        y = data["two_year_recid"]
        X = X.to_numpy().astype(int)
        y = y.to_numpy().astype(int)
        feat_name = data.columns.to_numpy()
        self.set_all(X, y, feat_name, data)

    def load_fico(self):
        """ FICO Dataset
        """
        data = pd.read_csv("module/data/fico.csv")
        X = data.drop(columns=["RiskPerformance"])
        y = data["RiskPerformance"]
        X = X.to_numpy().astype(int)
        y = y.to_numpy().astype(int)
        feat_name = data.columns.to_numpy()
        self.set_all(X, y, feat_name, data)

    def load_compas_bin(self):
        """ COMPAS Dataset (Binary)
        """
        data = pd.read_csv("module/data/compas-bin.csv")
        X = data.drop(columns=["two_year_recid"])
        y = data["two_year_recid"]
        X = X.to_numpy().astype(int)
        y = y.to_numpy().astype(int)
        feat_name = data.columns.to_numpy()
        self.set_all(X, y, feat_name, data)

    def load_mimic(self):
        """ MIMIC-II Dataset
        """
        data = pd.read_csv("module/data/mimic2.csv")
        X = data.drop(columns=["HospitalMortality"])
        y = data["HospitalMortality"]
        X = X.to_numpy().astype(np.float32)
        y = y.to_numpy().astype(np.float32)
        feat_name = data.columns.to_numpy()
        self.set_all(X, y, feat_name, data)

    def load_adult_orig(self):
        """ Adult Original Dataset
        https://www.openml.org/search?type=data&id=1590&sort=runs&status=active
        """
        data = fetch_openml(data_id=1590)
        X = data.data
        y = data.target
        if y.name is None:
            y.name = "target"
        
        combined_df = pd.concat([X, y], axis=1)
        combined_df.replace("?", np.nan, inplace=True)
        initial_rows = len(combined_df)
        cleaned_df = combined_df.dropna()
        print(f"Dropped {initial_rows - len(cleaned_df)} rows due to NaNs.")
        cleaned_df = cleaned_df.sample(n=10000, random_state=self.rs)

        X_processed_df = cleaned_df.drop(columns=[y.name])
        y_processed_df = cleaned_df[y.name]
        categorical_cols = X_processed_df.select_dtypes(include=['object', 'category']).columns

        if len(categorical_cols) > 0:
            X_ohe_df = pd.get_dummies(X_processed_df, columns=categorical_cols, dummy_na=False, drop_first=False)
        else:
            X_ohe_df = X_processed_df
        
        feat_name = X_ohe_df.columns.to_numpy()
        X = X_ohe_df.to_numpy().astype(np.float32)
        y = np.where(y_processed_df == ">50K", 1, 0).astype(int)
        self.set_all(X, y, feat_name, data)

    def load_adult(self):
        """ Adult dataset
        """
        data = pd.read_csv("module/data/adult.csv")
        X = data.drop(columns=["Class:>50K"])
        y = data["Class:>50K"]
        X = X.to_numpy().astype(int)
        y = y.to_numpy().astype(int)
        feat_name = data.columns.to_numpy()
        self.set_all(X, y, feat_name, data)

    def load_bank(self):
        """ Bank Dataset
        """
        data = pd.read_csv("module/data/bank.csv")
        X = data.drop(columns=["class:yes"])
        y = data["class:yes"]
        X = X.to_numpy().astype(int)
        y = y.to_numpy().astype(int)
        feat_name = data.columns.to_numpy()
        self.set_all(X, y, feat_name, data)

    def load_census(self):
        """ Census Income Dataset
        """
        data = pd.read_csv("module/data/census-income.csv")
        X = data.drop(columns=["class: 50000+"])
        y = data["class: 50000+"]
        X = X.to_numpy().astype(int)
        y = y.to_numpy().astype(int)
        feat_name = data.columns.to_numpy()
        self.set_all(X, y, feat_name, data)
        #endregion

    def load_oulad(self):
        """ OULAD Dataset
        """
        data = pd.read_csv("module/data/oulad.csv")
        X = data.drop(columns=["final_result:Pass"])
        y = data["final_result:Pass"]
        X = X.to_numpy().astype(int)
        y = y.to_numpy().astype(int)
        feat_name = data.columns.to_numpy()
        self.set_all(X, y, feat_name, data)

    def load_german_credit_bin(self):
        """ German Credit Dataset
        """
        data = pd.read_csv("module/data/german-credit-bin.csv")
        X = data.drop(columns=["Creditability"])
        y = data["Creditability"]
        X = X.to_numpy().astype(int)
        y = y.to_numpy().astype(int)
        feat_name = data.columns.to_numpy()
        self.set_all(X, y, feat_name, data)

    def load_default_credit(self):
        """ Default Credit Dataset
        """
        data = pd.read_csv("module/data/default-credit.csv")
        X = data.drop(columns=["DEFAULT_PAYMENT"])
        y = data["DEFAULT_PAYEMENT"]
        X = X.to_numpy().astype(int)
        y = y.to_numpy().astype(int)
        feat_name = data.columns.to_numpy()
        self.set_all(X, y, feat_name, data)

    #endregion
