import os
from typing import Tuple

import numpy as np
import pandas as pd
import torch
# from robustbench.data import load_cifar10c
import matplotlib.pyplot as plt
from data_utils.data_corruption.corruption_type import get_corruption_type_from_dataset_name
from data_utils.data_corruption.covariates_dimension_reducer import CovariatesDimensionReducer
from data_utils.data_corruption.data_corruption_masker import DataCorruptionIndicatorFactory, \
    OracleDataCorruptionMasker, DefaultDataCorruptionMasker
from data_utils.data_scaler import DataScaler
from data_utils.data_type import DataType
from data_utils.dataset_naming_utils import get_original_dataset_name, get_e_from_dataset_name
from data_utils.datasets.regression_dataset import RegressionDataset
from data_utils.datasets.synthetic_causal_inference_data_generator import CausalInferenceDataGenerator
from data_utils.datasets.synthetic_dataset_generator import PartiallyLinearDataGenerator, SyntheticDataGenerator
from models.LinearModel import LinearModel, LogisticLinearModel
from models.classifiers.XGBClassifier import XGBClassifier
from models.regressors.FullRegressor import FullRegressor
from utils.utils import set_seeds
from models.qr_models.RFQR import RFQR

# proxy_col_dict = {
#     'meps_19':  [3, 0, 100, 98],
#     'meps_20': [3, 138],
#     'meps_21': [3, 138],
#     'facebook_1': [33, 11],
#     'facebook_2': [33, 11],
#     'bio': [2, 3],
#     'house': [2, 8],  # 14,
#     "blog": [9, 20]
# }
# https://meps.ahrq.gov/mepsweb/data_stats/download_data_files_codebook.jsp?PUFId=H147
# https://meps.ahrq.gov/mepsweb/data_stats/download_data_files_codebook.jsp?PUFId=H147&varName=K6SUM42
proxy_col_dict = {
    'meps_19': [3, 0, 100],
    'meps_20': [3, 0, 100],
    'meps_21': [3, 0, 100],
    'facebook_1': [33],
    'facebook_2': [33],
    'bio': [2],
    'house': [2],  # 14,
    "blog": [9]
}


def get_noised_ihdp_data(args):
    x, y, z, corruption_masker, mask = get_ihdp_data(args)
    set_seeds(args.seed)
    model_x = torch.cat([x, z], dim=-1).to(args.device)
    model_y = mask.to(args.device)
    classifier = XGBClassifier(args.dataset_name + '_y(0)', args.saved_models_path, model_x.shape[1], 2,
                               device=args.device, max_depth=2, n_estimators=10, seed=args.seed)
    permutation = np.random.permutation(len(model_x))
    train_idx = permutation[:int(len(permutation) * 0.8)]
    val_idx = permutation[int(len(permutation) * 0.8):]
    classifier.fit_xy(model_x[train_idx], model_y[train_idx], None, model_x[val_idx], model_y[val_idx], None,
                      epochs=1000, batch_size=args.bs, n_wait=args.wait)
    propensity = classifier.estimate_probabilities(model_x).probabilities.detach()[..., 1].cpu()
    correction = mask.float().mean() - propensity.mean()
    propensity += correction
    high_mask_probability_idx = propensity >= torch.quantile(propensity, q=0.8)

    y_std = y.std()
    noise = torch.randn_like(y) * (y_std * 5)
    y[high_mask_probability_idx] += noise[high_mask_probability_idx]
    set_seeds(args.seed)
    return x, y, z, corruption_masker, mask


