{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "sys.path.append('../..')\n",
    "sys.path.append('.')\n",
    "sys.path.append('./scripts')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['axes.labelsize'] = 30\n",
    "plt.rcParams['xtick.labelsize'] = 10\n",
    "plt.rcParams['ytick.labelsize'] = 12\n",
    "plt.rcParams['axes.spines.right'] = False\n",
    "plt.rcParams['axes.spines.top'] = False\n",
    "plt.rcParams['axes.edgecolor'] = 'black'\n",
    "plt.rcParams['axes.linewidth'] = 2.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Feature Ranking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = 'rf' # \"rf\" \"gradient_boosting\"\n",
    "\n",
    "methods = [\n",
    "    'lmdi+',    \n",
    "    'LIME',\n",
    "    'Treeshap',\n",
    "    'Local MDI'\n",
    "]\n",
    "color_map = {\n",
    "    'LIME': '#71BEB7',\n",
    "    'Treeshap': 'orange',\n",
    "    'Local MDI': '#9B5DFF',\n",
    "    'lmdi+': 'black'\n",
    "}\n",
    "\n",
    "data_name = {\n",
    "    \"openml_361260\": \"Miami Housing\",\n",
    "    \"openml_361259\": \"Puma Robot\",\n",
    "    \"openml_361253\": \"Wave Energy\",\n",
    "    \"openml_361254\": \"SARCOS\",\n",
    "    \"openml_361242\": \"Super Conductivity\",\n",
    "    \"openml_361243\": \"Geographic Origin of Music\",\n",
    "    \"openml_43\": \"Spam\",\n",
    "    \"openml_361062\": \"Pol\",\n",
    "    \"openml_361071\": \"Jannis\",\n",
    "    \"openml_9978\": \"Ozone\",\n",
    "    \"openml_361069\": \"Higgs\",\n",
    "    \"openml_361063\": \"House 16H\"\n",
    "}\n",
    "\n",
    "feature_values = {\n",
    "    \"openml_361260\": 15,\n",
    "    \"openml_361259\": 32,\n",
    "    \"openml_361253\": 48,\n",
    "    \"openml_361254\": 21,\n",
    "    \"openml_361242\": 81,\n",
    "    \"openml_361243\": 72,\n",
    "    \"openml_43\": 57,\n",
    "    \"openml_361062\": 26,\n",
    "    \"openml_361071\": 54,\n",
    "    \"openml_9978\": 47,\n",
    "    \"openml_361069\": 24,\n",
    "    \"openml_361063\": 16\n",
    "}\n",
    "\n",
    "methods_name = {\n",
    "    'LIME': 'LIME',\n",
    "    'Local MDI': 'Local MDI',\n",
    "    'Treeshap': 'TreeSHAP',\n",
    "    'lmdi+': 'LMDI+',\n",
    "}\n",
    "\n",
    "data_generator_label_map = {\n",
    "    'linear': 'Linear',\n",
    "    'interaction': 'Interaction',\n",
    "    'linear_lss': 'Linear + LSS',\n",
    "    'logistic_linear': 'Logistic',\n",
    "    'logistic_interaction': 'Logistic Interaction',\n",
    "    'logistic_linear_lss': 'Logistic + LSS'\n",
    "\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "combined_df = pd.DataFrame()\n",
    "datasets = [\"openml_361260\"]\n",
    "for dgp in [\"linear\", \"interaction\", \"linear_lss\"]:\n",
    "    for data in datasets:\n",
    "        ablation_directory = f\"./results/mdi_local_{model}.real_data_regression_{data}_{dgp}/{data}_{dgp}/varying_heritability_sample_row_n\"\n",
    "        feature_seeds = [1,2,3,4,5,6,7,8,9,10]\n",
    "        sample_seeds = [1,2,3]\n",
    "        for sample_seed in sample_seeds:\n",
    "            for feature_seed in feature_seeds:\n",
    "                df = pd.read_csv(os.path.join(ablation_directory, f\"seed_{feature_seed}_{sample_seed}/results.csv\"))\n",
    "                df[\"data\"] = data\n",
    "                df[\"dgp\"] = dgp\n",
    "                combined_df = pd.concat([combined_df, df], ignore_index=True)\n",
    "\n",
    "agg_df = combined_df.groupby(['sample_row_n', 'heritability','dgp', 'fi', 'data'])[\n",
    "    [\"auroc_train\", \"auroc_test\"]\n",
    "].agg(['mean', 'std', 'count']).reset_index()\n",
    "agg_df.columns = ['_'.join(col).strip('_') if col[1] else col[0] for col in agg_df.columns.values]\n",
    "agg_df[\"auroc_train_sem\"] = agg_df[\"auroc_train_std\"] / np.sqrt(agg_df[\"auroc_train_count\"])\n",
    "agg_df[\"auroc_test_sem\"] = agg_df[\"auroc_test_std\"] / np.sqrt(agg_df[\"auroc_test_count\"])\n",
    "df_regression = agg_df[agg_df[\"sample_row_n\"] == 300]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "combined_df = pd.DataFrame()\n",
    "datasets = [\"openml_361069\"]\n",
    "for dgp in [\"logistic_linear\", \"logistic_interaction\", \"logistic_linear_lss\"]:\n",
    "    for data in datasets:\n",
    "        ablation_directory = f\"./results/mdi_local_{model}.real_data_classification_{data}_{dgp}/{data}_{dgp}/varying_frac_label_corruption_sample_row_n\"\n",
    "        feature_seeds = [1,2,3,4,5,6,7,8,9,10]\n",
    "        sample_seeds = [1,2,3]\n",
    "        for sample_seed in sample_seeds:\n",
    "            for feature_seed in feature_seeds:\n",
    "                df = pd.read_csv(os.path.join(ablation_directory, f\"seed_{feature_seed}_{sample_seed}/results.csv\"))\n",
    "                df[\"data\"] = data\n",
    "                df[\"dgp\"] = dgp\n",
    "                combined_df = pd.concat([combined_df, df], ignore_index=True)\n",
    "\n",
    "agg_df = combined_df.groupby(['sample_row_n', 'frac_label_corruption','dgp', 'fi', 'data'])[\n",
    "    [\"auroc_train\", \"auroc_test\"]\n",
    "].agg(['mean', 'std', 'count']).reset_index()\n",
    "agg_df.columns = ['_'.join(col).strip('_') if col[1] else col[0] for col in agg_df.columns.values]\n",
    "agg_df[\"auroc_train_sem\"] = agg_df[\"auroc_train_std\"] / np.sqrt(agg_df[\"auroc_train_count\"])\n",
    "agg_df[\"auroc_test_sem\"] = agg_df[\"auroc_test_std\"] / np.sqrt(agg_df[\"auroc_test_count\"])\n",
    "df_classification = agg_df[agg_df[\"sample_row_n\"] == 300]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "regression_datasets = df_regression[\"data\"].unique()\n",
    "classification_datasets = df_classification[\"data\"].unique()\n",
    "regression_dgp = ['linear', 'interaction', 'linear_lss']\n",
    "classification_dgp = ['logistic_linear', 'logistic_interaction', 'logistic_linear_lss']\n",
    "marker_size = 7\n",
    "\n",
    "n_cols = 3\n",
    "n_rows = 2\n",
    "fig, axs = plt.subplots(\n",
    "    nrows=n_rows,\n",
    "    ncols=n_cols,\n",
    "    figsize=(8 * n_cols, 6.2 * n_rows),\n",
    "    sharey=False\n",
    ")\n",
    "\n",
    "if n_rows == 1:\n",
    "    axs = np.expand_dims(axs, axis=0)\n",
    "if n_cols == 1:\n",
    "    axs = np.expand_dims(axs, axis=1)\n",
    "\n",
    "for row_idx, dataset in enumerate(regression_datasets):\n",
    "    for col_idx, data_generator in enumerate(regression_dgp):\n",
    "        ax = axs[row_idx, col_idx]\n",
    "        subset = df_regression[(df_regression[\"data\"] == dataset) & (df_regression[\"dgp\"] == data_generator)]\n",
    "        \n",
    "        for method in methods:\n",
    "            method_data = subset[subset[\"fi\"] == method]\n",
    "            if method in ['LIME', 'Treeshap', 'Local MDI']:\n",
    "                ax.errorbar(\n",
    "                    method_data[\"heritability\"], method_data['auroc_test_mean'], yerr=method_data[\"auroc_test_sem\"],\n",
    "                    linestyle='solid', marker='o', markersize=marker_size,\n",
    "                    label=methods_name[method], color=color_map[method], linewidth=4, alpha=0.6\n",
    "                )\n",
    "            else:\n",
    "                ax.errorbar(\n",
    "                    method_data[\"heritability\"], method_data['auroc_test_mean'], yerr=method_data[\"auroc_test_sem\"],\n",
    "                    linestyle='solid', marker='o', markersize=marker_size,\n",
    "                    label=methods_name[method], color=color_map[method], linewidth=4\n",
    "                )\n",
    "        ax.set_xticks([0.1, 0.2, 0.4, 0.8])\n",
    "        ax.set_xticklabels([\"0.1\", \"0.2\", \"0.4\", \"0.8\"], fontsize=25)\n",
    "        ax.tick_params(axis='y', labelsize=25)\n",
    "        ax.set_xlabel(\"PVE\", fontsize=32)\n",
    "        if col_idx == 0:\n",
    "            dataset_label = data_name[dataset].replace(' ', r'\\ ')\n",
    "            p_val = feature_values[dataset]\n",
    "            ax.set_ylabel(f\"$\\\\mathbf{{{dataset_label}\\ (p={p_val})}}$\\nAUROC (↑)\", fontsize=30)\n",
    "        if row_idx == 0:\n",
    "            ax.set_title(f\"{data_generator_label_map[data_generator]}\", fontsize=32, fontweight='bold')\n",
    "        if col_idx == n_cols - 1:\n",
    "            ax.legend(fontsize=22, loc='lower right')\n",
    "\n",
    "for row_idx, dataset in enumerate(classification_datasets, start=len(regression_datasets)):\n",
    "    for col_idx, data_generator in enumerate(classification_dgp):\n",
    "        ax = axs[row_idx, col_idx]\n",
    "        subset = df_classification[(df_classification[\"data\"] == dataset) & (df_classification[\"dgp\"] == data_generator)]\n",
    "        \n",
    "        for method in methods:\n",
    "            method_data = subset[subset[\"fi\"] == method]\n",
    "            if method in ['LIME', 'Treeshap', 'Local MDI']:\n",
    "                ax.errorbar(\n",
    "                    method_data[\"frac_label_corruption\"], method_data['auroc_test_mean'], yerr=method_data[\"auroc_test_sem\"],\n",
    "                    linestyle='solid', marker='o', markersize=marker_size,\n",
    "                    label=methods_name[method], color=color_map[method], linewidth=4, alpha=0.6\n",
    "                )\n",
    "            else:\n",
    "                ax.errorbar(\n",
    "                    method_data[\"frac_label_corruption\"], method_data['auroc_test_mean'], yerr=method_data[\"auroc_test_sem\"],\n",
    "                    linestyle='solid', marker='o', markersize=marker_size,\n",
    "                    label=methods_name[method], color=color_map[method], linewidth=4\n",
    "                )\n",
    "        ax.set_xticks([0.15, 0.1, 0.05, 0])\n",
    "        ax.set_xticklabels([\"85%\", \"90%\", \"95%\", \"100%\"], fontsize=25)\n",
    "        ax.invert_xaxis()\n",
    "        ax.tick_params(axis='y', labelsize=25)\n",
    "        ax.set_xlabel(\"Percentage Uncorrupted\", fontsize=32)\n",
    "        if col_idx == 0:\n",
    "            dataset_label = data_name[dataset].replace(' ', r'\\ ')\n",
    "            p_val = feature_values[dataset]\n",
    "            ax.set_ylabel(f\"$\\\\mathbf{{{dataset_label}\\ (p={p_val})}}$\\nAUROC (↑)\", fontsize=30)\n",
    "        if row_idx == len(regression_datasets):\n",
    "            ax.set_title(f\"{data_generator_label_map[data_generator]}\", fontsize=32, fontweight='bold')\n",
    "        if col_idx == n_cols - 1:\n",
    "            ax.legend(fontsize=22, loc='lower right')\n",
    "plt.tight_layout(rect=[0, 0, 1, 1])\n",
    "plt.subplots_adjust(hspace=0.4)\n",
    "plt.savefig(f\"main_paper_feature_ranking_{model}.pdf\", format='pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Stability"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task = \"regression\" \n",
    "combined_df = pd.DataFrame()\n",
    "datasets = [\"openml_361254\", \"openml_361259\"]\n",
    "for data in datasets:\n",
    "    ablation_directory =f\"./results/mdi_local_{model}.real_data_{task}_{data}_stability/{data}_stability/varying_sample_row_n\"\n",
    "    split_seeds = [1,2,3]\n",
    "    sample_seeds = [1,2,3,4,5]\n",
    "    for split_seed in split_seeds:\n",
    "        for sample_seed in sample_seeds:\n",
    "            df = pd.read_csv(os.path.join(ablation_directory, f\"seed_{split_seed}_{sample_seed}/results.csv\"))\n",
    "            df[\"data\"] = data\n",
    "            combined_df = pd.concat([combined_df, df], ignore_index=True)\n",
    "df_regression = combined_df[combined_df[\"sample_row_n\"] == 2000]\n",
    "\n",
    "task = \"classification\" \n",
    "datasets = [\"openml_361069\", \"openml_9978\"]\n",
    "combined_df = pd.DataFrame()\n",
    "for data in datasets:\n",
    "    ablation_directory =f\"./results/mdi_local_{model}.real_data_{task}_{data}_stability/{data}_stability/varying_sample_row_n\"\n",
    "    split_seeds = [1,2,3]\n",
    "    sample_seeds = [1,2,3,4,5]\n",
    "    for split_seed in split_seeds:\n",
    "        for sample_seed in sample_seeds:\n",
    "            df = pd.read_csv(os.path.join(ablation_directory, f\"seed_{split_seed}_{sample_seed}/results.csv\"))\n",
    "            df[\"data\"] = data\n",
    "            combined_df = pd.concat([combined_df, df], ignore_index=True)\n",
    "df_classification = combined_df[combined_df[\"sample_row_n\"] == 2000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "regression_datasets = list(df_regression[\"data\"].unique())\n",
    "classification_datasets = list(df_classification[\"data\"].unique())\n",
    "\n",
    "marker_size = 18\n",
    "n_rows = 1\n",
    "n_cols = 4\n",
    "\n",
    "fig, axs = plt.subplots(\n",
    "    nrows=n_rows,\n",
    "    ncols=n_cols,\n",
    "    figsize=(12 * n_cols, 11.25 * n_rows),\n",
    "    sharey=True\n",
    ")\n",
    "axs = axs.flatten()\n",
    "\n",
    "plot_order = [\n",
    "    (\"regression\", regression_datasets[0]),\n",
    "    (\"classification\", classification_datasets[0]),\n",
    "    (\"regression\", regression_datasets[1]),\n",
    "    (\"classification\", classification_datasets[1]),\n",
    "]\n",
    "\n",
    "for plot_idx, (dtype, dataset) in enumerate(plot_order):\n",
    "    ax = axs[plot_idx]\n",
    "    subset = df_regression if dtype == \"regression\" else df_classification\n",
    "    subset = subset[subset[\"data\"] == dataset]\n",
    "\n",
    "    for method in methods:\n",
    "        method_data = subset[subset[\"fi\"] == method]\n",
    "\n",
    "        x = [0.1, 0.2, 0.3, 0.4]\n",
    "        cols = [\"avg_10_features\", \"avg_20_features\", \"avg_30_features\", \"avg_40_features\"]\n",
    "        \n",
    "        n_features = feature_values[dataset]\n",
    "        normalized_data = method_data[cols] / n_features\n",
    "        \n",
    "        means = normalized_data.mean(axis=0).values\n",
    "        stds = normalized_data.std(axis=0).values\n",
    "        counts = normalized_data.count(axis=0).values\n",
    "        sems = stds / np.sqrt(counts)\n",
    "\n",
    "        if method in ['LIME', 'Treeshap', 'Local MDI']:\n",
    "            ax.errorbar(\n",
    "                x, means, sems,\n",
    "                linestyle='solid', marker='o', markersize=marker_size,\n",
    "                label=methods_name[method], color=color_map[method], linewidth=8, alpha=0.6\n",
    "            )\n",
    "        else:\n",
    "            ax.errorbar(\n",
    "                x, means, sems,\n",
    "                linestyle='solid', marker='o', markersize=marker_size,\n",
    "                label=methods_name[method], color=color_map[method], linewidth=8\n",
    "            )\n",
    "\n",
    "    ax.set_xticks([0.1, 0.2, 0.3, 0.4])\n",
    "    ax.set_xticklabels([\"10%\", \"20%\", \"30%\", \"40%\"], fontsize=50)\n",
    "    ax.tick_params(axis='both', which='both', direction='in', length=10, width=6, labelsize=50)\n",
    "    ax.set_xlabel(\"Top Features Retained\", fontsize=55)\n",
    "    if plot_idx == 0:\n",
    "        ax.set_ylabel(\"Unique Feature Ratio (↓)\", fontsize=55)\n",
    "    dataset_label = data_name[dataset].replace(' ', r'\\ ')\n",
    "    p_val = feature_values[dataset]\n",
    "    ax.set_title(f\"$\\\\mathbf{{{dataset_label} \\ (p={p_val}) }}$\", fontsize=55)\n",
    "\n",
    "axs[-1].legend(fontsize=45, loc='upper left')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"main_paper_feature_stability_{model}.pdf\", format='pdf', bbox_inches='tight')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Feature Selection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task = \"regression\" \n",
    "combined_df = pd.DataFrame()\n",
    "datasets = [\"openml_361254\", \"openml_361259\"]\n",
    "for data in datasets:\n",
    "    ablation_directory =f\"./results/mdi_local_{model}.real_data_{task}_{data}/{data}_mean/varying_sample_row_n\"\n",
    "    split_seeds = [1,2,3,4]\n",
    "    sample_seeds = [1,2,3,4,5]\n",
    "    for split_seed in split_seeds:\n",
    "        for sample_seed in sample_seeds:\n",
    "            df = pd.read_csv(os.path.join(ablation_directory, f\"seed_{split_seed}_{sample_seed}/results.csv\"))\n",
    "            df[\"data\"] = data\n",
    "            combined_df = pd.concat([combined_df, df], ignore_index=True)\n",
    "df_regression = combined_df\n",
    "\n",
    "task = \"classification\" \n",
    "datasets = [\"openml_361069\", \"openml_9978\"]\n",
    "combined_df = pd.DataFrame()\n",
    "for data in datasets:\n",
    "    ablation_directory =f\"./results/mdi_local_{model}.real_data_{task}_{data}/{data}_mean/varying_sample_row_n\"\n",
    "    split_seeds = [1,2,3,4]\n",
    "    sample_seeds = [1,2,3,4,5]\n",
    "    for split_seed in split_seeds:\n",
    "        for sample_seed in sample_seeds:\n",
    "            df = pd.read_csv(os.path.join(ablation_directory, f\"seed_{split_seed}_{sample_seed}/results.csv\"))\n",
    "            df[\"data\"] = data\n",
    "            combined_df = pd.concat([combined_df, df], ignore_index=True)\n",
    "df_classification = combined_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "regression_datasets = df_regression[\"data\"].unique()\n",
    "classification_datasets = df_classification[\"data\"].unique()\n",
    "marker_size = 7\n",
    "n_rows = 1\n",
    "n_cols = 4\n",
    "\n",
    "fig, axs = plt.subplots(\n",
    "    nrows=n_rows,\n",
    "    ncols=n_cols,\n",
    "    figsize=(12 * n_cols, 12 * n_rows),\n",
    "    sharey=False\n",
    ")\n",
    "axs = axs.flatten()\n",
    "\n",
    "for plot_idx, dataset in enumerate(regression_datasets):\n",
    "    ax = axs[plot_idx]\n",
    "    subset = df_regression[(df_regression[\"data\"] == dataset)]\n",
    "\n",
    "    for method in methods:\n",
    "        method_data = subset[subset[\"fi\"] == method]\n",
    "        r2_col = [\"R2_keep_0.1\", \"R2_keep_0.2\", \"R2_keep_0.3\",\n",
    "            \"R2_keep_0.4\", \"R2_keep_0.5\", \"R2_keep_0.6\", \"R2_keep_0.7\",\n",
    "            \"R2_keep_0.8\", \"R2_keep_0.9\", \"R2_keep_1.0\"]\n",
    "        means = method_data[r2_col].mean(axis=0).values\n",
    "        stds = method_data[r2_col].std(axis=0).values\n",
    "        counts = method_data[r2_col].count(axis=0).values\n",
    "        sems = stds / np.sqrt(counts)\n",
    "        x = [0.1, 0.2, 0.3, 0.4,0.5,0.6,0.7,0.8,0.9,1.0]\n",
    "\n",
    "        if method in ['LIME', 'Treeshap', 'Local MDI']:\n",
    "            ax.errorbar(\n",
    "                x, means, sems,\n",
    "                label=methods_name[method], linestyle='solid',\n",
    "                marker='o', markersize=marker_size, color=color_map[method], linewidth=8, alpha=0.6\n",
    "            )\n",
    "        else:\n",
    "            if dataset == \"openml_361254\":\n",
    "                marker_size = 15\n",
    "            ax.errorbar(\n",
    "                x, means, sems,\n",
    "                label=methods_name[method], linestyle='solid',\n",
    "                marker='o', markersize=marker_size, color=color_map[method], linewidth=8\n",
    "            )\n",
    "\n",
    "    ax.set_xticks([0.1, 0.3, 0.5, 0.7, 0.9])\n",
    "    ax.set_xticklabels([\"10%\", \"30%\", \"50%\", \"70%\", \"90%\"], fontsize=50)\n",
    "    ax.tick_params(axis='y', labelsize=50)\n",
    "    ax.set_xlabel(\"Top Features Retained\", fontsize=55)\n",
    "    ax.xaxis.set_label_coords(0.48, -0.1)\n",
    "    if plot_idx == 0:\n",
    "        ax.set_ylabel(f\"$R^2$ (↑)\", fontsize=55)\n",
    "    dataset_label = data_name[dataset].replace(' ', r'\\ ')\n",
    "    p_val = feature_values[dataset]\n",
    "    ax.set_title(f\"$\\\\mathbf{{{dataset_label} \\ (p={p_val}) }}$\", fontsize=55)\n",
    "\n",
    "\n",
    "for plot_idx, dataset in enumerate(classification_datasets, start=2):\n",
    "    ax = axs[plot_idx]\n",
    "    subset = df_classification[(df_classification[\"data\"] == dataset)]\n",
    "\n",
    "    for method in methods:\n",
    "        method_data = subset[subset[\"fi\"] == method]\n",
    "        auroc_cols = [\n",
    "            \"AUROC_keep_0.1\", \"AUROC_keep_0.2\", \"AUROC_keep_0.3\",\n",
    "            \"AUROC_keep_0.4\", \"AUROC_keep_0.5\", \"AUROC_keep_0.6\",\n",
    "            \"AUROC_keep_0.7\", \"AUROC_keep_0.8\", \"AUROC_keep_0.9\",\n",
    "            \"AUROC_keep_1.0\"\n",
    "        ]\n",
    "        means = method_data[auroc_cols].mean(axis=0).values\n",
    "        stds = method_data[auroc_cols].std(axis=0).values\n",
    "        counts = method_data[auroc_cols].count(axis=0).values\n",
    "        sems = stds / np.sqrt(counts)\n",
    "\n",
    "        x = [0.1, 0.2, 0.3, 0.4,0.5,0.6,0.7,0.8,0.9,1.0]\n",
    "\n",
    "        if method in ['LIME', 'Treeshap', 'Local MDI']:\n",
    "            ax.errorbar(\n",
    "                x, means, sems,\n",
    "                label=methods_name[method], linestyle='solid',\n",
    "                marker='o', markersize=marker_size, color=color_map[method], linewidth=8, alpha=0.6\n",
    "            )\n",
    "        else:\n",
    "            ax.errorbar(\n",
    "                x, means, sems,\n",
    "                label=methods_name[method], linestyle='solid',\n",
    "                marker='o', markersize=marker_size, color=color_map[method], linewidth=8\n",
    "            )\n",
    "\n",
    "    ax.set_xticks([0.1, 0.3, 0.5, 0.7, 0.9])\n",
    "    ax.set_xticklabels([\"10%\", \"30%\", \"50%\", \"70%\", \"90%\"], fontsize=50)\n",
    "    ax.tick_params(axis='y', labelsize=50)\n",
    "    ax.set_xlabel(\"Top Features Retained\", fontsize=55)\n",
    "    ax.xaxis.set_label_coords(0.48, -0.1)\n",
    "    if plot_idx == 2:\n",
    "        ax.set_ylabel(f\"AUROC (↑)\", fontsize=55)\n",
    "    dataset_label = data_name[dataset].replace(' ', r'\\ ')\n",
    "    p_val = feature_values[dataset]\n",
    "    ax.set_title(f\"$\\\\mathbf{{{dataset_label} \\ (p={p_val}) }}$\", fontsize=55)\n",
    "\n",
    "axs[-1].legend(fontsize=45, loc='lower right')\n",
    "\n",
    "fig.text(0.27, 0.9, \"Regression\", ha='center', va='bottom', fontsize=65, fontweight='bold')\n",
    "fig.text(0.76, 0.9, \"Classification\", ha='center', va='bottom', fontsize=65, fontweight='bold')\n",
    "\n",
    "plt.tight_layout(rect=[0, 0, 1, 0.88])\n",
    "plt.subplots_adjust(wspace=0.15)\n",
    "plt.savefig(f\"main_paper_feature_selection_{model}.pdf\", format='pdf', bbox_inches='tight')\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mdi-new",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
