{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1868668",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e08c53cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "import sklearn\n",
    "\n",
    "import imodels"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed914d16",
   "metadata": {},
   "source": [
    "# Data Check\n",
    "\n",
    "Here we check the data sets we found through the references authors of HS paper provide. Each data set is accompanied with code for parsing it and the data set authors provide in their Github repo."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fee6335a",
   "metadata": {},
   "source": [
    "# 1. Classification data sets"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb4faa62",
   "metadata": {},
   "source": [
    "## Heart\n",
    "Link: https://archive.ics.uci.edu/ml/datasets/Statlog+%28Heart%29\n",
    "\n",
    "Same: __TRUE__\n",
    "\n",
    "The authors of HS have normalized the variables, and split variable \"thal\" into three different columns. This was veryfied by looking at the distribution of \"thal\" variable and variables \"att_13_-1.0\", \"att_13_0.5\", \"att_13_1.0\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcdf531c",
   "metadata": {},
   "outputs": [],
   "source": [
    "HEART_COLS_TRANSLATE = {\n",
    "    0: \"age\", 1: \"sex\", 2: \"chest_pain_type\", 3: \"resting_blood_pressure\",  4: \"serum_cholesterol\",\n",
    "    5: \"fasting_blood_sugar\", 6: \"resting_electrocardiographic_results\", 7: \"maximum_heart_rate_achieved\",\n",
    "    8: \"exercise_induced_angina\", 9: \"oldpeak\", 10: \"slope_of_the_peak\", 11: \"number_of_major_vessels\",\n",
    "    12: \"thal\", 13: \"heart_disease\"\n",
    "}\n",
    "HEART_COLS_REAL = [\"age\", \"resting_blood_pressure\", \"serum_cholesterol\", \"maximum_heart_rate_achieved\",\n",
    "            \"oldpeak\", \"number_of_major_vessels\"]\n",
    "HEART_COLS_BINARY = [\"sex\", \"fasting_blood_sugar\", \"exercise_induced_angina\", \"heart_disease\"]\n",
    "HEART_COLS_ORDERED = [\"slope_of_the_peak\"]\n",
    "HEART_COLS_NOMINAL = [\"chest_pain_type\", \"resting_electrocardiographic_results\", \"thal\"]\n",
    "\n",
    "\n",
    "heart = pd.DataFrame(columns=list(range(len(HEART_COLS_TRANSLATE))))\n",
    "with open(os.path.abspath(\"../../data/classification/heart.dat\"), \"r\") as f:\n",
    "    for line in f:\n",
    "        params = line.strip().split(\" \")\n",
    "        heart = pd.concat([heart, pd.DataFrame.from_dict({i: [v] for i, v in enumerate(params)})], ignore_index=True)\n",
    "\n",
    "heart = heart.rename(HEART_COLS_TRANSLATE, axis=1)\n",
    "for col in HEART_COLS_REAL:\n",
    "    heart[col] = heart[col].astype(float)\n",
    "for col in HEART_COLS_BINARY+HEART_COLS_NOMINAL:\n",
    "    heart[col] = heart[col].astype(float)#.astype(int)\n",
    "for col in HEART_COLS_ORDERED:\n",
    "    heart[col] = heart[col].astype(float)#.astype(int).astype(str)\n",
    "heart.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8120225",
   "metadata": {},
   "outputs": [],
   "source": [
    "for col in heart.columns:\n",
    "    _min = heart[col].min()\n",
    "    _max = heart[col].max()\n",
    "    heart[col] = 2 * (heart[col]-_min) / (_max-_min) -1\n",
    "\n",
    "for col in HEART_COLS_BINARY+HEART_COLS_ORDERED+HEART_COLS_NOMINAL:\n",
    "    heart[col] = (heart[col] + 1) / 2\n",
    "\n",
    "heart.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f7230fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y, cols = imodels.util.data_util.get_clean_dataset(\"heart\", \"imodels\")\n",
    "heart_hs = pd.DataFrame(X)\n",
    "heart_hs = heart_hs.rename({i: v for i, v in enumerate(cols)}, axis=1)\n",
    "heart_hs[\"heart_disease\"] = y\n",
    "heart_hs.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38486590",
   "metadata": {},
   "outputs": [],
   "source": [
    "heart_hs.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4dc2038",
   "metadata": {},
   "source": [
    "## Breast cancer\n",
    "Link: https://www.openml.org/search?type=data&status=active&sort=runs&id=13\n",
    "\n",
    "Same: __TRUE__\n",
    "\n",
    "At first glance the data sets seem different, but the authors of HS transform each feature into numeric and leave out values that are present in attribute definition but not in the data set. Authors also dropped all rows with unknown values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f71ac291",
   "metadata": {},
   "outputs": [],
   "source": [
    "BREAST_AGE_TRANSLATE = {\n",
    "    \"10-19\": 0, \"20-29\": 1, \"30-39\": 2, \"40-49\": 3, \"50-59\": 4, \"60-69\": 5,\n",
    "    \"70-79\": 6, \"80-89\": 7, \"90-99\": 8}\n",
    "BREAST_TUMOR_TRANSLATE = {\n",
    "    \"0-4\": 0, \"5-9\": 1, \"10-14\":2 , \"15-19\": 3, \"20-24\": 4, \"25-29\": 5, \"30-34\": 6,\n",
    "    \"35-39\": 7, \"40-44\": 8, \"45-49\": 9, \"50-54\": 10, \"55-59\": 11}\n",
    "BREAST_INV_TRANSLATE = {\n",
    "    \"0-2\": 0, \"3-5\": 1, \"6-8\": 2, \"9-11\": 3, \"12-14\": 4, \"15-17\": 5, \"18-20\": 6,\n",
    "    \"21-23\": 7, \"24-26\": 8, \"27-29\": 9, \"30-32\": 10, \"33-35\": 11, \"36-39\": 12}\n",
    "\n",
    "\n",
    "breast_cancer = pd.DataFrame(columns=list(range(10)))\n",
    "with open(os.path.abspath(\"../../data/classification/dataset_13_breast-cancer.arff\"), \"r\") as f:\n",
    "    header = list()\n",
    "    for line in f:\n",
    "        if line.startswith(\"@attribute\"):\n",
    "            header.append(line.split(\"'\")[1])\n",
    "        if line.startswith(\"%\") or line.startswith(\"@\"):\n",
    "            continue\n",
    "        params = line.strip().replace(\"'\", \"\").split(\",\")\n",
    "        breast_cancer = pd.concat([breast_cancer, pd.DataFrame.from_dict({i: [v] for i, v in enumerate(params)})], ignore_index=True)\n",
    "\n",
    "breast_cancer = breast_cancer.rename({i: v for i, v in enumerate(header)}, axis=1)\n",
    "breast_cancer = breast_cancer.replace(\"?\", np.nan).dropna()\n",
    "breast_cancer[\"age\"] = breast_cancer[\"age\"].apply(lambda x: BREAST_AGE_TRANSLATE[x])\n",
    "breast_cancer[\"tumor-size\"] = breast_cancer[\"tumor-size\"].apply(lambda x: BREAST_TUMOR_TRANSLATE[x])\n",
    "breast_cancer[\"inv-nodes\"] = breast_cancer[\"inv-nodes\"].apply(lambda x: BREAST_INV_TRANSLATE[x])\n",
    "breast_cancer[\"node-caps\"] = (breast_cancer[\"node-caps\"] == \"yes\")*1\n",
    "breast_cancer[\"breast\"] = (breast_cancer[\"breast\"] == \"right\")*1\n",
    "breast_cancer[\"irradiat\"] = (breast_cancer[\"irradiat\"] == \"yes\")*1\n",
    "breast_cancer[\"Class\"] = (breast_cancer[\"Class\"] == \"recurrence-events\")*1\n",
    "breast_cancer = pd.get_dummies(breast_cancer, columns=[\"deg-malig\", \"menopause\", \"breast-quad\"])\n",
    "breast_cancer.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa5de4b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "breast_cancer.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f17bf323",
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y, cols = imodels.util.data_util.get_clean_dataset(\"breast_cancer\", \"imodels\")\n",
    "breast_cancer_hs = pd.DataFrame(X)\n",
    "breast_cancer_hs = breast_cancer_hs.rename({i: v for i, v in enumerate(cols)}, axis=1)\n",
    "breast_cancer_hs[\"cancer\"] = y\n",
    "breast_cancer_hs.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f98df41",
   "metadata": {},
   "outputs": [],
   "source": [
    "breast_cancer_hs.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7bf02cc5",
   "metadata": {},
   "source": [
    "## Haberman\n",
    "Link: https://archive.ics.uci.edu/ml/datasets/Haberman%27s+Survival\n",
    "\n",
    "Same: __TRUE__\n",
    "\n",
    "Exact match for the paper. The data set that the authors provide is the same, but they do subtract 58 from the column \"Patients_year_of_operation\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28c9a0d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "HABERMAN_COLS_TRANSLATE = {\n",
    "    0: \"age\", 1: \"year_of_operation\", 2: \"positive_axillary_nodes_detected\", 3: \"survival\"\n",
    "}\n",
    "\n",
    "haberman = pd.DataFrame(columns=list(range(len(HABERMAN_COLS_TRANSLATE))))\n",
    "with open(os.path.abspath(\"../../data/classification/haberman.data\"), \"r\") as f:\n",
    "    for line in f:\n",
    "        params = line.strip().split(\",\")\n",
    "        haberman = pd.concat([haberman, pd.DataFrame.from_dict({i: [v] for i, v in enumerate(params)})], ignore_index=True)\n",
    "\n",
    "haberman = haberman.rename(HABERMAN_COLS_TRANSLATE, axis=1)\n",
    "haberman = haberman.astype(int)\n",
    "haberman[\"year_of_operation\"] = haberman[\"year_of_operation\"]-58\n",
    "haberman[\"survival\"] = -1*haberman[\"survival\"]+2\n",
    "haberman.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7a17b30",
   "metadata": {},
   "outputs": [],
   "source": [
    "haberman.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8da97c32",
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y, cols = imodels.util.data_util.get_clean_dataset(\"haberman\", \"imodels\")\n",
    "haberman_hs = pd.DataFrame(X)\n",
    "haberman_hs = haberman_hs.rename({i: v for i, v in enumerate(cols)}, axis=1)\n",
    "haberman_hs[\"survival\"] = y\n",
    "haberman_hs.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cba99b5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "haberman_hs.describe()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1fb34174",
   "metadata": {},
   "source": [
    "## Ionosphere\n",
    "Link: https://archive.ics.uci.edu/ml/datasets/Ionosphere\n",
    "\n",
    "Same: __TRUE__\n",
    "\n",
    "Exact match for the paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90df06c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "IONOSPHERE_COLS_TRANSLATE = {\n",
    "    i: f\"attr_{i}\" for i in range(34)\n",
    "}\n",
    "IONOSPHERE_COLS_TRANSLATE[34] = \"ionosphere\"\n",
    "\n",
    "ionosphere = pd.DataFrame(columns=list(range(len(IONOSPHERE_COLS_TRANSLATE))))\n",
    "with open(os.path.abspath(\"../../data/classification/ionosphere.data\"), \"r\") as f:\n",
    "    for line in f:\n",
    "        params = line.strip().split(\",\")\n",
    "        ionosphere = pd.concat([ionosphere, pd.DataFrame.from_dict({i: [v] for i, v in enumerate(params)})], ignore_index=True)\n",
    "\n",
    "ionosphere = ionosphere.rename(IONOSPHERE_COLS_TRANSLATE, axis=1)\n",
    "ionosphere[\"ionosphere\"] = ionosphere[\"ionosphere\"] == \"g\"\n",
    "ionosphere.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9422be3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "ionosphere.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04329e48",
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y, cols = imodels.util.data_util.get_clean_dataset(\"ionosphere\", \"pmlb\")\n",
    "ionosphere_hs = pd.DataFrame(X)\n",
    "ionosphere_hs = ionosphere_hs.rename({i: v for i, v in enumerate(cols)}, axis=1)\n",
    "ionosphere_hs[\"ionosphere\"] = y\n",
    "ionosphere_hs.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ef6d993",
   "metadata": {},
   "outputs": [],
   "source": [
    "ionosphere_hs.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38f590a5",
   "metadata": {},
   "source": [
    "## Diabetes\n",
    "Link: https://www.kaggle.com/datasets/mathchi/diabetes-data-set\n",
    "\n",
    "Same: __TRUE__\n",
    "\n",
    "Exact match for the paper. The rows are scrambled, but we can see that the data distribution is the same for all columns in both data sets.\n",
    "\n",
    "Note: the data set was found on Kaggle and not on the UCI as the authors of Random Forest paper state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00b62781",
   "metadata": {},
   "outputs": [],
   "source": [
    "diabetes = pd.read_csv(os.path.abspath(\"../../data/classification/diabetes.csv\"))\n",
    "diabetes.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c2d99f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "diabetes.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c53317e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y, cols = imodels.util.data_util.get_clean_dataset(\"diabetes\", \"pmlb\")\n",
    "diabetes_hs = pd.DataFrame(X)\n",
    "diabetes_hs = diabetes_hs.rename({i: v for i, v in enumerate(cols)}, axis=1)\n",
    "diabetes_hs[\"diabetes\"] = y\n",
    "diabetes_hs.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fddd4749",
   "metadata": {},
   "outputs": [],
   "source": [
    "diabetes_hs.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87b6d1d2",
   "metadata": {},
   "source": [
    "## German credit\n",
    "Link: https://archive.ics.uci.edu/ml/datasets/South+German+Credit+%28UPDATE%29\n",
    "\n",
    "Same: __FALSE__\n",
    "\n",
    "Has the same number of features and instances, features also coinside with the german names. But the data distribution of the majority of variables is different. Here are variables with the same distributions:\n",
    "- duration\n",
    "- amount/credit\n",
    "- installment rate\n",
    "- age\n",
    "- number credits/existing credits\n",
    "\n",
    "Target variable is the same. Most of the other variables seem to have different range (found data set has column range from 1 to 5, data set from authors of HS have from 0 to 4). Even if we scale the data set back, there is still difference between the distributions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8987194",
   "metadata": {},
   "outputs": [],
   "source": [
    "GERMAN_COLS_TRANSLATE = [\n",
    "    \"status\", \"duration\", \"credit_history\", \"purpose\", \"amount\", \"savings\",\n",
    "    \"employment_duration\", \"installment_rate\", \"personal_status_sex\",\n",
    "    \"other_debtors\", \"present_residence\", \"property\", \"age\",\n",
    "    \"other_installment_plans\", \"housing\", \"number_credits\", \"job\", \"people_liable\",\n",
    "    \"telephone\", \"foreign_worker\", \"credit_risk\"]\n",
    "GERMAN_COLS_TRANSLATE = {i: v for i, v in enumerate(GERMAN_COLS_TRANSLATE)}\n",
    "GERMAN_COLS_MINUS = [\n",
    "    \"status\", \"savings\", \"employment_duration\", \"personal_status_sex\", \"other_debtors\",\n",
    "    \"property\", \"other_installment_plans\", \"housing\", \"job\", \"telephone\", \"foreign_worker\"\n",
    "]\n",
    "\n",
    "german_credit = pd.DataFrame(columns=list(range(len(GERMAN_COLS_TRANSLATE))))\n",
    "with open(os.path.abspath(\"../../data/classification/SouthGermanCredit.asc\"), \"r\") as f:\n",
    "    _ = f.readline().strip().split(\" \") #skip header, which is in German\n",
    "    for line in f:\n",
    "        params = line.strip().split(\" \")\n",
    "        german_credit = pd.concat([german_credit, pd.DataFrame.from_dict({i: [v] for i, v in enumerate(params)})], ignore_index=True)\n",
    "german_credit = german_credit.rename(GERMAN_COLS_TRANSLATE, axis=1)\n",
    "german_credit = german_credit.astype(int)\n",
    "for col in GERMAN_COLS_MINUS:\n",
    "    german_credit[col] = german_credit[col] - 1\n",
    "german_credit.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "520cb873",
   "metadata": {},
   "outputs": [],
   "source": [
    "german_credit.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44ac447b",
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y, cols = imodels.util.data_util.get_clean_dataset(\"german\", \"pmlb\")\n",
    "german_credit_hs = pd.DataFrame(X)\n",
    "german_credit_hs = german_credit_hs.rename({i: v for i, v in enumerate(cols)}, axis=1)\n",
    "german_credit_hs[\"credit\"] = y\n",
    "german_credit_hs.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2a1a4eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "german_credit_hs.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6556ca04",
   "metadata": {},
   "source": [
    "## Juvenile\n",
    "Link: https://www.icpsr.umich.edu/web/NACJD/studies/3986\n",
    "\n",
    "Same: __TRUE?__\n",
    "\n",
    "The data set from the upper link has too many columns for us to manually try and get to the format of the authors used data set, but upon looking at the authors code from `https://github.com/csinva/imodels-data/blob/master/notebooks_fetch_data/00_get_datasets_custom.ipynb` we can see that they use the same data set, but clean it a lot."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f94fe0fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y, cols = imodels.util.data_util.get_clean_dataset(\"juvenile_clean\", \"imodels\")\n",
    "juvenile_hs = pd.DataFrame(X)\n",
    "juvenile_hs = juvenile_hs.rename({i: v for i, v in enumerate(cols)}, axis=1)\n",
    "juvenile_hs[\"target\"] = y\n",
    "juvenile_hs.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bbe253b",
   "metadata": {},
   "outputs": [],
   "source": [
    "juvenile_hs.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7414cbd",
   "metadata": {},
   "source": [
    "## Recidivism\n",
    "Link: https://www.propublica.org/datastore/dataset/compas-recidivism-risk-score-data-and-analysis\n",
    "\n",
    "Same: __TRUE__\n",
    "\n",
    "The authors performed some one-hot-encodings (columns race, age, sex, c_charge_degree)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd1e6692",
   "metadata": {},
   "outputs": [],
   "source": [
    "RECIDIVISM_COLS_KEEP = [\n",
    "    \"age\", \"priors_count\", \"days_b_screening_arrest\", \"c_jail_time\",\n",
    "    \"juv_fel_count\", \"juv_other_count\", \"juv_misd_count\", \"c_charge_degree\", \"race\", \"sex\"\n",
    "]\n",
    "\n",
    "recidivism = pd.read_csv(os.path.abspath(\"../../data/classification/compas-scores-two-years.csv\"))\n",
    "recidivism[\"c_jail_time\"] = (pd.to_datetime(recidivism[\"c_jail_out\"])-pd.to_datetime(recidivism[\"c_jail_in\"])) // np.timedelta64(1, \"D\")\n",
    "recidivism = recidivism[RECIDIVISM_COLS_KEEP]\n",
    "recidivism.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee0f7024",
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y, cols = imodels.util.data_util.get_clean_dataset(\"compas_two_year_clean\", \"imodels\")\n",
    "recidivism_hs = pd.DataFrame(X)\n",
    "recidivism_hs = recidivism_hs.rename({i: v for i, v in enumerate(cols)}, axis=1)\n",
    "recidivism_hs[\"target\"] = y\n",
    "recidivism_hs.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19ff09ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "recidivism_hs.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81c8f798",
   "metadata": {},
   "source": [
    "# 2. Regression data sets"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "841f99bf",
   "metadata": {},
   "source": [
    "## Friedman 1 & Friedman 3\n",
    "Link: https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_friedman1.html\n",
    "\n",
    "Link: https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_friedman3.html\n",
    "\n",
    "Same: __TRUE__\n",
    "\n",
    "Both data sets are synthetic, and the authors also use the same Scikit-learn functions to generate them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0930f9dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y = sklearn.datasets.make_friedman1(200, 10)\n",
    "friedman1 = pd.DataFrame(X)\n",
    "friedman1[\"target\"] = y\n",
    "friedman1.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a789c2c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "friedman1.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb41d2d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y, cols = imodels.util.data_util.get_clean_dataset(\"friedman1\", \"synthetic\")\n",
    "friedman1_hs = pd.DataFrame(X)\n",
    "friedman1_hs = friedman1_hs.rename({i: v for i, v in enumerate(cols)}, axis=1)\n",
    "friedman1_hs[\"target\"] = y\n",
    "friedman1_hs.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4f51555",
   "metadata": {},
   "outputs": [],
   "source": [
    "friedman1_hs.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b5b0ad",
   "metadata": {},
   "source": [
    "## Diabetes\n",
    "\n",
    "Link: https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_diabetes.html\n",
    "\n",
    "Same: __TRUE__\n",
    "\n",
    "Also part of the Scikit-learn package, and the authors provide the same data in their imodels package."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe186b2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "diabetes_data = sklearn.datasets.load_diabetes(as_frame=True)\n",
    "diabetes = diabetes_data[\"data\"]\n",
    "diabetes[\"diabetes\"] = diabetes_data[\"target\"]\n",
    "diabetes.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e48d02fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "diabetes.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e926f006",
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y, cols = imodels.util.data_util.get_clean_dataset(\"diabetes\", \"sklearn\")\n",
    "diabetes_hs = pd.DataFrame(X)\n",
    "diabetes_hs = diabetes_hs.rename({i: v for i, v in enumerate(cols)}, axis=1)\n",
    "diabetes_hs[\"diabetes\"] = y\n",
    "diabetes_hs.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7de25425",
   "metadata": {},
   "outputs": [],
   "source": [
    "diabetes_hs.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f581d3e3",
   "metadata": {},
   "source": [
    "## Geographical music\n",
    "\n",
    "Link: https://epistasislab.github.io/pmlb/profile/4544_GeographicalOriginalofMusic.html\n",
    "\n",
    "Link: https://github.com/EpistasisLab/pmlb/tree/master/datasets/4544_GeographicalOriginalofMusic\n",
    "\n",
    "Same: __TRUE__\n",
    "\n",
    "The authors do not provide this data set in their code, but the number of features and samples match.\n",
    "\n",
    "Note: We found the data set online on PMLB and were able to read it with `imodels.util.data_util.get_clean_dataset`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e74700d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "geographical_music = pd.read_csv(os.path.abspath(\"../../data/regression/geographical_music.tsv\"), sep=\"\\t\")\n",
    "geographical_music.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88401317",
   "metadata": {},
   "outputs": [],
   "source": [
    "geographical_music.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee644d32",
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y, cols = imodels.util.data_util.get_clean_dataset(\"4544_GeographicalOriginalofMusic\", \"pmlb\")\n",
    "geographical_music_hs = pd.DataFrame(X)\n",
    "geographical_music_hs = geographical_music_hs.rename({i: v for i, v in enumerate(cols)}, axis=1)\n",
    "geographical_music_hs[\"target\"] = y\n",
    "geographical_music_hs.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "894e31b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "geographical_music_hs.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b82c984a",
   "metadata": {},
   "source": [
    "## Red wine\n",
    "\n",
    "Link: https://archive.ics.uci.edu/ml/datasets/Wine+Quality\n",
    "\n",
    "Same: __TRUE__\n",
    "\n",
    "The authors do not provide this data set in their code, but the number of features and samples match.\n",
    "\n",
    "Note: We found the data set online on PMLB and were able to read it with `imodels.util.data_util.get_clean_dataset`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dda7ad15",
   "metadata": {},
   "outputs": [],
   "source": [
    "red_wine = pd.read_csv(os.path.abspath(\"../../data/regression/winequality-red.csv\"), sep=\";\")\n",
    "red_wine.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5092c0a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "red_wine.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f10765b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y, cols = imodels.util.data_util.get_clean_dataset(\"wine_quality_red\", \"pmlb\")\n",
    "red_wine_hs = pd.DataFrame(X)\n",
    "red_wine_hs = red_wine_hs.rename({i: v for i, v in enumerate(cols)}, axis=1)\n",
    "red_wine_hs[\"quality\"] = y\n",
    "red_wine_hs.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2c70d42",
   "metadata": {},
   "outputs": [],
   "source": [
    "red_wine_hs.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33351e36",
   "metadata": {},
   "source": [
    "## Abalone\n",
    "\n",
    "Link: https://archive.ics.uci.edu/ml/datasets/Abalone\n",
    "\n",
    "Same: __TRUE__\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3e35c12",
   "metadata": {},
   "outputs": [],
   "source": [
    "ABALONE_COLS_TRANSLATE = {\n",
    "    0: \"Sex\", 1: \"Length\", 2: \"Diameter\", 3: \"Height\", 4: \"Whole_weight\",\n",
    "    5: \"Shucked_weight\", 6: \"Viscera_weight\", 7: \"Shell_weight\", 8: \"Rings\"\n",
    "}\n",
    "ABALONE_SEX_TRANSLATE = {\n",
    "    \"M\": 2, \"F\": 0, \"I\": 1\n",
    "}\n",
    "\n",
    "abalone = pd.DataFrame(columns=list(range(len(ABALONE_COLS_TRANSLATE))))\n",
    "with open(os.path.abspath(\"../../data/regression/abalone.data\"), \"r\") as f:\n",
    "    for line in f:\n",
    "        params = line.strip().split(\",\")\n",
    "        abalone = pd.concat([abalone, pd.DataFrame.from_dict({i: [v] for i, v in enumerate(params)})], ignore_index=True)\n",
    "abalone = abalone.rename(ABALONE_COLS_TRANSLATE, axis=1)\n",
    "abalone[\"Sex\"] = abalone[\"Sex\"].apply(lambda x: ABALONE_SEX_TRANSLATE[x])\n",
    "abalone = abalone.astype(float)\n",
    "abalone.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48d83ccc",
   "metadata": {},
   "outputs": [],
   "source": [
    "abalone.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f76f7ec7",
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y, cols = imodels.util.data_util.get_clean_dataset(\"183\", \"openml\")\n",
    "abalone_hs = pd.DataFrame(X)\n",
    "abalone_hs = abalone_hs.rename({i: v for i, v in enumerate(cols)}, axis=1)\n",
    "abalone_hs[\"Rings\"] = y\n",
    "abalone_hs.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b80fe71c",
   "metadata": {},
   "outputs": [],
   "source": [
    "abalone_hs.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a0aa316",
   "metadata": {},
   "source": [
    "## Satellite image\n",
    "\n",
    "Link: https://epistasislab.github.io/pmlb/profile/294_satellite_image.html\n",
    "    \n",
    "Link: https://github.com/EpistasisLab/pmlb/blob/master/datasets/294_satellite_image\n",
    "\n",
    "Same: __TRUE__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c109bf2",
   "metadata": {},
   "outputs": [],
   "source": [
    "satellite = pd.read_csv(os.path.abspath(\"../../data/regression/satellite_image.tsv\"), sep=\"\\t\")\n",
    "satellite.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05fedb4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "satellite.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6e786c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y, cols = imodels.util.data_util.get_clean_dataset(\"294_satellite_image\", \"pmlb\")\n",
    "satellite_hs = pd.DataFrame(X)\n",
    "satellite_hs = satellite_hs.rename({i: v for i, v in enumerate(cols)}, axis=1)\n",
    "satellite_hs[\"target\"] = y\n",
    "satellite_hs.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1a0c2d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "satellite_hs.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "942f5919",
   "metadata": {},
   "source": [
    "## CA housing\n",
    "\n",
    "Link: https://www.kaggle.com/datasets/camnugent/california-housing-prices\n",
    "\n",
    "Same: __TRUE__\n",
    "\n",
    "The columns are swapped, but the data sets are the same."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6c4d678",
   "metadata": {},
   "outputs": [],
   "source": [
    "HOUSING_COLS_TRANSLATE = {\n",
    "    0: \"longitude\",\n",
    "    1: \"latitude\",\n",
    "    2: \"housingMedianAge\",\n",
    "    3: \"totalRooms\",\n",
    "    4: \"totalBedrooms\",\n",
    "    5: \"population\",\n",
    "    6: \"households\",\n",
    "    7: \"medianIncome\",\n",
    "    8: \"medianHouseValue\"\n",
    "}\n",
    "\n",
    "ca_housing = pd.DataFrame(columns=list(range(len(HOUSING_COLS_TRANSLATE))))\n",
    "with open(os.path.abspath(\"../../data/regression/ca_housing.data\"), \"r\") as f:\n",
    "    for line in f:\n",
    "        params = line.strip().split(\",\")\n",
    "        ca_housing = pd.concat([ca_housing, pd.DataFrame.from_dict({i: [v] for i, v in enumerate(params)})], ignore_index=True)\n",
    "ca_housing = ca_housing.astype(float)\n",
    "ca_housing = ca_housing.rename(HOUSING_COLS_TRANSLATE, axis=1)\n",
    "ca_housing.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19d074db",
   "metadata": {},
   "outputs": [],
   "source": [
    "ca_housing.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "410681b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y, cols = imodels.util.data_util.get_clean_dataset(\"california_housing\", \"sklearn\")\n",
    "ca_housing_hs = pd.DataFrame(X)\n",
    "ca_housing_hs = ca_housing_hs.rename({i: v for i, v in enumerate(cols)}, axis=1)\n",
    "ca_housing_hs[\"medianHouseValue\"] = y\n",
    "ca_housing_hs.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78d58964",
   "metadata": {},
   "outputs": [],
   "source": [
    "ca_housing_hs.describe(include=\"all\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bf80d7f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mlds",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8 | packaged by conda-forge | (main, Nov 22 2022, 08:16:33) [MSC v.1929 64 bit (AMD64)]"
  },
  "vscode": {
   "interpreter": {
    "hash": "aaea2abca523dde69918745ab0433be939f84f870f3a8b0d9770b38003b437c2"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