def get_ihdp_data(args):
    # Get the data from here: https://www.fredjo.com/
    # Was used in "Estimating individual treatment effect: generalization bounds and algorithms"
    # Was used also in "Causal Effect Inference with Deep Latent-Variable Models"
    data1 = np.load(os.path.join(args.data_path, "ihdp_npci_1-1000.train.npz"))
    data2 = np.load(os.path.join(args.data_path, "ihdp_npci_1-1000.test.npz"))
    x, t, yf, ycf = [], [], [], []
    for data in [data1, data2]:
        x += [torch.Tensor(data.get("x"))]
        t += [torch.Tensor(data.get("t"))]
        yf += [torch.Tensor(data.get("yf"))]
        ycf += [torch.Tensor(data.get("ycf"))]
    x = torch.cat(x, dim=0)
    t = torch.cat(t, dim=0)
    yf = torch.cat(yf, dim=0)
    ycf = torch.cat(ycf, dim=0)
    t = t[..., 0].bool()
    x = x[..., 0]
    yf = yf[..., 0]
    ycf = ycf[..., 0]
    y = yf.clone()
    y[t] = ycf[t].clone()

    # [np.corrcoef(mask.float().numpy(), x.numpy()[:, i]) for i in range(x.shape[1])]
    # relevant_cols = [i for i in range(x.shape[1]) if len(torch.unique(x[:, i])) > 5]
    relevant_cols = range(x.shape[1])
    proxy_col = np.argmax([abs(np.corrcoef(y.float().numpy(), x.numpy()[:, i])[0, 1]) for i in relevant_cols])
    assert proxy_col == 5
    z = x[:, proxy_col]
    if len(z.shape) == 1:
        z = z.unsqueeze(-1)
    cols_mask = np.ones(x.shape[1], dtype=bool)
    cols_mask[proxy_col] = False
    x = x[:, cols_mask]

    model = LinearModel("ihdp_x_y_model", args.saved_models_path, args.figures_dir, args.seed)
    model.fit(torch.cat([x, z], dim=-1), y)

    class LinearModelReducer(CovariatesDimensionReducer):

        def __call__(self, x: torch.Tensor, z: torch.Tensor, *args, **kwargs) -> torch.Tensor:
            if len(z.shape) == 1:
                z = z.unsqueeze(-1)
            return model.predict(torch.cat([x, z], dim=-1)).squeeze()

    corruption_masker = DefaultDataCorruptionMasker('ihdp', LinearModelReducer(),
                                                    unscaled_full_x=x,
                                                    unscaled_full_z=z,
                                                    marginal_masking_ratio=0.2
                                                    )
    mask = corruption_masker.get_corruption_mask(unscaled_x=x, unscaled_z=z, seed=args.seed).bool()

    return x, y, z, corruption_masker, mask


def get_nlsm_data(args):
    # -1: Was used in https://arxiv.org/pdf/2006.06138.pdf (Conformal Inference
    # of Counterfactuals and Individual Treatment Effects)
    # 0: Raw data was taken from: https://github.com/grf-labs/grf/blob/master/experiments/acic18/synthetic_data.csv
    # 1: real trial: Assessing treatment effect variation in observational studies: Results from a data challenge
    # 2: synthetic data based on it: A national experiment reveals where a growth mindset improves achievement
    # 3: The artificial definition of the data is based on "Sensitivity Analysis of Individual Treatment Effects:
    # A Robust Conformal Inference Approach" which is based on "Assessing treatment effect variation
    # in observational studies: Results from a data challenge"

    set_seeds(42)
    df = pd.read_csv(os.path.join(args.data_path, 'nslm.csv'))
    mask = torch.Tensor(df['Z'].values).bool().to(args.device)
    y = torch.Tensor(df['Y'].values).to(args.device)
    X_cols = ['S3', 'C1', 'C2', 'C3', 'XC', 'X1', 'X2', 'X3', 'X4', 'X5']
    x1_col = X_cols.index("X1")
    x2_col = X_cols.index("X2")
    c1_col = X_cols.index("C1")
    x = torch.Tensor(df[X_cols].values).to(args.device)
    temp_scaler = DataScaler()
    temp_scaler.initialize_scalers(x, y)
    x = temp_scaler.scale_x(x)
    y = temp_scaler.scale_y(y)

    regressor = FullRegressor(args.dataset_name + "_y(0)", args.saved_models_path, x.shape[1], 0,
                              [32],
                              args.dropout, False, args.lr, args.wd, args.device,
                              figures_dir=args.figures_dir,
                              seed=0)
    permutation = np.random.permutation(len(x[~mask]))
    train_idx = permutation[:int(len(permutation) * 0.8)]
    val_idx = permutation[int(len(permutation) * 0.8):]
    regressor.fit(x[~mask][train_idx], y[~mask][train_idx], None, x[~mask][val_idx], y[~mask][val_idx], None,
                  epochs=1000, batch_size=args.bs, n_wait=args.wait, z_train=None, z_val=None)
    #
    # classifier = NetworkClassifier(args.dataset_name + "_y(0)", args.saved_models_path, x.shape[1], 2,
    #                           [32],
    #                           args.dropout, False, args.lr, args.wd, args.device,
    #                           figures_dir=args.figures_dir,
    #                           seed=0)
    classifier = XGBClassifier(args.dataset_name + '_y(0)', args.saved_models_path, x.shape[1], 2, device=args.device,
                               max_depth=2, n_estimators=10, seed=0)
    # permutation = np.random.permutation(len(x))
    # train_idx = permutation[:int(len(permutation) * 0.8)]
    # val_idx = permutation[int(len(permutation) * 0.8):]
    classifier.fit_xy(x[train_idx], mask[train_idx], None, x[val_idx], mask[val_idx], None, epochs=1000,
                      batch_size=args.bs, n_wait=args.wait)
    propensity = classifier.estimate_probabilities(x).probabilities.detach()[..., 1]
    correction = mask.float().mean() - propensity.mean()
    propensity += correction

    mu_hat_0 = temp_scaler.unscale_y(regressor.predict_mean(x, None).detach())
    x = temp_scaler.unscale_x(x)
    y = temp_scaler.unscale_y(y)

    u = torch.randn_like(y) * 0.2
    is_u_extreme = (u >= torch.quantile(u, q=0.9)) | (u <= torch.quantile(u, q=0.1))
    # new_propensity = propensity
    new_propensity = propensity * 2 * is_u_extreme + propensity * (~is_u_extreme)
    new_propensity = torch.min(new_propensity, 0.8 * torch.ones_like(new_propensity))

    # mask = torch.rand_like(new_propensity) < new_propensity

    c1_is_in_set = (abs(x[:, c1_col] - 1) < 0.1) | (abs(x[:, c1_col] - 13) < 0.1) | (abs(x[:, c1_col] - 14) < 0.1)
    tau = 0.228 + 0.05 * (x[:, x1_col] < 0.07) - 0.05 * (x[:, x2_col] < -0.69) - 0.08 * c1_is_in_set
    new_u = 2 * u * is_u_extreme + u * (~ is_u_extreme)
    new_y = mu_hat_0 + tau + new_u

    z = u.unsqueeze(-1)

    corruption_masker = OracleDataCorruptionMasker(x, z, new_propensity)
    new_mask = corruption_masker.get_corruption_mask(x, z)

    # corruption_probabilities = corruption_masker.get_corruption_probabilities(x, z)
    # err = (corruption_probabilities - new_propensity).abs().max().item()
    # if err > 0:
    #     print(f"probability error: err={err}")

    set_seeds(args.seed)

    return x.cpu(), new_y.cpu(), z.cpu(), corruption_masker, new_mask.cpu()


