import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix
from sklearn.datasets import fetch_openml
from sklearn.decomposition import TruncatedSVD
from sklearn.feature_selection import SelectKBest
from sklearn.feature_selection import f_regression
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import KBinsDiscretizer
from sklearn.preprocessing import MinMaxScaler
from utils import Console

# Dictionary of datasets used in experiments
datasets = {
    8: "Liver Disorders",
    9: "Automobile",
    191: "Wisconsin",
    194: "Cleveland",
    204: "Cholesterol",
    206: "Triazines",
    223: "Stock Prices",
    230: "Machine CPU",
    287: "Wine Quality",
    511: "Plasma Retinol",
    542: "Pollution",
    566: "Meta",
    1089: "US Crime",
    41021: "Moneyball",
    42225: "Diamonds",
    42352: "Student Performance",
    42363: "Forest Fires",
    42372: "Auto MPG",
    42724: "Online News Popularity",
    42726: "Abalone",
    43338: "Energy Efficiency",
    44019: "House Sales",
    44024: "California",
    44032: "Fifa",
    44042: "Black Friday",
    44133: "Pol",
    44134: "Elevator",
    44137: "Ailerons",
    44139: "House 16H",
    44141: "Brazilian Houses",
    44142: "Bike Sharing Demand",
    44143: "NYC Taxi Green",
    44145: "Sulfur",
    44146: "Medical Charges",  # Use 10 bins
    44957: "Airfoil Self Noise",
    44958: "Auction Verification",
    44959: "Concrete Compressive Strength",
    44962: "Forest Fires",
    44963: "Physicochemical Protein",
    44965: "Geographical Origin of Music",
    44966: "Solar Flare",
    44969: "Naval Propulsion Plant",
    44970: "Fish Toxicity",
    44973: "Grid Stability",
    44975: "Wave Energy",
    44983: "Miami Housing",
    44984: "CPS88 Wages",
    44987: "Socmob",
    44989: "Kings County",
    44994: "Cars",
    45950: "Heart Failure Clinical Records",
    46132: "NCI 60 Thioguanine",
    46134: "Acute Myeloid Leukemia",
    46139: "Cancer Drug Response Methylation",
    46283: "Appliances Energy Prediction",
    46286: "Communities and Crime",
    46328: "Seoul Bike Sharing",
    46337: "Conso RTE",
}


