"""
This module provides the code for generating the benchmark datasets:
 - Friedman1
 - Ishigami function
 - G-function
"""

import numpy as np
from scipy.linalg import toeplitz
from sklearn.datasets import make_friedman1, make_friedman2, make_friedman3
from sklearn.feature_selection import mutual_info_regression
from sklearn.preprocessing import StandardScaler


class GFunction:
    """
    Class to generate samples from the G-function.
    y = (prod_{i=1}^d (|4*x_i - 2| + a_i)) / (prod_{i=1}^d (1 + a_i))

    Parameters
    ----------
    a_i_values : list or array-like, shape (d,)
        Coefficients for each feature. Higher values reduce the influence of that
        feature.
    correlation : float, optional
        Correlation coefficient for the Toeplitz covariance matrix. If None,
        features are uncorrelated.
    snr : float, optional
        Signal-to-noise ratio for the output variable.
    """

    def __init__(self, a_i_values, correlation=None, snr=1.0):
        self.a_i_values = a_i_values
        self.ai_arr = np.array(a_i_values)
        self.d = len(a_i_values)
        self.correlation = correlation
        self.snr = snr
        if correlation is not None:
            self.cov = toeplitz(correlation ** np.arange(0, self.d))
            self.cov_chol = np.linalg.cholesky(self.cov)
        else:
            self.cov = np.eye(self.d)
            self.cov_chol = np.eye(self.d)

    def sample(self, n_samples, random_state=None):
        rng = np.random.default_rng(random_state)
        X = rng.uniform(0, 1, size=(n_samples, self.d))
        X = X.dot(self.cov_chol.T)

        y = self.g_function(X, self.ai_arr)
        noise = rng.normal(0, np.std(y) / self.snr, size=n_samples)
        y = y + noise

        y = StandardScaler().fit_transform(y.reshape(-1, 1)).ravel()
        X = StandardScaler().fit_transform(X)
        return X, y

    @staticmethod
    def g_function(X, ai_arr):
        """
        Parameters
        ----------
        X : array-like, shape (n_samples, d)
            Input samples.
        ai_arr : array-like, shape (d,)
            Coefficients for each dimension.
        """
        numerator = np.prod(np.abs(4 * X - 2) + ai_arr, axis=1)
        denominator = np.prod(1 + ai_arr)
        return numerator / denominator


class IshigamiFunction:
    """
    Class to generate samples from the Ishigami function.
    y = sin(x1) + 7*sin(x2)^2 + 0.1*x3^4*sin(x1)

    Parameters
    ----------
    n_features : int, optional
        Number of input features. Default is 3.
    correlation : float, optional
        Correlation coefficient for the Toeplitz covariance matrix. If None,
        features are uncorrelated.
    snr : float, optional
        Signal-to-noise ratio for the output variable.
    classification : bool, optional
        If True, the output variable is converted to a binary classification
        problem based on the median value.
    """

    def __init__(self, n_features=3, correlation=None, snr=1.0, classification=False):
        self.n_features = n_features
        self.correlation = correlation
        self.snr = snr
        self.classification = classification
        if correlation is not None:
            self.cov = toeplitz(correlation ** np.arange(0, self.n_features))
            self.cov_chol = np.linalg.cholesky(self.cov)
        else:
            self.cov = np.eye(self.n_features)
            self.cov_chol = np.eye(self.n_features)

    def sample(self, n_samples, random_state=None):
        rng = np.random.default_rng(random_state)
        X = rng.uniform(-np.pi, np.pi, size=(n_samples, self.n_features))
        X = X.dot(self.cov_chol.T)

        y = self.ishigami_function(X)
        noise = rng.normal(0, np.std(y) / self.snr, size=n_samples)
        y = y + noise

        y = StandardScaler().fit_transform(y.reshape(-1, 1)).ravel()
        X = StandardScaler().fit_transform(X)
        if self.classification:
            median = np.median(y)
            y = (y > median).astype(int)
        return X, y

    @staticmethod
    def ishigami_function(X):
        """
        Parameters
        ----------
        X : array-like, shape (n_samples, d)
            Input samples.
        ai_arr : array-like, shape (d,)
            Coefficients for each dimension.
        """
        y = (
            np.sin(X[:, 0])
            + 7 * np.sin(X[:, 1]) ** 2
            + 0.1 * X[:, 2] ** 4 * np.sin(X[:, 0])
        )
        return y


def get_dataset(
    dataset_name,
    n_samples,
    snr,
    random_state,
    n_features=None,
):
    """
    Wrapper function to generate datasets. For each dataset, the target variable is
    normalized to have zero mean and unit variance.

    Parameters
    ----------
    dataset_name : str
        Name of the dataset to generate. Options are "friedman1", "friedman2",
        "friedman3", "g_function", "ishigami".
    n_samples : int
        Number of samples to generate.
    snr : float
        Signal-to-noise ratio for the output variable.
    random_state : int
        Random seed for reproducibility.
    n_features : int, optional
        Number of features.

    Returns
    -------
    X : array-like, shape (n_samples, n_features)
        Generated input features.
    y : array-like, shape (n_samples,)
        Generated target variable, normalized.
    support : array-like, shape (n_relevant_features,)
        Indices of the features that are part of the true support.
    support_bis : array-like, shape (n_relevant_features,)
        Indices of the features that are part of the true support (duplicate unused).
    """
    if dataset_name == "friedman1":
        X, y = make_friedman1(
            n_samples=n_samples,
            n_features=n_features,
            noise=2 / snr,
            random_state=random_state,
        )
        minfo = mutual_info_regression(X, y, random_state=random_state)
        support = np.argsort(minfo)[-5:]
        support_bis = support
    elif dataset_name == "friedman2":
        X, y = make_friedman2(
            n_samples=n_samples,
            noise=2 / snr,
            random_state=random_state,
        )
        minfo = mutual_info_regression(X, y, random_state=random_state)
        support = np.arange(X.shape[1])
        support_bis = support
    elif dataset_name == "friedman3":
        X, y = make_friedman3(
            n_samples=n_samples,
            noise=2 / snr,
            random_state=random_state,
        )
        minfo = mutual_info_regression(X, y, random_state=random_state)
        support = np.arange(X.shape[1])
        support_bis = support

    elif dataset_name == "g_function":
        a_i_values = [0, 1, 2, 3, 4] + [100] * (n_features - 5)
        g_func = GFunction(a_i_values, correlation=0.3, snr=snr)
        X, y = g_func.sample(n_samples=n_samples, random_state=random_state)
        support = np.arange(5)
        support_bis = support
    elif dataset_name == "ishigami":
        ishigami = IshigamiFunction(n_features=n_features, correlation=0.3, snr=snr)
        X, y = ishigami.sample(n_samples=n_samples, random_state=random_state)
        support = np.array([0, 1, 2])
        support_bis = support
    else:
        raise ValueError(f"Unknown dataset name: {dataset_name}")

    y_norm = StandardScaler().fit_transform(y.reshape(-1, 1)).ravel()
    return X, y_norm, support, support_bis