def get_twins_data(args):
    # taken from https://github.com/py-why/dowhy/blob/main/docs/source/example_notebooks/dowhy_twins_example.ipynb
    x_df = pd.read_csv(os.path.join(args.data_path, "twins", 'twins_x.csv'))
    y_df = pd.read_csv(os.path.join(args.data_path, "twins", 'twins_y.csv'))[["mort_0", "mort_1"]]
    t_df = pd.read_csv(os.path.join(args.data_path, "twins", 'twins_t.csv'))[["dbirwt_0", "dbirwt_1"]]
    lighter_features = x_df[['pldel', 'birattnd', 'brstate', 'stoccfipb', 'mager8',
                             'ormoth', 'mrace', 'meduc6', 'dmar', 'mplbir', 'mpre5', 'adequacy',
                             'orfath', 'frace', 'birmon', 'gestat10', 'csex', 'anemia', 'cardiac',
                             'lung', 'diabetes', 'herpes', 'hydra', 'hemo', 'chyper', 'phyper',
                             'eclamp', 'incervix', 'pre4000', 'preterm', 'renal', 'rh', 'uterine',
                             'othermr', 'tobacco', 'alcohol', 'cigar6', 'drink5', 'crace',
                             'data_year', 'nprevistq', 'dfageq', 'feduc6', 'infant_id_0',
                             'dlivord_min', 'dtotord_min', 'bord_0',
                             'brstate_reg', 'stoccfipb_reg', 'mplbir_reg']]
    lighter_features.fillna(value=lighter_features.mean(axis="rows"), inplace=True)
    # lighter_features["dbirwt_0"] = t["dbirwt_0"]
    relevant_idx = (t_df <= 2000).all(axis='columns')
    x = torch.Tensor(lighter_features.to_numpy())[relevant_idx]
    z = torch.Tensor(t_df["dbirwt_0"].to_numpy()).unsqueeze(-1)[relevant_idx]
    y = torch.Tensor(y_df["mort_0"].to_numpy()).unsqueeze(-1)[relevant_idx]
    """
    t2 = t.copy().to_numpy()
    t2[:, 0] = 0
    t2[:, 1] = 1
    y.to_numpy().flatten()
    scipy.stats.pearsonr(t2.flatten(), y.to_numpy().flatten())
    """

    model = LogisticLinearModel("twins_x_y_model", args.saved_models_path, args.figures_dir, args.seed)
    model.fit(torch.cat([x, z], dim=-1), y)

    class LinearModelReducer(CovariatesDimensionReducer):

        def __call__(self, x: torch.Tensor, z: torch.Tensor, *args, **kwargs) -> torch.Tensor:
            if len(z.shape) == 1:
                z = z.unsqueeze(-1)
            return model.estimate_probabilities(torch.cat([x, z], dim=-1)).squeeze()[:, 1]

    corruption_masker = DefaultDataCorruptionMasker('twins', LinearModelReducer(),
                                                    unscaled_full_x=x,
                                                    unscaled_full_z=z,
                                                    marginal_masking_ratio=0.2)

    mask = corruption_masker.get_corruption_mask(unscaled_x=x, unscaled_z=z, seed=args.seed).bool()

    return x, None, y, z, corruption_masker, mask