class Dataset:
    # Constructor that stores information about:
    # the dataset id,
    # the dataset name,
    # the dataset version,
    # the number of bins for discretization,
    # data_home : give the path where we can get the instance we want to use.
    def __init__(
        self,
        data_id=-1,
        data_name="",
        data_version=1,
        nb_bins=4,
        verbose=True,
        data_home=None,
    ):
        self.console = Console(verbose=verbose)
        self.dataId = data_id
        self.dataName = data_name
        self.dataVersion = data_version
        self.nbBins = nb_bins
        self.data_home = data_home

    # print dataset attributes
    def __str__(self):
        str = self.console.string("Dataset Information", endl=True)
        str += self.console.string("ID", self.dataId, endl=True)
        str += self.console.string("Data Set Name", self.dataName, endl=True)
        str += self.console.string("# Missing Values", self.nbMissingValues, endl=True)
        str += self.console.string("# Raw Features", self.nbRawFeatures, endl=True)
        str += self.console.string(
            "# Categorical Features", self.nbCatFeatures, endl=True
        )
        str += self.console.string(
            "# Numerical Features", self.nbNumFeatures, endl=True
        )
        str += self.console.string("# Binary Features", self.nbBinFeatures, endl=True)
        str += self.console.string("# Instances", self.X.shape[0], endl=True)
        return str

    # Main function for setting up data
    def setup(self) -> str:
        self.getData()
        self.cleanData()
        self.binarizeData()
        self.rescaleLabels()
        return self.dataName

    # Set dataset name from ID using openml dictionary
    def getDataName(self) -> None:
        self.dataName = f"Id{self.dataId}"
        for key, value in datasets.items():
            if self.dataId == key:  #
                self.dataName = value  #

    # Get dataset from its "id" or "name + version" in openml database
    def getData(self) -> None:
        if self.dataId != -1:
            self.console.log("Load dataset from ID", self.dataId)
            if self.data_home != None:
                self.console.log("Load dataset from the cache", self.data_home)
                self.rawData, self.rawLabels = fetch_openml(
                    data_id=self.dataId,
                    parser="auto",
                    as_frame="auto",
                    return_X_y=True,
                    data_home=self.data_home,
                )
            else:
                self.rawData, self.rawLabels = fetch_openml(
                    data_id=self.dataId, parser="auto", as_frame="auto", return_X_y=True
                )

            # Perform TruncatedSVD for dimensionality reduction
            if isinstance(self.rawData, csr_matrix):
                svd = TruncatedSVD(n_components=200, random_state=32)
                reduced_rawData = svd.fit_transform(self.rawData)
                new_features_names = [
                    f"feature_{i+1}" for i in range(reduced_rawData.shape[1])
                ]
                df_rawData = pd.DataFrame(reduced_rawData, columns=new_features_names)
                self.rawData = df_rawData

            if isinstance(self.rawLabels, np.ndarray):
                # Convert NumPy array to pandas Series
                labels_series = pd.Series(self.rawLabels)
                self.rawLabels = labels_series

            self.getDataName()
        else:
            self.console.log(
                "Load dataset from name and version",
                f"{self.dataName},v{self.dataVersion}",
            )
            if self.data_home != None:
                self.rawData, self.rawLabels = fetch_openml(
                    name=self.dataName,
                    version=self.dataVersion,
                    parser="auto",
                    as_frame=True,
                    return_X_y=True,
                    data_home=self.data_home,
                )
            else:
                self.rawData, self.rawLabels = fetch_openml(
                    name=self.dataName,
                    version=self.dataVersion,
                    parser="auto",
                    as_frame=True,
                    return_X_y=True,
                )
        self.nbRawFeatures = self.rawData.shape[1]

    # Impute missing values
    def cleanData(self) -> None:
        self.console.log("Check for missing values")

        missing_values = ["?", "", ".", "nan", "na", "none", "null"]
        self.rawData = self.rawData.apply(
            lambda x: x.map(lambda y: y.lower() if isinstance(y, str) else y)
        )
        self.rawData.replace(missing_values, np.nan, inplace=True)

        self.nbMissingValues = self.rawData.isnull().values.sum()
        if self.nbMissingValues > 0:
            imp = SimpleImputer(strategy="most_frequent")
            raw_data_cleaned = imp.fit_transform(self.rawData)
            self.rawData = pd.DataFrame(
                raw_data_cleaned, columns=self.rawData.columns, index=self.rawData.index
            ).astype(self.rawData.dtypes.to_dict())

    # Binarize features
    def binarizeData(self) -> None:
        self.console.log("Binarize features")

        self.X = pd.DataFrame()
        rawData = self.rawData.copy()

        # Remove columns with single value
        single_cols = [col for col in rawData.columns if rawData[col].nunique() == 1]
        rawData.drop(columns=single_cols, inplace=True)

        # Binarize categorical features
        df_cat_data = rawData.select_dtypes(exclude=["number"])
        self.nbCatFeatures = df_cat_data.shape[1]
        if not df_cat_data.empty:
            self.X = pd.get_dummies(df_cat_data).astype("int64")
            rawData.drop(columns=df_cat_data.columns, inplace=True)

        # Binarize numerical features with nb_values <= nb_bins
        sparse_cols = [
            col for col in rawData.columns if rawData[col].nunique() <= self.nbBins
        ]
        if sparse_cols:
            df_sparse_raw = rawData[sparse_cols].astype("category")
            df_sparse_bin = pd.get_dummies(df_sparse_raw).astype("int64")
            self.X = pd.concat([self.X, df_sparse_bin], axis=1)
            rawData.drop(columns=sparse_cols, inplace=True)

        # Binarize numerical features with nb_values >= nb_bins
        df_num_data = rawData
        self.nbNumFeatures = df_num_data.shape[1]
        if not df_num_data.empty:
            discretizer = KBinsDiscretizer(
                n_bins=self.nbBins,
                encode="onehot-dense",
                strategy="uniform",
                subsample=None,
            )
            bin_num_data = discretizer.fit_transform(
                df_num_data.to_numpy(), self.rawLabels.to_numpy()
            )
            bin_num_cols = [
                f"{c}_Bin{i}" for c in df_num_data.columns for i in range(self.nbBins)
            ]
            df_bin_num_data = pd.DataFrame(
                bin_num_data, columns=bin_num_cols, dtype="int64"
            )
            self.X = pd.concat([self.X, df_bin_num_data], axis=1)
        self.nbBinFeatures = self.X.shape[1]

    # Rescale labels
    def rescaleLabels(self) -> None:
        self.console.log("Rescale labels")

        scaler = MinMaxScaler(feature_range=(0, 1))
        numLabels = scaler.fit_transform(self.rawLabels.to_numpy().reshape(-1, +1))
        self.Y = pd.DataFrame(numLabels, columns=["Label"], dtype="float64")
