{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "sys.path.append('../..')\n",
    "sys.path.append('.')\n",
    "sys.path.append('./scripts')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['axes.labelsize'] = 30\n",
    "plt.rcParams['xtick.labelsize'] = 10\n",
    "plt.rcParams['ytick.labelsize'] = 12\n",
    "plt.rcParams['axes.spines.right'] = False\n",
    "plt.rcParams['axes.spines.top'] = False\n",
    "plt.rcParams['axes.edgecolor'] = 'black'\n",
    "plt.rcParams['axes.linewidth'] = 2.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = \"rf\" # \"rf\" \"gradient_boosting\"\n",
    "\n",
    "methods = [\n",
    "    'lmdi+',\n",
    "    'LIME',\n",
    "    'Treeshap',\n",
    "]\n",
    "\n",
    "\n",
    "color_map = {\n",
    "    'LIME': '#71BEB7',\n",
    "    'Treeshap': 'orange',\n",
    "    'lmdi+': 'black',\n",
    "}\n",
    "methods_name = {\n",
    "    'LIME': 'LIME',\n",
    "    'Treeshap': 'TreeSHAP',\n",
    "    'lmdi+': 'LMDI+',\n",
    "}\n",
    "\n",
    "if model == \"rf\":\n",
    "    methods.extend(['Local MDI'])\n",
    "    color_map['Local MDI'] = '#9B5DFF'\n",
    "    methods_name['Local MDI'] = 'Local MDI'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task = \"classification\" \n",
    "combined_df = pd.DataFrame()\n",
    "datasets = ['openml_361063', 'openml_361069', 'openml_361062', 'openml_9978', 'openml_361071', 'openml_43']\n",
    "split_seeds = [1,2,3,4]\n",
    "sample_seeds = [1,2,3,4,5]\n",
    "for data in datasets:\n",
    "    ablation_directory =f\"./results/mdi_local_{model}.real_data_{task}_{data}/{data}_mean/varying_sample_row_n\"\n",
    "    for split_seed in split_seeds:\n",
    "        for sample_seed in sample_seeds:\n",
    "            try:\n",
    "                df = pd.read_csv(os.path.join(ablation_directory, f\"seed_{split_seed}_{sample_seed}/results.csv\"))\n",
    "                df[\"data\"] = data\n",
    "                combined_df = pd.concat([combined_df, df], ignore_index=True)\n",
    "            except:\n",
    "                print(f\"File not found: {os.path.join(ablation_directory, f'seed_{split_seed}_{sample_seed}/results.csv')}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_name = {\n",
    "    \"openml_43\": \"Spam\",\n",
    "    \"openml_361062\": \"Pol\",\n",
    "    \"openml_361071\": \"Jannis\",\n",
    "    \"openml_9978\": \"Ozone\",\n",
    "    \"openml_361069\": \"Higgs\",\n",
    "    \"openml_361063\": \"House 16H\"\n",
    "}\n",
    "\n",
    "feature_values = {\n",
    "    \"openml_43\": 57,\n",
    "    \"openml_361062\": 26,\n",
    "    \"openml_361071\": 54,\n",
    "    \"openml_9978\": 47,\n",
    "    \"openml_361069\": 24,\n",
    "    \"openml_361063\": 16\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = combined_df\n",
    "datasets = df[\"data\"].unique()\n",
    "\n",
    "marker_size = 7\n",
    "\n",
    "n_cols = 3\n",
    "n_rows = 2\n",
    "\n",
    "fig, axs = plt.subplots(\n",
    "    nrows=n_rows,\n",
    "    ncols=n_cols,\n",
    "    figsize=(8 * n_cols, 5.5 * n_rows),\n",
    "    sharey=False\n",
    ")\n",
    "\n",
    "axs = axs.flatten()\n",
    "\n",
    "for idx, dataset in enumerate(datasets):\n",
    "    ax = axs[idx]\n",
    "    subset = df[df[\"data\"] == dataset]\n",
    "\n",
    "    for method in methods:\n",
    "        method_data = subset[subset[\"fi\"] == method]\n",
    "\n",
    "        auroc_cols = [\n",
    "            \"AUROC_keep_0.1\", \"AUROC_keep_0.2\", \"AUROC_keep_0.3\",\n",
    "            \"AUROC_keep_0.4\", \"AUROC_keep_0.5\", \"AUROC_keep_0.6\",\n",
    "            \"AUROC_keep_0.7\", \"AUROC_keep_0.8\", \"AUROC_keep_0.9\",\n",
    "            \"AUROC_keep_1.0\"\n",
    "        ]\n",
    "        means = method_data[auroc_cols].mean(axis=0).values\n",
    "        stds = method_data[auroc_cols].std(axis=0).values\n",
    "        counts = method_data[auroc_cols].count(axis=0).values\n",
    "        sems = stds / np.sqrt(counts)\n",
    "\n",
    "        x = [0.1, 0.2, 0.3, 0.4,0.5,0.6,0.7,0.8,0.9,1.0]\n",
    "\n",
    "        if method in ['LIME', 'Treeshap', 'lmdi', 'Local MDI', 'Maple']:\n",
    "            ax.errorbar(\n",
    "                x, means, sems,\n",
    "                label=methods_name[method], linestyle='solid',\n",
    "                marker='o', markersize=marker_size, color=color_map[method], linewidth=4, alpha=0.6\n",
    "            )\n",
    "        else:\n",
    "            ax.errorbar(\n",
    "                x, means, sems,\n",
    "                label=methods_name[method], linestyle='solid',\n",
    "                marker='o', markersize=marker_size, color=color_map[method], linewidth=4\n",
    "            )\n",
    "\n",
    "    ax.set_xticks([0.1, 0.3, 0.5, 0.7, 0.9])\n",
    "    ax.set_xticklabels([\"10%\", \"30%\", \"50%\", \"70%\", \"90%\"], fontsize=22)\n",
    "    ax.tick_params(axis='y', labelsize=22)\n",
    "    if idx >= (n_rows - 1) * n_cols:\n",
    "        ax.set_xlabel(\"Percentage of Top Features Retained\", fontsize=25)\n",
    "\n",
    "    if idx % n_cols == 0:\n",
    "        ax.set_ylabel(\"AUROC\", fontsize=24)\n",
    "    else:\n",
    "        ax.set_ylabel(\"\")\n",
    "    \n",
    "    dataset_label = data_name[dataset].replace(' ', r'\\ ')\n",
    "    p_val = feature_values[dataset]\n",
    "    ax.set_title(f\"$\\\\mathbf{{{dataset_label} \\ (p={p_val}) }}$\", fontsize=25)\n",
    "    \n",
    "    if (idx + 1) % n_cols == 0 or idx == len(datasets) - 1:\n",
    "        ax.legend(fontsize=18, loc='lower right')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"feature_selection_classification_full_{model}.pdf\", format='pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Regression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task = \"regression\"\n",
    "combined_df = pd.DataFrame()\n",
    "datasets = ['openml_361260', 'openml_361254', 'openml_361259', 'openml_361253', 'openml_361243', 'openml_361242']\n",
    "split_seeds = [1,2,3,4]\n",
    "sample_seeds = [1,2,3,4,5]\n",
    "for data in datasets:\n",
    "    ablation_directory =f\"./results/mdi_local_{model}.real_data_{task}_{data}/{data}_mean/varying_sample_row_n\"\n",
    "    for split_seed in split_seeds:\n",
    "        for sample_seed in sample_seeds:\n",
    "            df = pd.read_csv(os.path.join(ablation_directory, f\"seed_{split_seed}_{sample_seed}/results.csv\"))\n",
    "            df[\"data\"] = data\n",
    "            combined_df = pd.concat([combined_df, df], ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_name = {\n",
    "    \"openml_361260\": \"Miami Housing\",\n",
    "    \"openml_361259\": \"Puma Robot\",\n",
    "    \"openml_361253\": \"Wave Energy\",\n",
    "    \"openml_361254\": \"SARCOS\",\n",
    "    \"openml_361242\": \"Super Conductivity\",\n",
    "    \"openml_361243\": \"Geographic Origin of Music\"\n",
    "}\n",
    "\n",
    "feature_values = {\n",
    "    \"openml_361260\": 15,\n",
    "    \"openml_361259\": 32,\n",
    "    \"openml_361253\": 48,\n",
    "    \"openml_361254\": 21,\n",
    "    \"openml_361242\": 81,\n",
    "    \"openml_361243\": 72\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = combined_df\n",
    "datasets = df[\"data\"].unique()\n",
    "\n",
    "marker_size = 7\n",
    "\n",
    "n_cols = 3\n",
    "n_rows = 2\n",
    "\n",
    "fig, axs = plt.subplots(\n",
    "    nrows=n_rows,\n",
    "    ncols=n_cols,\n",
    "    figsize=(8 * n_cols, 5.5 * n_rows),\n",
    "    sharey=False\n",
    ")\n",
    "\n",
    "axs = axs.flatten()\n",
    "\n",
    "for idx, dataset in enumerate(datasets):\n",
    "    ax = axs[idx]\n",
    "    subset = df[df[\"data\"] == dataset]\n",
    "    \n",
    "    for method in methods:\n",
    "        method_data = subset[subset[\"fi\"] == method]\n",
    "        r2_col = [\"R2_keep_0.1\", \"R2_keep_0.2\", \"R2_keep_0.3\",\n",
    "            \"R2_keep_0.4\", \"R2_keep_0.5\", \"R2_keep_0.6\", \"R2_keep_0.7\",\n",
    "            \"R2_keep_0.8\", \"R2_keep_0.9\", \"R2_keep_1.0\"]\n",
    "        means = method_data[r2_col].mean(axis=0).values\n",
    "        stds = method_data[r2_col].std(axis=0).values\n",
    "        counts = method_data[r2_col].count(axis=0).values\n",
    "        sems = stds / np.sqrt(counts)\n",
    "        x = [0.1, 0.2, 0.3, 0.4,0.5,0.6,0.7,0.8,0.9,1.0]\n",
    "\n",
    "        if method in ['LIME', 'Treeshap', 'Local MDI', 'Maple']:\n",
    "            ax.errorbar(\n",
    "                x, means, sems,\n",
    "                label=methods_name[method], linestyle='solid',\n",
    "                marker='o', markersize=marker_size, color=color_map[method], linewidth=4, alpha=0.6\n",
    "            )\n",
    "        else:\n",
    "            ax.errorbar(\n",
    "                x, means, sems,\n",
    "                label=methods_name[method], linestyle='solid',\n",
    "                marker='o', markersize=marker_size, color=color_map[method], linewidth=4\n",
    "            )\n",
    "\n",
    "    ax.set_xticks([0.1, 0.3, 0.5, 0.7, 0.9])\n",
    "    ax.set_xticklabels([\"10%\", \"30%\", \"50%\", \"70%\", \"90%\"], fontsize=22)\n",
    "    ax.tick_params(axis='y', labelsize=22)\n",
    "    if idx >= (n_rows - 1) * n_cols:\n",
    "        ax.set_xlabel(\"Percentage of Top Features Retained\", fontsize=25)\n",
    "    \n",
    "    if idx % n_cols == 0:\n",
    "        ax.set_ylabel(r\"$R^2$\", fontsize=24)\n",
    "    else:\n",
    "        ax.set_ylabel(\"\")\n",
    "    \n",
    "    dataset_label = data_name[dataset].replace(' ', r'\\ ')\n",
    "    p_val = feature_values[dataset]\n",
    "    ax.set_title(f\"$\\\\mathbf{{{dataset_label} \\ (p={p_val}) }}$\", fontsize=25)\n",
    "    \n",
    "    if (idx + 1) % n_cols == 0 or idx == len(datasets) - 1:\n",
    "        ax.legend(fontsize=18, loc='lower right')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"feature_selection_regression_full_{model}.pdf\", format='pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mdi-new",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