def get_real_data(args):
    # some of the datasets are taken from: https://github.com/py-why/dowhy/blob/main/dowhy/datasets.py
    dataset_name = args.dataset_name
    if 'noised_ihdp' in dataset_name:
        return get_noised_ihdp_data(args)
    elif 'ihdp' in dataset_name:
        return get_ihdp_data(args)
    elif 'nslm' in dataset_name:
        return get_nlsm_data(args)
    z_dim = args.z_dim
    original_data_name = get_original_dataset_name(dataset_name)

    corruption_type = get_corruption_type_from_dataset_name(dataset_name)
    X, y = load_real_regression_dataset(original_data_name, args.data_path)
    X = torch.Tensor(X)
    y = torch.Tensor(y)
    # [abs(np.corrcoef(X[:, i], y)[0, 1]) for i in range(X.shape[1])]
    # np.argsort([abs(np.corrcoef(X[:, i], y)[0,1]) for i in range(X.shape[1])])
    # [j for j in np.argsort([abs(np.corrcoef(X[:, i], y)[0,1]) for i in range(X.shape[1])]) if len(X[:, j].unique()) > 10]
    if len(proxy_col_dict[original_data_name]) < z_dim:
        raise Exception("too few proxy columns")
    proxy_cols = proxy_col_dict[original_data_name]  # [:z_dim]
    # if isinstance(proxy_col, list):
    #     proxy_col = proxy_col[0]
    # if dataset_name == 'facebook_1' or  dataset_name == 'facebook_2':
    #     X[:, response_col] = torch.log(X[:, response_col] - X[:, response_col].min() + 1)
    #     y = torch.log(y - y.min() + 1)
    # if dataset_name == 'popularity':
    #     y = torch.log(y - y.min() + 1)

    # import matplotlib
    # import matplotlib.pyplot as plt
    # matplotlib.use('module://backend_interagg')
    # response_col = 3
    # plt.hist(torch.log(X[:, response_col] - X[:, response_col].min() + 1).numpy(), bins=20)
    # plt.xlabel("log x")
    # plt.ylabel("count")
    # plt.show()
    # plt.hist(X[:, response_col].numpy(), bins=20)
    # plt.xlabel("x")
    # plt.ylabel("count")
    # plt.show()
    # plt.hist(y.numpy(), bins=20)
    # plt.xlabel("y")
    # plt.ylabel("count")
    # plt.show()
    # plt.hist(torch.log(y - y.min() + 1).numpy(), bins=20)
    # plt.xlabel("log(y)")
    # plt.ylabel("count")
    # plt.show()
    # plt.scatter(X[:, response_col], y)
    # plt.xlabel("x")
    # plt.ylabel("y")
    # plt.show()
    Z = X[:, proxy_cols]
    e = get_e_from_dataset_name(dataset_name)
    if e is not None:
        Z = (0.1 * e) * torch.randn_like(Z) + Z

    # if z_dim > 1:
    #     more_z = generate_proxy_variable(y.squeeze(), z_dim - 1)
    #     Z = torch.cat([Z.unsqueeze(-1), more_z], dim=-1)

    Y = y.unsqueeze(1)
    # Y[:, 0] = torch.log(Y[:, 0] - Y[:,0].min() + 1)
    # if args.figures_dir is not None:
    #     plt.clf()
    #     plt.scatter(Y[:, 0], Y[:, 1])
    #     save_dir = os.path.join(args.figures_dir, dataset_name, f'seed={args.seed}')
    #     create_folder_if_it_doesnt_exist(save_dir)
    #     save_path = os.path.join(save_dir, "data_visualization.png")
    #     plt.savefig(save_path, dpi=300, bbox_inches='tight')
    #     plt.show()

    mask = np.ones(X.shape[1], dtype=bool)
    mask[proxy_cols] = False
    X = X[:, mask]
    if 'meps' in dataset_name.lower():
        covariates_reducer = get_uncertainty_covariates_reducer(dataset_name, args.saved_models_path, X, y, Z, args)
    else:
        covariates_reducer = None
    corruption_masker = DataCorruptionIndicatorFactory.get_corruption_masker(dataset_name, corruption_type, X, Z,
                                                                             Y,
                                                                             covariates_reducer=covariates_reducer)

    mask = corruption_masker.get_corruption_mask(X, Z)
    return X, Y, Z, corruption_masker, mask


