{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "899c1a16",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "pd.set_option('display.max_columns', None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa39d071",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['axes.labelsize'] = 30\n",
    "plt.rcParams['xtick.labelsize'] = 10\n",
    "plt.rcParams['ytick.labelsize'] = 12\n",
    "plt.rcParams['axes.spines.right'] = False\n",
    "plt.rcParams['axes.spines.top'] = False\n",
    "plt.rcParams['axes.edgecolor'] = 'black'\n",
    "plt.rcParams['axes.linewidth'] = 2.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4aa9bc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = \"rf\" # \"rf\" \"gradient_boosting\"\n",
    "\n",
    "methods = [\n",
    "    'lmdi+',\n",
    "    'LIME',\n",
    "    'Treeshap',\n",
    "]\n",
    "\n",
    "\n",
    "color_map = {\n",
    "    'LIME': '#71BEB7',\n",
    "    'Treeshap': 'orange',\n",
    "    'lmdi+': 'black',\n",
    "}\n",
    "methods_name = {\n",
    "    'LIME': 'LIME',\n",
    "    'Treeshap': 'TreeSHAP',\n",
    "    'lmdi+': 'LMDI+',\n",
    "}\n",
    "\n",
    "if model == \"rf\":\n",
    "    methods.extend(['Local MDI'])\n",
    "    color_map['Local MDI'] = '#9B5DFF'\n",
    "    methods_name['Local MDI'] = 'Local MDI'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e795187d",
   "metadata": {},
   "outputs": [],
   "source": [
    "task = \"classification\" \n",
    "datasets = ['openml_361063', 'openml_361069', 'openml_361062', 'openml_9978', 'openml_361071', 'openml_43']\n",
    "split_seeds = [1,2,3]\n",
    "sample_seeds = [1,2,3,4,5]\n",
    "combined_df = pd.DataFrame()\n",
    "for data in datasets:\n",
    "    ablation_directory =f\"./results/mdi_local_{model}.real_data_{task}_{data}_stability/{data}_stability/varying_sample_row_n\"\n",
    "    for split_seed in split_seeds:\n",
    "        for sample_seed in sample_seeds:\n",
    "            df = pd.read_csv(os.path.join(ablation_directory, f\"seed_{split_seed}_{sample_seed}/results.csv\"))\n",
    "            df[\"data\"] = data\n",
    "            combined_df = pd.concat([combined_df, df], ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0336956f",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_name = {\n",
    "    \"openml_43\": \"Spam\",\n",
    "    \"openml_361062\": \"Pol\",\n",
    "    \"openml_361071\": \"Jannis\",\n",
    "    \"openml_9978\": \"Ozone\",\n",
    "    \"openml_361069\": \"Higgs\",\n",
    "    \"openml_361063\": \"House 16H\"\n",
    "}\n",
    "\n",
    "feature_values = {\n",
    "    \"openml_43\": 57,\n",
    "    \"openml_361062\": 26,\n",
    "    \"openml_361071\": 54,\n",
    "    \"openml_9978\": 47,\n",
    "    \"openml_361069\": 24,\n",
    "    \"openml_361063\": 16\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "237b53af",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = combined_df\n",
    "datasets = df[\"data\"].unique()\n",
    "sample_size_all = df[\"sample_row_n\"].unique()\n",
    "marker_size = 7\n",
    "\n",
    "n_cols = len(sample_size_all)\n",
    "n_rows = len(datasets)\n",
    "\n",
    "fig, axs = plt.subplots(\n",
    "    nrows=n_rows,\n",
    "    ncols=n_cols,\n",
    "    figsize=(8 * n_cols, 5 * n_rows),\n",
    "    sharey=False\n",
    ")\n",
    "\n",
    "if n_rows == 1:\n",
    "    axs = np.expand_dims(axs, axis=0)\n",
    "if n_cols == 1:\n",
    "    axs = np.expand_dims(axs, axis=1)\n",
    "\n",
    "for row_idx, dataset in enumerate(datasets):\n",
    "    for col_idx, sample_size in enumerate(sample_size_all):\n",
    "        ax = axs[row_idx, col_idx]\n",
    "        subset = df[(df[\"data\"] == dataset) & (df[\"sample_row_n\"] == sample_size)]\n",
    "\n",
    "        for method in methods:\n",
    "            method_data = subset[subset[\"fi\"] == method]\n",
    "\n",
    "            x = [0.1, 0.2, 0.3, 0.4]\n",
    "            cols = [\"avg_10_features\", \"avg_20_features\", \"avg_30_features\", \"avg_40_features\"]\n",
    "            \n",
    "            n_features = feature_values[dataset]\n",
    "            normalized_data = method_data[cols] / n_features\n",
    "            means = normalized_data.mean(axis=0).values\n",
    "            stds = normalized_data.std(axis=0).values\n",
    "            counts = normalized_data.count(axis=0).values\n",
    "            sems = stds / np.sqrt(counts)\n",
    "\n",
    "            if method in ['LIME', 'Treeshap', 'Local MDI']:\n",
    "                ax.errorbar(\n",
    "                    x, means, sems,\n",
    "                    linestyle='solid', marker='o', markersize=marker_size,\n",
    "                    label=methods_name[method], color=color_map[method], linewidth=4, alpha=0.6\n",
    "                )\n",
    "            else:\n",
    "                ax.errorbar(\n",
    "                    x, means, sems,\n",
    "                    linestyle='solid', marker='o', markersize=marker_size,\n",
    "                    label=methods_name[method], color=color_map[method], linewidth=4\n",
    "                )\n",
    "\n",
    "        \n",
    "        ax.set_xticks([0.1, 0.2, 0.3, 0.4])\n",
    "        ax.set_xticklabels([\"10%\", \"20%\", \"30%\", \"40%\"], fontsize=22)\n",
    "        ax.tick_params(axis='y', labelsize=22)\n",
    "        if row_idx == n_rows - 1:\n",
    "            ax.set_xlabel(\"Top-k% Most Important Features\", fontsize=25)\n",
    "        \n",
    "        if col_idx == 0:\n",
    "            dataset_label = data_name[dataset]\n",
    "            p_val = feature_values[dataset]\n",
    "            dataset_label = dataset_label.replace(' ', r'\\ ')\n",
    "            ax.set_ylabel(f\"$\\\\mathbf{{{dataset_label}\\ (p={p_val})}}$\\nInstability Score (↓)\", fontsize=25)\n",
    "        else:\n",
    "            ax.set_ylabel(\"\")\n",
    "\n",
    "        \n",
    "        if row_idx == 0:\n",
    "            ax.set_title(rf\"$N = {sample_size}$\", fontsize=35)\n",
    "\n",
    "        if col_idx == n_cols - 1:\n",
    "            ax.legend(fontsize=22, loc='lower right')\n",
    "\n",
    "plt.tight_layout(rect=[0.01, 0, 1, 1])\n",
    "plt.savefig(f\"stability_classification_full_{model}.pdf\", format='pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "252da903",
   "metadata": {},
   "outputs": [],
   "source": [
    "task = \"regression\" \n",
    "combined_df = pd.DataFrame()\n",
    "datasets = ['openml_361260', 'openml_361254', 'openml_361259', 'openml_361253', 'openml_361243', 'openml_361242']\n",
    "for data in datasets:\n",
    "    ablation_directory =f\"./results/mdi_local_{model}.real_data_{task}_{data}_stability/{data}_stability/varying_sample_row_n\"\n",
    "    split_seeds = [1,2,3]\n",
    "    sample_seeds = [1,2,3,4,5]\n",
    "    for split_seed in split_seeds:\n",
    "        for sample_seed in sample_seeds:\n",
    "            df = pd.read_csv(os.path.join(ablation_directory, f\"seed_{split_seed}_{sample_seed}/results.csv\"))\n",
    "            df[\"data\"] = data\n",
    "            combined_df = pd.concat([combined_df, df], ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d5e7520",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_name = {\n",
    "    \"openml_361260\": \"Miami Housing\",\n",
    "    \"openml_361259\": \"Puma Robot\",\n",
    "    \"openml_361253\": \"Wave Energy\",\n",
    "    \"openml_361254\": \"SARCOS\",\n",
    "    \"openml_361242\": \"Super Conductivity\",\n",
    "    \"openml_361243\": \"Geographic Origin of Music\"\n",
    "}\n",
    "\n",
    "\n",
    "feature_values = {\n",
    "    \"openml_361260\": 15,\n",
    "    \"openml_361259\": 32,\n",
    "    \"openml_361253\": 48,\n",
    "    \"openml_361254\": 21,\n",
    "    \"openml_361242\": 81,\n",
    "    \"openml_361243\": 72\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a4435c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = combined_df\n",
    "datasets = df[\"data\"].unique()\n",
    "sample_size_all = df[\"sample_row_n\"].unique()\n",
    "marker_size = 7\n",
    "\n",
    "n_cols = len(sample_size_all)\n",
    "n_rows = len(datasets) \n",
    "\n",
    "fig, axs = plt.subplots(\n",
    "    nrows=n_rows,\n",
    "    ncols=n_cols,\n",
    "    figsize=(8 * n_cols, 5 * n_rows),\n",
    "    sharey=False \n",
    ")\n",
    "\n",
    "if n_rows == 1:\n",
    "    axs = np.expand_dims(axs, axis=0)\n",
    "if n_cols == 1:\n",
    "    axs = np.expand_dims(axs, axis=1)\n",
    "\n",
    "for row_idx, dataset in enumerate(datasets):\n",
    "    for col_idx, sample_size in enumerate(sample_size_all):\n",
    "        ax = axs[row_idx, col_idx]\n",
    "        subset = df[(df[\"data\"] == dataset) & (df[\"sample_row_n\"] == sample_size)]\n",
    "\n",
    "        for method in methods:\n",
    "            method_data = subset[subset[\"fi\"] == method]\n",
    "\n",
    "            x = [0.1, 0.2, 0.3, 0.4]\n",
    "            cols = [\"avg_10_features\", \"avg_20_features\", \"avg_30_features\", \"avg_40_features\"]\n",
    "            \n",
    "            n_features = feature_values[dataset]\n",
    "            normalized_data = method_data[cols] / n_features\n",
    "            means = normalized_data.mean(axis=0).values\n",
    "            stds = normalized_data.std(axis=0).values\n",
    "            counts = normalized_data.count(axis=0).values\n",
    "            sems = stds / np.sqrt(counts)\n",
    "\n",
    "            if method in ['LIME', 'Treeshap', 'Local MDI']:\n",
    "                ax.errorbar(\n",
    "                    x, means, sems,\n",
    "                    linestyle='solid', marker='o', markersize=marker_size,\n",
    "                    label=methods_name[method], color=color_map[method], linewidth=4, alpha=0.6\n",
    "                )\n",
    "            else:\n",
    "                ax.errorbar(\n",
    "                    x, means, sems,\n",
    "                    linestyle='solid', marker='o', markersize=marker_size,\n",
    "                    label=methods_name[method], color=color_map[method], linewidth=4\n",
    "                )\n",
    "        \n",
    "        ax.set_xticks([0.1, 0.2, 0.3, 0.4])\n",
    "        ax.set_xticklabels([\"10%\", \"20%\", \"30%\", \"40%\"], fontsize=22)\n",
    "        ax.tick_params(axis='y', labelsize=22)\n",
    "        if row_idx == n_rows - 1:\n",
    "            ax.set_xlabel(\"Top-k% Most Important Features\", fontsize=25)\n",
    "\n",
    "        if col_idx == 0:\n",
    "            dataset_label = data_name[dataset]\n",
    "            p_val = feature_values[dataset]\n",
    "            dataset_label = dataset_label.replace(' ', r'\\ ')\n",
    "            if dataset == \"openml_361243\":\n",
    "                ax.set_ylabel(\n",
    "                    f\"$\\\\mathbf{{Geographic\\ Origin\\ of}}$\\n$\\\\mathbf{{Music\\ (p={p_val})}}$\\nAvg. # of Unique Features\",\n",
    "                    fontsize=25\n",
    "                )\n",
    "            else:\n",
    "                ax.set_ylabel(f\"$\\\\mathbf{{{dataset_label}\\ (p={p_val})}}$\\nInstability Score (↓)\", fontsize=25)\n",
    "        else:\n",
    "            ax.set_ylabel(\"\")\n",
    "\n",
    "        \n",
    "        if row_idx == 0:\n",
    "            ax.set_title(rf\"$N = {sample_size}$\", fontsize=35)\n",
    "\n",
    "        if col_idx == n_cols - 1:\n",
    "            ax.legend(fontsize=22, loc='lower right')\n",
    "\n",
    "plt.tight_layout(rect=[0.01, 0, 1, 1])\n",
    "plt.savefig(f\"stability_regression_full_{model}.pdf\", format='pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mdi-new",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
