{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed6cfeb1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import OrderedDict\n",
    "import re\n",
    "import os\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import torch\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "from eval import get_run_metrics, read_run_dir, get_model_from_run\n",
    "from plot_utils import basic_plot, collect_results, relevant_model_names\n",
    "\n",
    "%matplotlib inline\n",
    "%load_ext autoreload  \n",
    "%autoreload 2\n",
    "\n",
    "run_dir = \"../models\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e8d018b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "df = read_run_dir(run_dir)\n",
    "df  # list all the runs in our run_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfbbcd1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# task = \"combination_regression\"\n",
    "# task = \"linear_regression\"\n",
    "# task = \"sparse_linear_regression\"\n",
    "#task = \"decision_tree\"\n",
    "task = \"relu_2nn_regression\"\n",
    "\n",
    "run_id = \"0827192358\"\n",
    "# run_id = \"0823174809\"  # if you train more models, replace with the run_id from the table above\n",
    "# run_id = \"0825202036\"\n",
    "\n",
    "run_path = os.path.join(run_dir, task, run_id)\n",
    "recompute_metrics = False\n",
    "\n",
    "if recompute_metrics:\n",
    "    get_run_metrics(run_path)  # these are normally precomputed at the end of training\n",
    "\n",
    "def valid_row(r):\n",
    "    return r.task == task and r.run_id == run_id\n",
    "\n",
    "metrics = collect_results(run_dir, df, valid_row=valid_row)\n",
    "_, conf = get_model_from_run(run_path, only_conf=True)\n",
    "n_dims = conf.model.n_dims\n",
    "\n",
    "models = relevant_model_names[task]\n",
    "print(metrics[\"standard\"].keys())\n",
    "basic_plot(metrics[\"standard\"], models=models)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9980951",
   "metadata": {},
   "outputs": [],
   "source": [
    "task = \"combination_regression\"\n",
    "# task = \"linear_regression\"\n",
    "# task = \"sparse_linear_regression\"\n",
    "#task = \"decision_tree\"\n",
    "#task = \"relu_2nn_regression\"\n",
    "\n",
    "run_id = \"0827190756\"\n",
    "# run_id = \"0823174809\"  # if you train more models, replace with the run_id from the table above\n",
    "# run_id = \"0825202036\"\n",
    "\n",
    "run_path = os.path.join(run_dir, task, run_id)\n",
    "recompute_metrics = False\n",
    "\n",
    "if recompute_metrics:\n",
    "    get_run_metrics(run_path)  # these are normally precomputed at the end of training\n",
    "\n",
    "def valid_row(r):\n",
    "    return r.task == task and r.run_id == run_id\n",
    "\n",
    "metrics = collect_results(run_dir, df, valid_row=valid_row)\n",
    "_, conf = get_model_from_run(run_path, only_conf=True)\n",
    "n_dims = conf.model.n_dims\n",
    "\n",
    "models = relevant_model_names[task]\n",
    "print(metrics[\"standard\"].keys())\n",
    "basic_plot(metrics[\"standard\"], models=models)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e33a2f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# task = \"combination_regression\"\n",
    "# task = \"linear_regression\"\n",
    "task = \"sparse_linear_regression\"\n",
    "#task = \"decision_tree\"\n",
    "#task = \"relu_2nn_regression\"\n",
    "\n",
    "# run_id = \"0826203246\"\n",
    "# run_id = \"0823174809\"  # if you train more models, replace with the run_id from the table above\n",
    "run_id = \"0825202036\"\n",
    "\n",
    "run_path = os.path.join(run_dir, task, run_id)\n",
    "recompute_metrics = False\n",
    "\n",
    "if recompute_metrics:\n",
    "    get_run_metrics(run_path)  # these are normally precomputed at the end of training\n",
    "def valid_row(r):\n",
    "    return r.task == task and r.run_id == run_id\n",
    "\n",
    "metrics = collect_results(run_dir, df, valid_row=valid_row)\n",
    "_, conf = get_model_from_run(run_path, only_conf=True)\n",
    "n_dims = conf.model.n_dims\n",
    "\n",
    "models = relevant_model_names[task]\n",
    "print(metrics[\"standard\"].keys())\n",
    "basic_plot(metrics[\"standard\"], models=models)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6d09964",
   "metadata": {},
   "source": [
    "# Plot pre-computed metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd8e02c5",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# task = \"combination_regression\"\n",
    "task = \"linear_regression\"\n",
    "# task = \"sparse_linear_regression\"\n",
    "#task = \"relu_2nn_regression\"\n",
    "\n",
    "# run_id = \"0826203246\"\n",
    "run_id = \"0823174809\"  # if you train more models, replace with the run_id from the table above\n",
    "# run_id = \"0825202036\"\n",
    "\n",
    "run_path = os.path.join(run_dir, task, run_id)\n",
    "recompute_metrics = False\n",
    "\n",
    "if recompute_metrics:\n",
    "    get_run_metrics(run_path)  # these are normally precomputed at the end of training\n",
    "def valid_row(r):\n",
    "    return r.task == task and r.run_id == run_id\n",
    "\n",
    "metrics = collect_results(run_dir, df, valid_row=valid_row)\n",
    "_, conf = get_model_from_run(run_path, only_conf=True)\n",
    "n_dims = conf.model.n_dims\n",
    "\n",
    "models = relevant_model_names[task]\n",
    "print(metrics[\"standard\"].keys())\n",
    "basic_plot(metrics[\"standard\"], models=models)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31b4ecca",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# plot any OOD metrics\n",
    "for name, metric in metrics.items():\n",
    "    if name == \"standard\": continue\n",
    "   \n",
    "    if \"scale\" in name:\n",
    "        scale = float(name.split(\"=\")[-1])**2\n",
    "    else:\n",
    "        scale = 1.0\n",
    "    print(name)\n",
    "    print(scale)\n",
    "    trivial = 1.0 if \"noisy\" not in name else (1+1/n_dims)\n",
    "    fig, ax = basic_plot(metric, models=models, trivial=trivial * scale)\n",
    "    ax.set_title(name)\n",
    "    \n",
    "    # if \"ortho\" in name:\n",
    "    #     ax.set_xlim(-1, n_dims - 1)\n",
    "    # ax.set_ylim(-.1 * scale, 1.5 * scale)\n",
    "\n",
    "    plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "in-context-learning",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