def get_uncertainty_covariates_reducer(dataset_name, saved_models_path, X, Y, Z, args):
    alpha = 0.6
    model = RFQR(f"{dataset_name}_x_y_model", saved_models_path, seed=0, alpha=alpha)
    permutation = np.random.permutation(len(X))
    train_idx = permutation[:int(len(permutation) * 0.8)]
    val_idx = permutation[int(len(permutation) * 0.8):]
    model_in = torch.cat([Z], dim=-1)
    model.fit(model_in[train_idx], Y[train_idx], None, model_in[val_idx], Y[val_idx], None)

    class LinearModelReducer(CovariatesDimensionReducer):

        def __call__(self, x: torch.Tensor, z: torch.Tensor, *args, **kwargs) -> torch.Tensor:
            if len(z.shape) == 1:
                z = z.unsqueeze(-1)
            model_in = torch.cat([z], dim=-1)
            intervals = model.construct_uncalibrated_intervals(model_in).intervals
            return intervals[:, 1] - intervals[:, 0]

    return LinearModelReducer()


def get_data_generator(dataset_name, x_dim, z_dim) -> SyntheticDataGenerator:
    original_dataset_name = get_original_dataset_name(dataset_name)
    corruption_type = get_corruption_type_from_dataset_name(dataset_name)
    if original_dataset_name == 'regression_synthetic' or original_dataset_name == 'regression_synthetic_with_overcoverage':
        return PartiallyLinearDataGenerator(dataset_name, x_dim, z_dim, corruption_type)
    if original_dataset_name == 'pcp_fail':
        return PartiallyLinearDataGenerator(dataset_name, x_dim, z_dim, corruption_type)
    if original_dataset_name == 'synthetic_causal':
        return CausalInferenceDataGenerator(x_dim)
    else:
        raise Exception(f"unknown dataset name: {dataset_name} and cannot construct a generator for it")


def get_regression_dataset(args) -> RegressionDataset:
    dataset_name: str = args.dataset_name
    if args.data_type == DataType.Real:
        x, y, z, data_masker, d = get_real_data(args)
        data_generator = None
    else:
        data_generator = get_data_generator(dataset_name, args.x_dim, args.z_dim)
        x, y, z, data_masker, d = data_generator.generate_data(args.data_size, args.device)
        probability_to_delete = data_masker.get_corruption_probabilities(x, z)
        reduced_z = data_masker.base_corruption_masker.covariates_dimension_reducer(x, z)
        plt.scatter(reduced_z.cpu(), probability_to_delete.cpu())
        plt.xlabel('z')
        plt.ylabel('missing prob.')
        plt.title("real probabilities")
        plt.show()
    if len(d.shape) == 2:
        marginal_mask_probability = d.any(dim=1).float().mean().item()
        max_mask_probability = data_masker.get_corruption_probabilities(x, z).max().item()
    elif len(d.shape) == 1:
        marginal_mask_probability = d.float().mean().item()
        max_mask_probability = data_masker.get_corruption_probabilities(x, z).max().item()
    else:
        raise Exception(f"don't know how to handle with len(d.shape)={len(d.shape)}")
    if marginal_mask_probability > 0.5 or marginal_mask_probability < 0.05:
        print(f"warning: marginal mask ratio={marginal_mask_probability}")
    if max_mask_probability > 0.95 or max_mask_probability < 0.3:
        print(f"warning: max mask ratio={max_mask_probability}")
    dataset = RegressionDataset(x, y, z, d, data_masker, dataset_name, args.training_ratio, args.validation_ratio,
                                args.calibration_ratio, args.device,
                                args.saved_models_path, args.figures_dir, args.seed,
                                data_generator=data_generator)
    print(f"data size: {x.shape[0]}, x_dim: {dataset.x_dim}, y_dim: {dataset.y_dim} z_dim: {dataset.z_dim}")
    return dataset


