{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c302d92b",
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "fnames = [\"...\"]\n",
    "num_datasets = 8\n",
    "metric_names = [\n",
    "    'Validity', 'Proximity', 'Plausibility', 'Diversity',\n",
    "    'M. Robustness', 'In. Robustness', 'Time'\n",
    "]\n",
    "num_metrics = len(metric_names)\n",
    "\n",
    "names = ['dice', 'face', 'icce', 'nnce', 'ours-first', 'ours-last', 'ours-middle', 'stce']\n",
    "names_sorted = ['nnce', 'face', 'stce', 'dice', 'icce', 'ours-first', 'ours-middle', 'ours-last']\n",
    "num_methods = len(names_sorted)\n",
    "\n",
    "directions = ['$\\\\uparrow$', '$\\\\downarrow$', '$\\\\downarrow$', '$\\\\uparrow$', '$\\\\uparrow$', '$\\\\downarrow$', '$\\\\downarrow$']\n",
    "scales = ['linear', 'linear', 'linear', 'linear', 'linear', 'linear', 'linear']\n",
    "\n",
    "row_titles = [\n",
    "    'heloc\\nRF', 'heloc\\nNN', 'wine\\nRF', 'wine\\nNN',\n",
    "    'adult\\nRF', 'adult\\nNN', 'compas\\nRF', 'compas\\nNN',\n",
    "]\n",
    "\n",
    "fig, axes = plt.subplots(num_datasets, num_metrics, figsize=(18, 8), sharey=False)\n",
    "\n",
    "colors = [\n",
    "    '#648fff', '#708090', '#ffb000', '#fe6100', '#00b894', '#d6bcfb', '#9d76db', '#6d45b4',\n",
    "]\n",
    "\n",
    "for row_index in range(num_datasets):\n",
    "    results_heloc = pd.read_csv(fnames[row_index])\n",
    "    results_values = results_heloc.values[1:].astype(np.float16)\n",
    "    results_sorted = np.zeros_like(results_values)\n",
    "    for i, n in enumerate(names_sorted):\n",
    "        results_sorted[i] = results_values[names.index(n)]\n",
    "\n",
    "    plot_data = []\n",
    "    for i, method in enumerate(names_sorted):\n",
    "        for j, metric in enumerate(metric_names):\n",
    "            mean = results_sorted[i, j * 2]\n",
    "            std = results_sorted[i, j * 2 + 1]\n",
    "            plot_data.append({'Method': method, 'Metric': metric, 'Mean': mean, 'Std': std})\n",
    "    df = pd.DataFrame(plot_data)\n",
    "\n",
    "    for col_index, metric in enumerate(metric_names):\n",
    "        ax = axes[row_index, col_index]\n",
    "        metric_df = df[df['Metric'] == metric]\n",
    "        \n",
    "        x_positions = np.linspace(0.15, 0.85, num_methods)\n",
    "\n",
    "        for j, method in enumerate(names_sorted):\n",
    "            data = metric_df[metric_df['Method'] == method]\n",
    "            mean = data['Mean'].iloc[0]\n",
    "            std = abs(data['Std'].iloc[0])\n",
    "            color_idx = j\n",
    "\n",
    "            ax.errorbar(y=mean, x=x_positions[j], yerr=std, fmt='none', color=colors[color_idx],\n",
    "                        ecolor=colors[color_idx], elinewidth=1, capsize=3)\n",
    "            ax.errorbar(y=mean, x=x_positions[j], yerr=None, color=colors[color_idx],\n",
    "                        marker='_', markersize=14, markeredgewidth=5.5)\n",
    "\n",
    "        means = metric_df['Mean']\n",
    "        stds = abs(metric_df['Std'])\n",
    "        \n",
    "        min_val = (means - stds).min()\n",
    "        max_val = (means + stds).max()\n",
    "        \n",
    "        plot_range = max_val - min_val\n",
    "        # Handle cases where all values are the same to avoid a zero range\n",
    "        if plot_range == 0:\n",
    "            plot_range = abs(max_val) * 0.1 if max_val != 0 else 0.1\n",
    "        \n",
    "        padding = plot_range * 0.10\n",
    "        \n",
    "        ax.set_ylim(min_val - padding, max_val + padding)\n",
    "        ax.set_yscale(scales[col_index])\n",
    "        ax.set_xticks([])\n",
    "        ax.spines[['bottom', 'right', 'top']].set_visible(False)\n",
    "        ax.set_facecolor(\"#EEE8E281\")\n",
    "        ax.tick_params(axis='y', which='major', labelsize=10)\n",
    "        \n",
    "        if row_index == 0:\n",
    "            ax.set_title(f'{metric}{directions[col_index]}', fontsize=14, pad=15)\n",
    "\n",
    "    axes[row_index, 0].text(-0.28, 0.5, row_titles[row_index], transform=axes[row_index, 0].transAxes,\n",
    "                            ha='right', va='center', fontsize=15)\n",
    "\n",
    "plt.subplots_adjust(left=0.2, right=0.98, top=0.95, bottom=0.05, hspace=0.1, wspace=0.25)\n",
    "plt.savefig(\"visualisation.png\", dpi=300)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
