{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "All results processing functions are put together in this script, including the metric processing for multiple runs, prediction results plot, training loss plot, etc.\n",
    "'''\n",
    "%matplotlib inline\n",
    "\n",
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import matthews_corrcoef as mcc_fn\n",
    "from sklearn.metrics import accuracy_score as acc_fn\n",
    "\n",
    "exp_dir = os.path.abspath(\"\")\n",
    "print(exp_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Process the metric from multiple runs and save the mean+std results\n",
    "\n",
    "def compute_metrics(preds, labels, hr_threshold, num_classes=7):\n",
    "    # Compute runtime (rt) and halt rate (hr) based on different halt rate thresholds.\n",
    "    if labels.shape[1] > 7:\n",
    "        runtime = labels[:, 2:-1]\n",
    "    else:\n",
    "        runtime = labels\n",
    "\n",
    "    if num_classes == 5:   # Remove kissat3 and bulky\n",
    "        runtime = torch.concat([runtime[:, 1:5], runtime[:, 6:]], dim=-1)\n",
    "\n",
    "    base_solver_idx = runtime.mean(dim=0).argmin()\n",
    "    # print(f'Base solver idx: {base_solver_idx}')\n",
    "\n",
    "    base_time = runtime[:, base_solver_idx]\n",
    "    min_time, _ = runtime.min(dim=1) \n",
    "\n",
    "    len_data = preds.shape[0]\n",
    "\n",
    "    base_hr_cnt = (base_time > hr_threshold).sum().float()\n",
    "    min_hr_cnt = (min_time > hr_threshold).sum().float()\n",
    "    hr_min = 100. * (base_hr_cnt - min_hr_cnt) / base_hr_cnt \n",
    "\n",
    "    pred_top1 = preds.argmax(dim=1) # Predicted top1 solver\n",
    "    pred_time_top1 = runtime[torch.arange(len_data), pred_top1] \n",
    "\n",
    "    # Compute the better time among predicted top2 solvers\n",
    "    _, pred_idx_top2 = torch.topk(preds, 2, dim=1)\n",
    "    tmp = torch.zeros((len_data, 2))\n",
    "    tmp[:, 0] = runtime[torch.arange(len_data), pred_idx_top2[:, 0]]\n",
    "    tmp[:, 1] = runtime[torch.arange(len_data), pred_idx_top2[:, 1]]\n",
    "    pred_time_top2, _ = tmp.min(dim=1)\n",
    "\n",
    "    hr_cnt_top1 = (pred_time_top1 > hr_threshold).sum().float()\n",
    "    hr_top1 = 100. * (base_hr_cnt - hr_cnt_top1) / base_hr_cnt\n",
    "    hr_cnt_top2 = (pred_time_top2 > hr_threshold).sum().float()\n",
    "    hr_top2 = 100. * (base_hr_cnt - hr_cnt_top2) / base_hr_cnt\n",
    "\n",
    "    rt_top1 = pred_time_top1.mean()\n",
    "    rt_top2 = pred_time_top2.mean()\n",
    "    rt_min = min_time.mean()\n",
    "    rt_base = base_time.mean()\n",
    "\n",
    "    # Compute mcc and acc\n",
    "    tgt_idx = runtime.argmin(dim=1)\n",
    "    mcc = mcc_fn(tgt_idx, pred_top1)\n",
    "    acc = acc_fn(tgt_idx, pred_top1)\n",
    "\n",
    "    # Compute the ratio of proved SAT instance for the given cutoff threshold\n",
    "    ratio_proved_top1 = (pred_time_top1 <= hr_threshold).sum().float() / len_data\n",
    "    ratio_proved_top2 = (pred_time_top2 <= hr_threshold).sum().float() / len_data\n",
    "\n",
    "    return {'hr_min': hr_min, 'hr_top1': hr_top1, 'hr_top2': hr_top2, 'base_solver_idx': base_solver_idx,\n",
    "            'rt_base': rt_base, 'rt_min': rt_min, 'rt_top1': rt_top1, 'rt_top2': rt_top2, \n",
    "            'ratio_proved_top1': ratio_proved_top1, 'ratio_proved_top2': ratio_proved_top2, 'mcc': mcc, 'acc': acc}\n",
    "\n",
    "def process_metrics(num_classes=7):\n",
    "    hr_options = [100, 200, 300, 400, 500]\n",
    "    data = []\n",
    "    for split_idx in range(5):\n",
    "        run_dir = os.path.join(exp_dir, 'seed_604_split_'+str(split_idx))\n",
    "        test_label_file = os.path.join(run_dir, 'test_labels.csv')\n",
    "        test_pred_file = os.path.join(run_dir, 'test_pred_probs.csv')\n",
    "\n",
    "        if not os.path.isfile(test_pred_file):\n",
    "            continue\n",
    "\n",
    "        labels = pd.read_csv(test_label_file).to_numpy()\n",
    "        preds = pd.read_csv(test_pred_file).to_numpy()\n",
    "\n",
    "        labels = torch.tensor(labels)\n",
    "        preds = torch.tensor(preds)\n",
    "\n",
    "        test_metrics_list = []\n",
    "        for hr_threshold in hr_options:\n",
    "            metrics = compute_metrics(preds, labels, hr_threshold, num_classes)\n",
    "            test_metrics_list.append(metrics)\n",
    "\n",
    "        test_metrics = pd.DataFrame.from_records(test_metrics_list, index=hr_options).astype(float)\n",
    "        save_path = os.path.join(run_dir, 'test_metrics_v2.csv')\n",
    "        test_metrics.to_csv(save_path)\n",
    "\n",
    "        data.append(test_metrics)\n",
    "\n",
    "    data_arr = np.dstack([d.to_numpy() for d in data])\n",
    "\n",
    "    avg = data_arr.mean(axis=2).round(3)\n",
    "    std = data_arr.std(axis=2).round(3)\n",
    "\n",
    "    new_df = pd.DataFrame(columns=data[0].columns)\n",
    "    \n",
    "    for j, col in enumerate(new_df.columns):\n",
    "        dat = []\n",
    "        for i in range(avg.shape[0]):\n",
    "            avg_std = str(avg[i, j]) + str(u\"\\u00B1\") + str(std[i, j])\n",
    "            dat.append(avg_std)\n",
    "        new_df[col] = dat\n",
    "    \n",
    "    exp_name = exp_dir.split('/')[-1]\n",
    "    save_path = os.path.join(exp_dir, f'{exp_name}_all_runs.csv')\n",
    "    new_df.to_csv(save_path, encoding='utf-8-sig')\n",
    "    print(new_df)\n",
    "\n",
    "num_classes=7\n",
    "process_metrics(num_classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot prediction vs labels\n",
    "\n",
    "for split_idx in range(5):\n",
    "    run_dir = os.path.join(exp_dir, 'seed_604_split_'+str(split_idx))\n",
    "    pred_filepath = os.path.join(run_dir, 'test_pred_probs.csv')\n",
    "    label_filepath = os.path.join(run_dir, 'test_labels.csv')\n",
    "\n",
    "    if not os.path.isfile(pred_filepath):\n",
    "        continue\n",
    "\n",
    "    preds = pd.read_csv(pred_filepath, index_col=False).to_numpy()\n",
    "    # Note that the labels saved in sat_selection_light has variables other than runtime\n",
    "    labels = pd.read_csv(label_filepath, index_col=False).to_numpy()\n",
    "    if labels.shape[1] == 10:\n",
    "        labels = labels[:, 2:-1]\n",
    "\n",
    "    if num_classes == 5:   # Remove kissat3 and bulky\n",
    "        labels = np.concatenate([labels[:, 1:5], labels[:, 6:]], axis=-1)\n",
    "\n",
    "    pred_idx =  preds.argmax(axis=1)\n",
    "    tgt_idx = labels.argmin(axis=1)\n",
    "\n",
    "    plt.figure()\n",
    "\n",
    "    precision = []\n",
    "    recall = []\n",
    "    for i in range(7):\n",
    "        _idx = (tgt_idx==i)\n",
    "        TP = (pred_idx[_idx]==i).sum() \n",
    "        pred_cnt = (pred_idx==i).sum()\n",
    "        tgt_cnt = _idx.sum()\n",
    "        if pred_cnt == 0:\n",
    "            p = 0\n",
    "        else:\n",
    "            p = TP/pred_cnt\n",
    "        precision.append(p)\n",
    "        recall.append(TP/tgt_cnt)\n",
    "\n",
    "    tgts, tgts_cnt = np.unique(tgt_idx, return_counts=True)\n",
    "    preds, preds_cnt = np.unique(pred_idx, return_counts=True)\n",
    "    preds_cnt_padded = np.zeros_like(tgts_cnt)   # Pad the zero predicted solver with zeros\n",
    "    preds_cnt_padded[preds] = preds_cnt\n",
    "\n",
    "\n",
    "    fig, ax = plt.subplots()\n",
    "    bar_preds = ax.bar(tgts-0.15, preds_cnt_padded, width=0.30, color='tab:grey', edgecolor='black', alpha=0.6, label='predictions')\n",
    "    bar_tgts = ax.bar(tgts+0.15, tgts_cnt, width=0.30, color='tab:blue', edgecolor='black', alpha=0.6, label='targets')\n",
    "    # Add precision and recall values above the bar\n",
    "    for i, rect in enumerate(bar_preds):\n",
    "        x = rect.get_x() + rect.get_width()/2\n",
    "        y = rect.get_height() + 5\n",
    "        p_r = f\"p:{precision[i]:.2f}\\nr:{recall[i]:.2f}\"\n",
    "        ax.text(x, y, p_r, ha='center', va='bottom')\n",
    "    # ax.bar_label(bar_preds, fmt='%.2g', padding=2)\n",
    "    # ax.bar_label(bar_tgts, fmt='%.2g', padding=2)\n",
    "\n",
    "    plt.title(f\"Top1 Prediction vs Targets - split_{str(split_idx)}\")\n",
    "    plt.legend(loc='best')\n",
    "    plt.xlabel('Solver Index')\n",
    "    plt.ylabel('Count')\n",
    "    plot_savepath = pred_filepath.split('.')[0]+'.png' \n",
    "    plt.savefig(plot_savepath)\n",
    "    plt.show()\n",
    "    # plt.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot training loss\n",
    "\n",
    "rolling_window_size = 100\n",
    "for split_idx in range(5):\n",
    "    run_dir = os.path.join(exp_dir, 'seed_604_split_'+str(split_idx))\n",
    "    step_metric_file = os.path.join(run_dir, 'step_metrics.csv')\n",
    "    epoch_metric_file = os.path.join(run_dir, 'epoch_metrics.csv')\n",
    "    if not os.path.isfile(step_metric_file):\n",
    "        continue\n",
    "    \n",
    "    step_metric = pd.read_csv(step_metric_file, index_col=0)\n",
    "    epoch_metric = pd.read_csv(epoch_metric_file, index_col=0)\n",
    "\n",
    "    total_steps = step_metric.shape[0]\n",
    "    total_epoch = epoch_metric.shape[0]\n",
    "\n",
    "    val_loss_step = np.arange(total_epoch) * int(total_steps/total_epoch)\n",
    "    val_loss = epoch_metric['val_loss'].values\n",
    "\n",
    "    fig = plt.figure()\n",
    "    plt.plot(step_metric.rolling(rolling_window_size).mean()['train_loss'], label='train_loss')\n",
    "    plt.plot(val_loss_step, val_loss, color='red', label='val_loss')\n",
    "    plt.legend(loc='best')\n",
    "    plt.title(f'Train / val loss in training - split_{split_idx}')\n",
    "    plot_savepath = os.path.join(run_dir, 'step_loss.png')\n",
    "    plt.savefig(plot_savepath)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the cost of not predicting the best solver\n",
    "\n",
    "for split_idx in range(5):\n",
    "    run_dir = os.path.join(exp_dir, 'seed_604_split_'+str(split_idx))\n",
    "    pred_filepath = os.path.join(run_dir, 'test_pred_probs.csv')\n",
    "    label_filepath = os.path.join(run_dir, 'test_labels.csv')\n",
    "\n",
    "    if not os.path.isfile(pred_filepath):\n",
    "        continue\n",
    "\n",
    "    preds = pd.read_csv(pred_filepath, index_col=False).to_numpy()\n",
    "    # Note that the labels saved in sat_selection_light has variables other than runtime\n",
    "    labels = pd.read_csv(label_filepath, index_col=False).to_numpy()\n",
    "    if labels.shape[1] == 10:\n",
    "        runtime = labels[:, 2:-1]\n",
    "\n",
    "    if num_classes == 5:   # Remove kissat3 and bulky\n",
    "        runtime = np.concatenate([labels[:, 1:5], labels[:, 6:]], axis=-1)\n",
    "\n",
    "    pred_idx =  preds.argmax(axis=1)\n",
    "    tgt_idx = runtime.argmin(axis=1)\n",
    "\n",
    "    len_data = preds.shape[0]\n",
    "    pred_rt = runtime[torch.arange(len_data), pred_idx] \n",
    "    tgt_rt = runtime.min(axis=1)\n",
    "\n",
    "    rt_delta = pred_rt - tgt_rt\n",
    "\n",
    "    rt_delta_cost = rt_delta[rt_delta>0]\n",
    "    cost_mean = rt_delta_cost.mean()\n",
    "    cost_std = rt_delta_cost.std()\n",
    "    print(rt_delta_cost)\n",
    "\n",
    "    plt.figure()\n",
    "\n",
    "    # fig, axes = plt.subplots()\n",
    "    ax = sns.histplot(rt_delta_cost)\n",
    "    ax.text(1000, 1000, f\"mean={cost_mean:.3f}\\nstd={cost_std:.3f}\")\n",
    "\n",
    "    plt.title(f\"Cost of predicting wrong - split_{str(split_idx)}\")\n",
    "    plt.xlabel('Diff to min runtime')\n",
    "    plt.ylabel('Count')\n",
    "    # plot_savepath = pred_filepath.split('.')[0]+'.png' \n",
    "    # plt.savefig(plot_savepath)\n",
    "    plt.show()\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot runtime distribution of solvers\n",
    "\n",
    "for split_idx in range(5):\n",
    "    run_dir = os.path.join(exp_dir, 'seed_604_split_'+str(split_idx))\n",
    "    pred_filepath = os.path.join(run_dir, 'test_pred_probs.csv')\n",
    "    label_filepath = os.path.join(run_dir, 'test_labels.csv')\n",
    "\n",
    "    if not os.path.isfile(pred_filepath):\n",
    "        continue\n",
    "\n",
    "    # preds = pd.read_csv(pred_filepath, index_col=False).to_numpy()\n",
    "    # Note that the labels saved in sat_selection_light has variables other than runtime\n",
    "    labels = pd.read_csv(label_filepath, index_col=False)\n",
    "    \n",
    "    if labels.shape[1] == 10:\n",
    "        col_names = ['#var', '#clause', 'base', 'HyWalk', 'MOSS', 'mabgb', 'ESA', 'bulky', 'UCB', 'MIN']\n",
    "        labels.columns = col_names\n",
    "        runtime = labels.iloc[:, 2:-1]\n",
    "    \n",
    "    ax = sns.boxplot(runtime)\n",
    "\n",
    "    means = runtime.mean(axis=0) \n",
    "    stds = runtime.std(axis=0) \n",
    "    for xtick in ax.get_xticks():\n",
    "        ax.text(xtick, means[xtick],f\"{means[xtick]:.1f}({stds[xtick]:.1f})\", \n",
    "                horizontalalignment='center',size='x-small',color='k')\n",
    "\n",
    "    plt.title(f\"Runtime by each solver\")\n",
    "    plt.ylabel('Runtime (s)')\n",
    "    # plot_savepath = pred_filepath.split('.')[0]+'.png' \n",
    "    # plt.savefig(plot_savepath)\n",
    "    plt.show()\n",
    "    plt.close()"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "a78e89a48c699a6ba54266b96068dd7b0b69dfec9a3e27db3a261c177d599c1b"
  },
  "kernelspec": {
   "display_name": "Python 3.8.2 64-bit ('sat': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
