{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "sys.path.append('../..')\n",
    "sys.path.append('.')\n",
    "sys.path.append('./scripts')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['axes.labelsize'] = 30\n",
    "plt.rcParams['xtick.labelsize'] = 10\n",
    "plt.rcParams['ytick.labelsize'] = 12\n",
    "plt.rcParams['axes.spines.right'] = False\n",
    "plt.rcParams['axes.spines.top'] = False\n",
    "plt.rcParams['axes.edgecolor'] = 'black'\n",
    "plt.rcParams['axes.linewidth'] = 2.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = \"rf\" # \"rf\" \"gradient_boosting\"\n",
    "\n",
    "methods = [\n",
    "    'lmdi+',\n",
    "    'LIME',\n",
    "    'Treeshap',\n",
    "]\n",
    "\n",
    "\n",
    "color_map = {\n",
    "    'LIME': '#71BEB7',\n",
    "    'Treeshap': 'orange',\n",
    "    'lmdi+': 'black',\n",
    "}\n",
    "methods_name = {\n",
    "    'LIME': 'LIME',\n",
    "    'Treeshap': 'TreeSHAP',\n",
    "    'lmdi+': 'LMDI+',\n",
    "}\n",
    "\n",
    "if model == \"rf\":\n",
    "    methods.extend(['Local MDI'])\n",
    "    color_map['Local MDI'] = '#9B5DFF'\n",
    "    methods_name['Local MDI'] = 'Local MDI'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Regression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dgp = \"linear\"\n",
    "combined_df = pd.DataFrame()\n",
    "datasets = ['openml_361260', 'openml_361254', 'openml_361259', 'openml_361253', 'openml_361243', 'openml_361242']\n",
    "feature_seeds = [1,2,3,4,5,6,7,8,9,10]\n",
    "sample_seeds = [1,2,3]\n",
    "for data in datasets:\n",
    "    # ablation_directory = f\"./results/mdi_local_{model}.real_data_regression_{data}_{dgp}/{data}_{dgp}/varying_heritability_sample_row_n\"\n",
    "    ablation_directory = f\"./results/mdi_local_{model}.real_data_regression_{data}_{dgp}/{data}_{dgp}/varying_heritability_sample_row_n\"\n",
    "    for sample_seed in sample_seeds:\n",
    "        for feature_seed in feature_seeds:\n",
    "            try:\n",
    "                df = pd.read_csv(os.path.join(ablation_directory, f\"seed_{feature_seed}_{sample_seed}/results.csv\"))\n",
    "                df[\"data\"] = data\n",
    "                combined_df = pd.concat([combined_df, df], ignore_index=True)\n",
    "            except:\n",
    "                print(f\"Missing: {data}, feature seed: {feature_seed}, sample seed: {sample_seed}\")\n",
    "                \n",
    "agg_df = combined_df.groupby(['sample_row_n', 'heritability', 'fi', 'data'])[\n",
    "    [\"auroc_train\", \"auroc_test\"]\n",
    "].agg(['mean', 'std', 'count']).reset_index()\n",
    "agg_df.columns = ['_'.join(col).strip('_') if col[1] else col[0] for col in agg_df.columns.values]\n",
    "agg_df[\"auroc_train_sem\"] = agg_df[\"auroc_train_std\"] / np.sqrt(agg_df[\"auroc_train_count\"])\n",
    "agg_df[\"auroc_test_sem\"] = agg_df[\"auroc_test_std\"] / np.sqrt(agg_df[\"auroc_test_count\"])\n",
    "\n",
    "df = agg_df\n",
    "df = df[df[\"sample_row_n\"] != 800]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_name = {\n",
    "    \"openml_361260\": \"Miami Housing\",\n",
    "    \"openml_361259\": \"Puma Robot\",\n",
    "    \"openml_361253\": \"Wave Energy\",\n",
    "    \"openml_361254\": \"SARCOS\",\n",
    "    \"openml_361242\": \"Super Conductivity\",\n",
    "    \"openml_361243\": \"Geographic Origin of Music\"\n",
    "}\n",
    "\n",
    "feature_values = {\n",
    "    \"openml_361260\": 15,\n",
    "    \"openml_361259\": 32,\n",
    "    \"openml_361253\": 48,\n",
    "    \"openml_361254\": 21,\n",
    "    \"openml_361242\": 81,\n",
    "    \"openml_361243\": 72\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = datasets\n",
    "heritability_all = df[\"heritability\"].unique()[::-1]\n",
    "marker_size = 7\n",
    "\n",
    "n_cols = len(heritability_all)\n",
    "n_rows = len(datasets) \n",
    "\n",
    "fig, axs = plt.subplots(\n",
    "    nrows=n_rows,\n",
    "    ncols=n_cols,\n",
    "    figsize=(8 * n_cols, 6.5 * n_rows),\n",
    "    sharey=False\n",
    ")\n",
    "\n",
    "if n_rows == 1:\n",
    "    axs = np.expand_dims(axs, axis=0)\n",
    "if n_cols == 1:\n",
    "    axs = np.expand_dims(axs, axis=1)\n",
    "\n",
    "for row_idx, dataset in enumerate(datasets):\n",
    "    for col_idx, heritability in enumerate(heritability_all):\n",
    "        ax = axs[row_idx, col_idx]\n",
    "        subset = df[(df[\"data\"] == dataset) & (df[\"heritability\"] == heritability)]\n",
    "        \n",
    "        for method in methods:\n",
    "            method_data = subset[subset[\"fi\"] == method]\n",
    "\n",
    "            if method in ['LIME', 'Treeshap', 'Local MDI']:\n",
    "                ax.errorbar(\n",
    "                    method_data[\"sample_row_n\"], method_data['auroc_test_mean'], yerr=method_data[\"auroc_test_sem\"],\n",
    "                    linestyle='solid', marker='o', markersize=marker_size,\n",
    "                    label=methods_name[method], color=color_map[method], linewidth=4, alpha=0.6\n",
    "                )\n",
    "            else:\n",
    "                ax.errorbar(\n",
    "                    method_data[\"sample_row_n\"], method_data['auroc_test_mean'], yerr=method_data[\"auroc_test_sem\"],\n",
    "                    linestyle='solid', marker='o', markersize=marker_size,\n",
    "                    label=methods_name[method], color=color_map[method], linewidth=4\n",
    "                )\n",
    "        ax.set_xticks([300,500, 1000])\n",
    "        ax.set_xticklabels([\"300\", \"500\", \"1000\"], fontsize=25)\n",
    "        ax.tick_params(axis='y', labelsize=25)\n",
    "        if row_idx == n_rows - 1:\n",
    "            ax.set_xlabel(\"Sample Size\", fontsize=30)\n",
    "        \n",
    "        if col_idx == 0:\n",
    "            dataset_label = data_name[dataset]\n",
    "            p_val = feature_values[dataset]\n",
    "            dataset_label = dataset_label.replace(' ', r'\\ ')\n",
    "            if dataset == \"openml_361243\":\n",
    "                ax.set_ylabel(\n",
    "                    f\"$\\\\mathbf{{Geographic\\ Origin\\ of}}$\\n$\\\\mathbf{{Music\\ (p={p_val})}}$\\nAUROC\",\n",
    "                    fontsize=30\n",
    "                )\n",
    "            else:\n",
    "                ax.set_ylabel(f\"$\\\\mathbf{{{dataset_label} \\ (p={p_val})}}$\\nAUROC\", fontsize=30)\n",
    "        else:\n",
    "            ax.set_ylabel(\"\")\n",
    "        \n",
    "        if row_idx == 0:\n",
    "            ax.set_title(r\"$\\bf{PVE}$=\" + rf\"$\\bf{{{heritability}}}$\", fontsize=30)\n",
    "\n",
    "        if col_idx == n_cols - 1:\n",
    "            ax.legend(fontsize=22, loc='lower right')\n",
    "\n",
    "plt.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "plt.suptitle(r\"\\textbf{Linear}\", fontsize=50, usetex=True)\n",
    "plt.savefig(f\"feature_ranking_linear_full_{model}.pdf\", format='pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dgp = \"interaction\"\n",
    "combined_df = pd.DataFrame()\n",
    "datasets = ['openml_361260', 'openml_361254', 'openml_361259', 'openml_361253', 'openml_361243', 'openml_361242']\n",
    "feature_seeds = [1,2,3,4,5,6,7,8,9,10]\n",
    "sample_seeds = [1,2,3]\n",
    "for data in datasets:\n",
    "    ablation_directory = f\"./results/mdi_local_{model}.real_data_regression_{data}_{dgp}/{data}_{dgp}/varying_heritability_sample_row_n\"\n",
    "    for sample_seed in sample_seeds:\n",
    "        for feature_seed in feature_seeds:\n",
    "            try:\n",
    "                df = pd.read_csv(os.path.join(ablation_directory, f\"seed_{feature_seed}_{sample_seed}/results.csv\"))\n",
    "                df[\"data\"] = data\n",
    "                combined_df = pd.concat([combined_df, df], ignore_index=True)\n",
    "            except:\n",
    "                print(f\"Missing: {data}, feature seed: {feature_seed}, sample seed: {sample_seed}\")\n",
    "\n",
    "agg_df = combined_df.groupby(['sample_row_n', 'heritability', 'fi', 'data'])[\n",
    "    [\"auroc_train\", \"auroc_test\"]\n",
    "].agg(['mean', 'std', 'count']).reset_index()\n",
    "agg_df.columns = ['_'.join(col).strip('_') if col[1] else col[0] for col in agg_df.columns.values]\n",
    "agg_df[\"auroc_train_sem\"] = agg_df[\"auroc_train_std\"] / np.sqrt(agg_df[\"auroc_train_count\"])\n",
    "agg_df[\"auroc_test_sem\"] = agg_df[\"auroc_test_std\"] / np.sqrt(agg_df[\"auroc_test_count\"])\n",
    "\n",
    "df = agg_df\n",
    "df = df[df[\"sample_row_n\"] != 800]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = datasets\n",
    "heritability_all = df[\"heritability\"].unique()[::-1]\n",
    "marker_size = 7\n",
    "\n",
    "n_cols = len(heritability_all)\n",
    "n_rows = len(datasets)\n",
    "\n",
    "fig, axs = plt.subplots(\n",
    "    nrows=n_rows,\n",
    "    ncols=n_cols,\n",
    "    figsize=(8 * n_cols, 6.5 * n_rows),\n",
    "    sharey=False\n",
    ")\n",
    "\n",
    "if n_rows == 1:\n",
    "    axs = np.expand_dims(axs, axis=0)\n",
    "if n_cols == 1:\n",
    "    axs = np.expand_dims(axs, axis=1)\n",
    "\n",
    "for row_idx, dataset in enumerate(datasets):\n",
    "    for col_idx, heritability in enumerate(heritability_all):\n",
    "        ax = axs[row_idx, col_idx]\n",
    "        subset = df[(df[\"data\"] == dataset) & (df[\"heritability\"] == heritability)]\n",
    "        \n",
    "        for method in methods:\n",
    "            method_data = subset[subset[\"fi\"] == method]\n",
    "\n",
    "            if method in ['LIME', 'Treeshap', 'Local MDI']:\n",
    "                ax.errorbar(\n",
    "                    method_data[\"sample_row_n\"], method_data['auroc_test_mean'], yerr=method_data[\"auroc_test_sem\"],\n",
    "                    linestyle='solid', marker='o', markersize=marker_size,\n",
    "                    label=methods_name[method], color=color_map[method], linewidth=4, alpha=0.6\n",
    "                )\n",
    "            else:\n",
    "                ax.errorbar(\n",
    "                    method_data[\"sample_row_n\"], method_data['auroc_test_mean'], yerr=method_data[\"auroc_test_sem\"],\n",
    "                    linestyle='solid', marker='o', markersize=marker_size,\n",
    "                    label=methods_name[method], color=color_map[method], linewidth=4\n",
    "                )\n",
    "        \n",
    "        ax.set_xticks([300,500, 1000])\n",
    "        ax.set_xticklabels([\"300\", \"500\", \"1000\"], fontsize=25)\n",
    "        ax.tick_params(axis='y', labelsize=25)\n",
    "        if row_idx == n_rows - 1:\n",
    "            ax.set_xlabel(\"Sample Size\", fontsize=30)\n",
    "\n",
    "        if col_idx == 0:\n",
    "            dataset_label = data_name[dataset]\n",
    "            p_val = feature_values[dataset]\n",
    "            dataset_label = dataset_label.replace(' ', r'\\ ')\n",
    "            if dataset == \"openml_361243\":\n",
    "                ax.set_ylabel(\n",
    "                    f\"$\\\\mathbf{{Geographic\\ Origin\\ of}}$\\n$\\\\mathbf{{Music\\ (p={p_val})}}$\\nAUROC\",\n",
    "                    fontsize=30\n",
    "                )\n",
    "            else:\n",
    "                ax.set_ylabel(f\"$\\\\mathbf{{{dataset_label} \\ (p={p_val})}}$\\nAUROC\", fontsize=30)\n",
    "        else:\n",
    "            ax.set_ylabel(\"\")\n",
    "        \n",
    "        if row_idx == 0:\n",
    "            ax.set_title(r\"$\\bf{PVE}$=\" + rf\"$\\bf{{{heritability}}}$\", fontsize=30)\n",
    "\n",
    "        if col_idx == n_cols - 1:\n",
    "            ax.legend(fontsize=22, loc='lower right')\n",
    "\n",
    "plt.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "plt.suptitle(r\"\\textbf{Interaction}\", fontsize=50, usetex=True)\n",
    "plt.savefig(f\"feature_ranking_interaction_full_{model}.pdf\", format='pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dgp = \"linear_lss\"\n",
    "combined_df = pd.DataFrame()\n",
    "datasets = ['openml_361260', 'openml_361254', 'openml_361259', 'openml_361253', 'openml_361243', 'openml_361242']\n",
    "feature_seeds = [1,2,3,4,5,6,7,8,9,10]\n",
    "sample_seeds = [1,2,3]\n",
    "for data in datasets:\n",
    "    ablation_directory = f\"./results/mdi_local_{model}.real_data_regression_{data}_{dgp}/{data}_{dgp}/varying_heritability_sample_row_n\"\n",
    "    for sample_seed in sample_seeds:\n",
    "        for feature_seed in feature_seeds:\n",
    "            try:\n",
    "                df = pd.read_csv(os.path.join(ablation_directory, f\"seed_{feature_seed}_{sample_seed}/results.csv\"))\n",
    "                df[\"data\"] = data\n",
    "                combined_df = pd.concat([combined_df, df], ignore_index=True)\n",
    "            except:\n",
    "                print(f\"Missing: {data}, feature seed: {feature_seed}, sample seed: {sample_seed}\")\n",
    "\n",
    "agg_df = combined_df.groupby(['sample_row_n', 'heritability', 'fi', 'data'])[\n",
    "    [\"auroc_train\", \"auroc_test\"]\n",
    "].agg(['mean', 'std', 'count']).reset_index()\n",
    "agg_df.columns = ['_'.join(col).strip('_') if col[1] else col[0] for col in agg_df.columns.values]\n",
    "agg_df[\"auroc_train_sem\"] = agg_df[\"auroc_train_std\"] / np.sqrt(agg_df[\"auroc_train_count\"])\n",
    "agg_df[\"auroc_test_sem\"] = agg_df[\"auroc_test_std\"] / np.sqrt(agg_df[\"auroc_test_count\"])\n",
    "\n",
    "df = agg_df\n",
    "df = df[df[\"sample_row_n\"] != 800]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = datasets\n",
    "heritability_all = df[\"heritability\"].unique()[::-1]\n",
    "marker_size = 7\n",
    "\n",
    "n_cols = len(heritability_all)\n",
    "n_rows = len(datasets)\n",
    "\n",
    "fig, axs = plt.subplots(\n",
    "    nrows=n_rows,\n",
    "    ncols=n_cols,\n",
    "    figsize=(8 * n_cols, 6.5 * n_rows),\n",
    "    sharey=False\n",
    ")\n",
    "\n",
    "if n_rows == 1:\n",
    "    axs = np.expand_dims(axs, axis=0)\n",
    "if n_cols == 1:\n",
    "    axs = np.expand_dims(axs, axis=1)\n",
    "\n",
    "for row_idx, dataset in enumerate(datasets):\n",
    "    for col_idx, heritability in enumerate(heritability_all):\n",
    "        ax = axs[row_idx, col_idx]\n",
    "        subset = df[(df[\"data\"] == dataset) & (df[\"heritability\"] == heritability)]\n",
    "        \n",
    "        for method in methods:\n",
    "            method_data = subset[subset[\"fi\"] == method]\n",
    "\n",
    "            if method in ['LIME', 'Treeshap', 'Local MDI']:\n",
    "                ax.errorbar(\n",
    "                    method_data[\"sample_row_n\"], method_data['auroc_test_mean'], yerr=method_data[\"auroc_test_sem\"],\n",
    "                    linestyle='solid', marker='o', markersize=marker_size,\n",
    "                    label=methods_name[method], color=color_map[method], linewidth=4, alpha=0.6\n",
    "                )\n",
    "            else:\n",
    "                ax.errorbar(\n",
    "                    method_data[\"sample_row_n\"], method_data['auroc_test_mean'], yerr=method_data[\"auroc_test_sem\"],\n",
    "                    linestyle='solid', marker='o', markersize=marker_size,\n",
    "                    label=methods_name[method], color=color_map[method], linewidth=4\n",
    "                )\n",
    "        \n",
    "        ax.set_xticks([300,500, 1000])\n",
    "        ax.set_xticklabels([\"300\", \"500\", \"1000\"], fontsize=25)\n",
    "        ax.tick_params(axis='y', labelsize=25)\n",
    "        if row_idx == n_rows - 1:\n",
    "            ax.set_xlabel(\"Sample Size\", fontsize=30)\n",
    "        \n",
    "        if col_idx == 0:\n",
    "            dataset_label = data_name[dataset]\n",
    "            p_val = feature_values[dataset]\n",
    "            dataset_label = dataset_label.replace(' ', r'\\ ')\n",
    "            if dataset == \"openml_361243\":\n",
    "                ax.set_ylabel(\n",
    "                    f\"$\\\\mathbf{{Geographic\\ Origin\\ of}}$\\n$\\\\mathbf{{Music\\ (p={p_val})}}$\\nAUROC\",\n",
    "                    fontsize=30\n",
    "                )\n",
    "            else:\n",
    "                ax.set_ylabel(f\"$\\\\mathbf{{{dataset_label} \\ (p={p_val})}}$\\nAUROC\", fontsize=30)\n",
    "        else:\n",
    "            ax.set_ylabel(\"\")\n",
    "        \n",
    "        if row_idx == 0:\n",
    "            ax.set_title(r\"$\\bf{PVE}$=\" + rf\"$\\bf{{{heritability}}}$\", fontsize=30)\n",
    "\n",
    "        if col_idx == n_cols - 1:\n",
    "            ax.legend(fontsize=22, loc='lower right')\n",
    "\n",
    "plt.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "plt.suptitle(r\"\\textbf{Linear + LSS}\", fontsize=50, usetex=True)\n",
    "plt.savefig(f\"feature_ranking_linear_lss_full_{model}.pdf\", format='pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dgp = \"logistic_linear\"\n",
    "datasets = ['openml_361063', 'openml_361069', 'openml_361062', 'openml_9978', 'openml_361071', 'openml_43']\n",
    "feature_seeds = [1,2,3,4,5,6,7,8,9,10]\n",
    "sample_seeds = [1,2,3]\n",
    "combined_df = pd.DataFrame()\n",
    "for data in datasets:\n",
    "    ablation_directory = f\"./results/mdi_local_{model}.real_data_classification_{data}_{dgp}/{data}_{dgp}_threshold_05/varying_frac_label_corruption_sample_row_n\"\n",
    "    for sample_seed in sample_seeds:\n",
    "        for feature_seed in feature_seeds:\n",
    "            df = pd.read_csv(os.path.join(ablation_directory, f\"seed_{feature_seed}_{sample_seed}/results.csv\"))\n",
    "            df[\"data\"] = data\n",
    "            combined_df = pd.concat([combined_df, df], ignore_index=True)\n",
    "\n",
    "\n",
    "agg_df = combined_df.groupby(['sample_row_n', 'frac_label_corruption', 'fi', 'data'])[\n",
    "    [\"auroc_train\", \"auroc_test\"]\n",
    "].agg(['mean', 'std', 'count']).reset_index()\n",
    "agg_df.columns = ['_'.join(col).strip('_') if col[1] else col[0] for col in agg_df.columns.values]\n",
    "agg_df[\"auroc_train_sem\"] = agg_df[\"auroc_train_std\"] / np.sqrt(agg_df[\"auroc_train_count\"])\n",
    "agg_df[\"auroc_test_sem\"] = agg_df[\"auroc_test_std\"] / np.sqrt(agg_df[\"auroc_test_count\"])\n",
    "\n",
    "df = agg_df\n",
    "df = df[df[\"sample_row_n\"] != 800]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_name = {\n",
    "    \"openml_43\": \"Spam\",\n",
    "    \"openml_361062\": \"Pol\",\n",
    "    \"openml_361071\": \"Jannis\",\n",
    "    \"openml_9978\": \"Ozone\",\n",
    "    \"openml_361069\": \"Higgs\",\n",
    "    \"openml_361063\": \"House 16H\"\n",
    "}\n",
    "\n",
    "feature_values = {\n",
    "    \"openml_43\": 57,\n",
    "    \"openml_361062\": 26,\n",
    "    \"openml_361071\": 54,\n",
    "    \"openml_9978\": 47,\n",
    "    \"openml_361069\": 24,\n",
    "    \"openml_361063\": 16\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = datasets\n",
    "frac_label_corruption_all = df[\"frac_label_corruption\"].unique()\n",
    "marker_size = 7\n",
    "\n",
    "n_cols = len(frac_label_corruption_all)\n",
    "n_rows = len(datasets)\n",
    "\n",
    "fig, axs = plt.subplots(\n",
    "    nrows=n_rows,\n",
    "    ncols=n_cols,\n",
    "    figsize=(8 * n_cols, 6.5 * n_rows),\n",
    "    sharey=False\n",
    ")\n",
    "\n",
    "if n_rows == 1:\n",
    "    axs = np.expand_dims(axs, axis=0)\n",
    "if n_cols == 1:\n",
    "    axs = np.expand_dims(axs, axis=1)\n",
    "\n",
    "for row_idx, dataset in enumerate(datasets):\n",
    "    for col_idx, frac_label_corruption in enumerate(frac_label_corruption_all):\n",
    "        ax = axs[row_idx, col_idx]\n",
    "        subset = df[(df[\"data\"] == dataset) & (df[\"frac_label_corruption\"] == frac_label_corruption)]\n",
    "        \n",
    "        for method in methods:\n",
    "            method_data = subset[subset[\"fi\"] == method]\n",
    "\n",
    "            if method in ['LIME', 'Treeshap', 'Local MDI']:\n",
    "                ax.errorbar(\n",
    "                    method_data[\"sample_row_n\"], method_data['auroc_test_mean'], yerr=method_data[\"auroc_test_sem\"],\n",
    "                    linestyle='solid', marker='o', markersize=marker_size,\n",
    "                    label=methods_name[method], color=color_map[method], linewidth=4, alpha=0.6\n",
    "                )\n",
    "            else:\n",
    "                ax.errorbar(\n",
    "                    method_data[\"sample_row_n\"], method_data['auroc_test_mean'], yerr=method_data[\"auroc_test_sem\"],\n",
    "                    linestyle='solid', marker='o', markersize=marker_size,\n",
    "                    label=methods_name[method], color=color_map[method], linewidth=4\n",
    "                )\n",
    "\n",
    "        ax.set_xticks([300,500, 1000])\n",
    "        ax.set_xticklabels([\"300\", \"500\", \"1000\"], fontsize=25)\n",
    "        ax.tick_params(axis='y', labelsize=25)\n",
    "        if row_idx == n_rows - 1:\n",
    "            ax.set_xlabel(\"Sample Size\", fontsize=30)\n",
    "\n",
    "        if col_idx == 0:\n",
    "            dataset_label = data_name[dataset]\n",
    "            p_val = feature_values[dataset]\n",
    "            dataset_label = dataset_label.replace(' ', r'\\ ')\n",
    "            ax.set_ylabel(f\"$\\\\mathbf{{{dataset_label}\\ (p={p_val})}}$\\nAUROC\", fontsize=30)\n",
    "        else:\n",
    "            ax.set_ylabel(\"\")\n",
    "        \n",
    "        if row_idx == 0:\n",
    "            ax.set_title(rf\"$\\bf{{{int(frac_label_corruption*100)}}} \\% \\ $\" + r\"$\\bf{Corrupted}$\", fontsize=30)\n",
    "\n",
    "        if col_idx == n_cols - 1:\n",
    "            ax.legend(fontsize=22, loc='lower right')\n",
    "\n",
    "plt.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "plt.suptitle(r\"\\textbf{Logistic}\", fontsize=50, usetex=True)\n",
    "plt.savefig(f\"feature_ranking_logistic_linear_full_{model}.pdf\", format='pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dgp = \"logistic_interaction\"\n",
    "datasets = ['openml_361063', 'openml_361069', 'openml_361062', 'openml_9978', 'openml_361071', 'openml_43']\n",
    "feature_seeds = [1,2,3,4,5,6,7,8,9,10]\n",
    "sample_seeds = [1,2,3]\n",
    "combined_df = pd.DataFrame()\n",
    "for data in datasets:\n",
    "    ablation_directory = f\"./results/mdi_local_{model}.real_data_classification_{data}_{dgp}/{data}_{dgp}/varying_frac_label_corruption_sample_row_n\"\n",
    "    for sample_seed in sample_seeds:\n",
    "        for feature_seed in feature_seeds:\n",
    "            try:\n",
    "                df = pd.read_csv(os.path.join(ablation_directory, f\"seed_{feature_seed}_{sample_seed}/results.csv\"))\n",
    "                df[\"data\"] = data\n",
    "                combined_df = pd.concat([combined_df, df], ignore_index=True)\n",
    "            except Exception as e:\n",
    "                print(f\"Error reading {ablation_directory}/seed_{feature_seed}_{sample_seed}/results.csv: {e}\")\n",
    "\n",
    "\n",
    "agg_df = combined_df.groupby(['sample_row_n', 'frac_label_corruption', 'fi', 'data'])[\n",
    "    [\"auroc_train\", \"auroc_test\"]\n",
    "].agg(['mean', 'std', 'count']).reset_index()\n",
    "agg_df.columns = ['_'.join(col).strip('_') if col[1] else col[0] for col in agg_df.columns.values]\n",
    "agg_df[\"auroc_train_sem\"] = agg_df[\"auroc_train_std\"] / np.sqrt(agg_df[\"auroc_train_count\"])\n",
    "agg_df[\"auroc_test_sem\"] = agg_df[\"auroc_test_std\"] / np.sqrt(agg_df[\"auroc_test_count\"])\n",
    "\n",
    "df = agg_df\n",
    "df = df[df[\"sample_row_n\"] != 800]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = datasets\n",
    "frac_label_corruption_all = df[\"frac_label_corruption\"].unique()\n",
    "marker_size = 7\n",
    "\n",
    "n_cols = len(frac_label_corruption_all)\n",
    "n_rows = len(datasets)\n",
    "\n",
    "fig, axs = plt.subplots(\n",
    "    nrows=n_rows,\n",
    "    ncols=n_cols,\n",
    "    figsize=(8 * n_cols, 6.5 * n_rows),\n",
    "    sharey=False\n",
    ")\n",
    "\n",
    "if n_rows == 1:\n",
    "    axs = np.expand_dims(axs, axis=0)\n",
    "if n_cols == 1:\n",
    "    axs = np.expand_dims(axs, axis=1)\n",
    "\n",
    "for row_idx, dataset in enumerate(datasets):\n",
    "    for col_idx, frac_label_corruption in enumerate(frac_label_corruption_all):\n",
    "        ax = axs[row_idx, col_idx]\n",
    "        subset = df[(df[\"data\"] == dataset) & (df[\"frac_label_corruption\"] == frac_label_corruption)]\n",
    "        \n",
    "        for method in methods:\n",
    "            method_data = subset[subset[\"fi\"] == method]\n",
    "\n",
    "            if method in ['LIME', 'Treeshap', 'Local MDI']:\n",
    "                ax.errorbar(\n",
    "                    method_data[\"sample_row_n\"], method_data['auroc_test_mean'], yerr=method_data[\"auroc_test_sem\"],\n",
    "                    linestyle='solid', marker='o', markersize=marker_size,\n",
    "                    label=methods_name[method], color=color_map[method], linewidth=4, alpha=0.6\n",
    "                )\n",
    "            else:\n",
    "                ax.errorbar(\n",
    "                    method_data[\"sample_row_n\"], method_data['auroc_test_mean'], yerr=method_data[\"auroc_test_sem\"],\n",
    "                    linestyle='solid', marker='o', markersize=marker_size,\n",
    "                    label=methods_name[method], color=color_map[method], linewidth=4\n",
    "                )\n",
    "        \n",
    "        ax.set_xticks([300,500, 1000])\n",
    "        ax.set_xticklabels([\"300\", \"500\", \"1000\"], fontsize=25)\n",
    "        ax.tick_params(axis='y', labelsize=25)\n",
    "        if row_idx == n_rows - 1:\n",
    "            ax.set_xlabel(\"Sample Size\", fontsize=30)\n",
    "\n",
    "        if col_idx == 0:\n",
    "            dataset_label = data_name[dataset]\n",
    "            p_val = feature_values[dataset]\n",
    "            dataset_label = dataset_label.replace(' ', r'\\ ')\n",
    "            ax.set_ylabel(f\"$\\\\mathbf{{{dataset_label}\\ (p={p_val})}}$\\nAUROC\", fontsize=30)\n",
    "        else:\n",
    "            ax.set_ylabel(\"\")\n",
    "        \n",
    "        if row_idx == 0:\n",
    "            ax.set_title(rf\"$\\bf{{{int(frac_label_corruption*100)}}} \\% \\ $\" + r\"$\\bf{Corrupted}$\", fontsize=30)\n",
    "\n",
    "        if col_idx == n_cols - 1:\n",
    "            ax.legend(fontsize=22, loc='lower right')\n",
    "\n",
    "plt.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "plt.suptitle(r\"\\textbf{Logistic Interaction}\", fontsize=50, usetex=True)\n",
    "plt.savefig(f\"feature_ranking_logistic_interaction_full_{model}.pdf\", format='pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dgp = \"logistic_linear_lss\"\n",
    "datasets = ['openml_361063', 'openml_361069', 'openml_361062', 'openml_9978', 'openml_361071', 'openml_43']\n",
    "feature_seeds = [1,2,3,4,5,6,7,8,9,10]\n",
    "sample_seeds = [1,2,3]\n",
    "combined_df = pd.DataFrame()\n",
    "for data in datasets:\n",
    "    ablation_directory = f\"./results/mdi_local_{model}.real_data_classification_{data}_{dgp}/{data}_{dgp}_threshold_05/varying_frac_label_corruption_sample_row_n\"\n",
    "    for sample_seed in sample_seeds:\n",
    "        for feature_seed in feature_seeds:\n",
    "            try:\n",
    "                df = pd.read_csv(os.path.join(ablation_directory, f\"seed_{feature_seed}_{sample_seed}/results.csv\"))\n",
    "                df[\"data\"] = data\n",
    "                combined_df = pd.concat([combined_df, df], ignore_index=True)\n",
    "            except Exception as e:\n",
    "                print(f\"Error reading {ablation_directory}/seed_{feature_seed}_{sample_seed}/results.csv: {e}\")\n",
    "\n",
    "agg_df = combined_df.groupby(['sample_row_n', 'frac_label_corruption', 'fi', 'data'])[\n",
    "    [\"auroc_train\", \"auroc_test\"]\n",
    "].agg(['mean', 'std', 'count']).reset_index()\n",
    "agg_df.columns = ['_'.join(col).strip('_') if col[1] else col[0] for col in agg_df.columns.values]\n",
    "agg_df[\"auroc_train_sem\"] = agg_df[\"auroc_train_std\"] / np.sqrt(agg_df[\"auroc_train_count\"])\n",
    "agg_df[\"auroc_test_sem\"] = agg_df[\"auroc_test_std\"] / np.sqrt(agg_df[\"auroc_test_count\"])\n",
    "\n",
    "df = agg_df\n",
    "df = df[df[\"sample_row_n\"] != 800]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = datasets\n",
    "frac_label_corruption_all = df[\"frac_label_corruption\"].unique()\n",
    "marker_size = 7\n",
    "\n",
    "n_cols = len(frac_label_corruption_all)\n",
    "n_rows = len(datasets)\n",
    "\n",
    "fig, axs = plt.subplots(\n",
    "    nrows=n_rows,\n",
    "    ncols=n_cols,\n",
    "    figsize=(8 * n_cols, 6.5 * n_rows),\n",
    "    sharey=False\n",
    ")\n",
    "\n",
    "if n_rows == 1:\n",
    "    axs = np.expand_dims(axs, axis=0)\n",
    "if n_cols == 1:\n",
    "    axs = np.expand_dims(axs, axis=1)\n",
    "\n",
    "for row_idx, dataset in enumerate(datasets):\n",
    "    for col_idx, frac_label_corruption in enumerate(frac_label_corruption_all):\n",
    "        ax = axs[row_idx, col_idx]\n",
    "        subset = df[(df[\"data\"] == dataset) & (df[\"frac_label_corruption\"] == frac_label_corruption)]\n",
    "        \n",
    "        for method in methods:\n",
    "            method_data = subset[subset[\"fi\"] == method]\n",
    "\n",
    "            if method in ['LIME', 'Treeshap', 'Local MDI']:\n",
    "                ax.errorbar(\n",
    "                    method_data[\"sample_row_n\"], method_data['auroc_test_mean'], yerr=method_data[\"auroc_test_sem\"],\n",
    "                    linestyle='solid', marker='o', markersize=marker_size,\n",
    "                    label=methods_name[method], color=color_map[method], linewidth=4, alpha=0.6\n",
    "                )\n",
    "            else:\n",
    "                ax.errorbar(\n",
    "                    method_data[\"sample_row_n\"], method_data['auroc_test_mean'], yerr=method_data[\"auroc_test_sem\"],\n",
    "                    linestyle='solid', marker='o', markersize=marker_size,\n",
    "                    label=methods_name[method], color=color_map[method], linewidth=4\n",
    "                )\n",
    "        \n",
    "        ax.set_xticks([300,500, 1000])\n",
    "        ax.set_xticklabels([\"300\", \"500\", \"1000\"], fontsize=25)\n",
    "        ax.tick_params(axis='y', labelsize=25)\n",
    "        if row_idx == n_rows - 1:\n",
    "            ax.set_xlabel(\"Sample Size\", fontsize=30)\n",
    "        \n",
    "        if col_idx == 0:\n",
    "            dataset_label = data_name[dataset]\n",
    "            p_val = feature_values[dataset]\n",
    "            dataset_label = dataset_label.replace(' ', r'\\ ')\n",
    "            ax.set_ylabel(f\"$\\\\mathbf{{{dataset_label}\\ (p={p_val})}}$\\nAUROC\", fontsize=30)\n",
    "        else:\n",
    "            ax.set_ylabel(\"\")\n",
    "\n",
    "        if row_idx == 0:\n",
    "            ax.set_title(rf\"$\\bf{{{int(frac_label_corruption*100)}}} \\% \\ $\" + r\"$\\bf{Corrupted}$\", fontsize=30)\n",
    "\n",
    "        if col_idx == n_cols - 1:\n",
    "            ax.legend(fontsize=22, loc='lower right')\n",
    "\n",
    "plt.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "plt.suptitle(r\"\\textbf{Logistic + LSS}\", fontsize=50, usetex=True)\n",
    "plt.savefig(f\"feature_ranking_logistic_linear_lss_full_{model}.pdf\", format='pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mdi-new",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