def load_real_regression_dataset(name: str, base_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
    """ Load a dataset
    Parameters
    ----------
    name : string, dataset name
    base_path : string, e.g. "path/to/RCOL_datasets/directory/"
    Returns
    -------
    X : features (nXp)
    y : labels (n)
	"""

    if name == "meps_19":
        df = pd.read_csv(os.path.join(base_path, 'meps_19_reg_fix.csv'))
        column_names = df.columns
        response_name = "UTILIZATION_reg"
        column_names = column_names[column_names != response_name]
        column_names = column_names[column_names != "Unnamed: 0"]

        col_names = ['AGE', 'PCS42', 'MCS42', 'K6SUM42', 'PERWT15F', 'REGION=1',
                     'REGION=2', 'REGION=3', 'REGION=4', 'SEX=1', 'SEX=2', 'MARRY=1',
                     'MARRY=2', 'MARRY=3', 'MARRY=4', 'MARRY=5', 'MARRY=6', 'MARRY=7',
                     'MARRY=8', 'MARRY=9', 'MARRY=10', 'FTSTU=-1', 'FTSTU=1', 'FTSTU=2',
                     'FTSTU=3', 'ACTDTY=1', 'ACTDTY=2', 'ACTDTY=3', 'ACTDTY=4',
                     'HONRDC=1', 'HONRDC=2', 'HONRDC=3', 'HONRDC=4', 'RTHLTH=-1',
                     'RTHLTH=1', 'RTHLTH=2', 'RTHLTH=3', 'RTHLTH=4', 'RTHLTH=5',
                     'MNHLTH=-1', 'MNHLTH=1', 'MNHLTH=2', 'MNHLTH=3', 'MNHLTH=4',
                     'MNHLTH=5', 'HIBPDX=-1', 'HIBPDX=1', 'HIBPDX=2', 'CHDDX=-1',
                     'CHDDX=1', 'CHDDX=2', 'ANGIDX=-1', 'ANGIDX=1', 'ANGIDX=2',
                     'MIDX=-1', 'MIDX=1', 'MIDX=2', 'OHRTDX=-1', 'OHRTDX=1', 'OHRTDX=2',
                     'STRKDX=-1', 'STRKDX=1', 'STRKDX=2', 'EMPHDX=-1', 'EMPHDX=1',
                     'EMPHDX=2', 'CHBRON=-1', 'CHBRON=1', 'CHBRON=2', 'CHOLDX=-1',
                     'CHOLDX=1', 'CHOLDX=2', 'CANCERDX=-1', 'CANCERDX=1', 'CANCERDX=2',
                     'DIABDX=-1', 'DIABDX=1', 'DIABDX=2', 'JTPAIN=-1', 'JTPAIN=1',
                     'JTPAIN=2', 'ARTHDX=-1', 'ARTHDX=1', 'ARTHDX=2', 'ARTHTYPE=-1',
                     'ARTHTYPE=1', 'ARTHTYPE=2', 'ARTHTYPE=3', 'ASTHDX=1', 'ASTHDX=2',
                     'ADHDADDX=-1', 'ADHDADDX=1', 'ADHDADDX=2', 'PREGNT=-1', 'PREGNT=1',
                     'PREGNT=2', 'WLKLIM=-1', 'WLKLIM=1', 'WLKLIM=2', 'ACTLIM=-1',
                     'ACTLIM=1', 'ACTLIM=2', 'SOCLIM=-1', 'SOCLIM=1', 'SOCLIM=2',
                     'COGLIM=-1', 'COGLIM=1', 'COGLIM=2', 'DFHEAR42=-1', 'DFHEAR42=1',
                     'DFHEAR42=2', 'DFSEE42=-1', 'DFSEE42=1', 'DFSEE42=2',
                     'ADSMOK42=-1', 'ADSMOK42=1', 'ADSMOK42=2', 'PHQ242=-1', 'PHQ242=0',
                     'PHQ242=1', 'PHQ242=2', 'PHQ242=3', 'PHQ242=4', 'PHQ242=5',
                     'PHQ242=6', 'EMPST=-1', 'EMPST=1', 'EMPST=2', 'EMPST=3', 'EMPST=4',
                     'POVCAT=1', 'POVCAT=2', 'POVCAT=3', 'POVCAT=4', 'POVCAT=5',
                     'INSCOV=1', 'INSCOV=2', 'INSCOV=3', 'RACE']

        y = df[response_name].values
        X = df[col_names].values

    elif name == "meps_20":
        df = pd.read_csv(os.path.join(base_path, 'facebook/meps_20_reg_fix.csv'))
        column_names = df.columns
        response_name = "UTILIZATION_reg"
        column_names = column_names[column_names != response_name]
        column_names = column_names[column_names != "Unnamed: 0"]

        col_names = ['AGE', 'PCS42', 'MCS42', 'K6SUM42', 'PERWT15F', 'REGION=1',
                     'REGION=2', 'REGION=3', 'REGION=4', 'SEX=1', 'SEX=2', 'MARRY=1',
                     'MARRY=2', 'MARRY=3', 'MARRY=4', 'MARRY=5', 'MARRY=6', 'MARRY=7',
                     'MARRY=8', 'MARRY=9', 'MARRY=10', 'FTSTU=-1', 'FTSTU=1', 'FTSTU=2',
                     'FTSTU=3', 'ACTDTY=1', 'ACTDTY=2', 'ACTDTY=3', 'ACTDTY=4',
                     'HONRDC=1', 'HONRDC=2', 'HONRDC=3', 'HONRDC=4', 'RTHLTH=-1',
                     'RTHLTH=1', 'RTHLTH=2', 'RTHLTH=3', 'RTHLTH=4', 'RTHLTH=5',
                     'MNHLTH=-1', 'MNHLTH=1', 'MNHLTH=2', 'MNHLTH=3', 'MNHLTH=4',
                     'MNHLTH=5', 'HIBPDX=-1', 'HIBPDX=1', 'HIBPDX=2', 'CHDDX=-1',
                     'CHDDX=1', 'CHDDX=2', 'ANGIDX=-1', 'ANGIDX=1', 'ANGIDX=2',
                     'MIDX=-1', 'MIDX=1', 'MIDX=2', 'OHRTDX=-1', 'OHRTDX=1', 'OHRTDX=2',
                     'STRKDX=-1', 'STRKDX=1', 'STRKDX=2', 'EMPHDX=-1', 'EMPHDX=1',
                     'EMPHDX=2', 'CHBRON=-1', 'CHBRON=1', 'CHBRON=2', 'CHOLDX=-1',
                     'CHOLDX=1', 'CHOLDX=2', 'CANCERDX=-1', 'CANCERDX=1', 'CANCERDX=2',
                     'DIABDX=-1', 'DIABDX=1', 'DIABDX=2', 'JTPAIN=-1', 'JTPAIN=1',
                     'JTPAIN=2', 'ARTHDX=-1', 'ARTHDX=1', 'ARTHDX=2', 'ARTHTYPE=-1',
                     'ARTHTYPE=1', 'ARTHTYPE=2', 'ARTHTYPE=3', 'ASTHDX=1', 'ASTHDX=2',
                     'ADHDADDX=-1', 'ADHDADDX=1', 'ADHDADDX=2', 'PREGNT=-1', 'PREGNT=1',
                     'PREGNT=2', 'WLKLIM=-1', 'WLKLIM=1', 'WLKLIM=2', 'ACTLIM=-1',
                     'ACTLIM=1', 'ACTLIM=2', 'SOCLIM=-1', 'SOCLIM=1', 'SOCLIM=2',
                     'COGLIM=-1', 'COGLIM=1', 'COGLIM=2', 'DFHEAR42=-1', 'DFHEAR42=1',
                     'DFHEAR42=2', 'DFSEE42=-1', 'DFSEE42=1', 'DFSEE42=2',
                     'ADSMOK42=-1', 'ADSMOK42=1', 'ADSMOK42=2', 'PHQ242=-1', 'PHQ242=0',
                     'PHQ242=1', 'PHQ242=2', 'PHQ242=3', 'PHQ242=4', 'PHQ242=5',
                     'PHQ242=6', 'EMPST=-1', 'EMPST=1', 'EMPST=2', 'EMPST=3', 'EMPST=4',
                     'POVCAT=1', 'POVCAT=2', 'POVCAT=3', 'POVCAT=4', 'POVCAT=5',
                     'INSCOV=1', 'INSCOV=2', 'INSCOV=3', 'RACE']

        y = df[response_name].values
        X = df[col_names].values

    elif name == "meps_21":

        df = pd.read_csv(os.path.join(base_path, 'facebook/meps_21_reg_fix.csv'))
        column_names = df.columns
        response_name = "UTILIZATION_reg"
        column_names = column_names[column_names != response_name]
        column_names = column_names[column_names != "Unnamed: 0"]

        col_names = ['AGE', 'PCS42', 'MCS42', 'K6SUM42', 'PERWT16F', 'REGION=1',
                     'REGION=2', 'REGION=3', 'REGION=4', 'SEX=1', 'SEX=2', 'MARRY=1',
                     'MARRY=2', 'MARRY=3', 'MARRY=4', 'MARRY=5', 'MARRY=6', 'MARRY=7',
                     'MARRY=8', 'MARRY=9', 'MARRY=10', 'FTSTU=-1', 'FTSTU=1', 'FTSTU=2',
                     'FTSTU=3', 'ACTDTY=1', 'ACTDTY=2', 'ACTDTY=3', 'ACTDTY=4',
                     'HONRDC=1', 'HONRDC=2', 'HONRDC=3', 'HONRDC=4', 'RTHLTH=-1',
                     'RTHLTH=1', 'RTHLTH=2', 'RTHLTH=3', 'RTHLTH=4', 'RTHLTH=5',
                     'MNHLTH=-1', 'MNHLTH=1', 'MNHLTH=2', 'MNHLTH=3', 'MNHLTH=4',
                     'MNHLTH=5', 'HIBPDX=-1', 'HIBPDX=1', 'HIBPDX=2', 'CHDDX=-1',
                     'CHDDX=1', 'CHDDX=2', 'ANGIDX=-1', 'ANGIDX=1', 'ANGIDX=2',
                     'MIDX=-1', 'MIDX=1', 'MIDX=2', 'OHRTDX=-1', 'OHRTDX=1', 'OHRTDX=2',
                     'STRKDX=-1', 'STRKDX=1', 'STRKDX=2', 'EMPHDX=-1', 'EMPHDX=1',
                     'EMPHDX=2', 'CHBRON=-1', 'CHBRON=1', 'CHBRON=2', 'CHOLDX=-1',
                     'CHOLDX=1', 'CHOLDX=2', 'CANCERDX=-1', 'CANCERDX=1', 'CANCERDX=2',
                     'DIABDX=-1', 'DIABDX=1', 'DIABDX=2', 'JTPAIN=-1', 'JTPAIN=1',
                     'JTPAIN=2', 'ARTHDX=-1', 'ARTHDX=1', 'ARTHDX=2', 'ARTHTYPE=-1',
                     'ARTHTYPE=1', 'ARTHTYPE=2', 'ARTHTYPE=3', 'ASTHDX=1', 'ASTHDX=2',
                     'ADHDADDX=-1', 'ADHDADDX=1', 'ADHDADDX=2', 'PREGNT=-1', 'PREGNT=1',
                     'PREGNT=2', 'WLKLIM=-1', 'WLKLIM=1', 'WLKLIM=2', 'ACTLIM=-1',
                     'ACTLIM=1', 'ACTLIM=2', 'SOCLIM=-1', 'SOCLIM=1', 'SOCLIM=2',
                     'COGLIM=-1', 'COGLIM=1', 'COGLIM=2', 'DFHEAR42=-1', 'DFHEAR42=1',
                     'DFHEAR42=2', 'DFSEE42=-1', 'DFSEE42=1', 'DFSEE42=2',
                     'ADSMOK42=-1', 'ADSMOK42=1', 'ADSMOK42=2', 'PHQ242=-1', 'PHQ242=0',
                     'PHQ242=1', 'PHQ242=2', 'PHQ242=3', 'PHQ242=4', 'PHQ242=5',
                     'PHQ242=6', 'EMPST=-1', 'EMPST=1', 'EMPST=2', 'EMPST=3', 'EMPST=4',
                     'POVCAT=1', 'POVCAT=2', 'POVCAT=3', 'POVCAT=4', 'POVCAT=5',
                     'INSCOV=1', 'INSCOV=2', 'INSCOV=3', 'RACE']

        y = df[response_name].values
        X = df[col_names].values

    elif name == "facebook_1":

        df = pd.read_csv(os.path.join(base_path, 'facebook/Features_Variant_1.csv'))
        y = df.iloc[:, 53].values
        X = df.iloc[:, 0:53].values

    elif name == "facebook_2":
        df = pd.read_csv(os.path.join(base_path, 'facebook/Features_Variant_2.csv'))
        y = df.iloc[:, 53].values
        X = df.iloc[:, 0:53].values

    elif name == "bio":
        # https://github.com/joefavergel/TertiaryPhysicochemicalProperties/blob/master/RMSD-ProteinTertiaryStructures.ipynb
        df = pd.read_csv(os.path.join(base_path, 'CASP.csv'))
        y = df.iloc[:, 0].values
        X = df.iloc[:, 1:].values
    elif name == 'house':
        df = pd.read_csv(os.path.join(base_path, 'kc_house_data.csv'))
        y = np.array(df['price'])
        X = (df.drop(['id', 'date', 'price'], axis=1)).values
    elif name == 'blog_data' or name == 'blog':
        # https://github.com/xinbinhuang/feature-selection_blogfeedback
        df = pd.read_csv(os.path.join(base_path, 'blogData_train.csv'), header=None)
        X = df.iloc[:, 0:280].values
        y = df.iloc[:, -1].values
    else:
        raise Exception(f"invalid dataset: {name}")

    X = torch.from_numpy(X).float()
    y = torch.from_numpy(y).float()
    return X, y
