{
 "cells": [
  {
   "cell_type": "code",
   "id": "ed6cfeb1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-15T03:50:13.339729Z",
     "start_time": "2024-04-15T03:50:12.474815Z"
    }
   },
   "source": [
    "from collections import OrderedDict\n",
    "import re\n",
    "import os\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import torch\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "from eval import get_run_metrics, read_run_dir, get_model_from_run\n",
    "from plot_utils import basic_plot, collect_results, relevant_model_names\n",
    "\n",
    "\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "sns.set_theme('notebook', 'darkgrid')\n",
    "palette = sns.color_palette('colorblind')\n",
    "\n",
    "run_dir = \"../models\""
   ],
   "execution_count": 1,
   "outputs": []
  },
  {
   "cell_type": "code",
   "id": "0e8d018b",
   "metadata": {
    "scrolled": true,
    "ExecuteTime": {
     "end_time": "2024-04-15T03:50:17.948188Z",
     "start_time": "2024-04-15T03:50:17.616476Z"
    }
   },
   "source": [
    "df = read_run_dir(run_dir)\n",
    "df  # list all the runs in our run_dir"
   ],
   "execution_count": 2,
   "outputs": []
  },
  {
   "cell_type": "code",
   "id": "a9980951",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-15T03:51:07.167124Z",
     "start_time": "2024-04-15T03:51:07.140183Z"
    }
   },
   "source": [
    "task = \"linear_regression\"\n",
    "#task = \"sparse_linear_regression\"\n",
    "#task = \"decision_tree\"\n",
    "#task = \"relu_2nn_regression\"\n",
    "\n",
    "run_id = \"pretrained\"  # if you train more models, replace with the run_id from the table above\n",
    "\n",
    "run_path = os.path.join(run_dir, task, run_id)\n",
    "recompute_metrics = False\n",
    "\n",
    "if recompute_metrics:\n",
    "    get_run_metrics(run_path)  # these are normally precomputed at the end of training"
   ],
   "execution_count": 3,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "f6d09964",
   "metadata": {},
   "source": [
    "# Plot pre-computed metrics"
   ]
  },
  {
   "cell_type": "code",
   "id": "cd8e02c5",
   "metadata": {
    "scrolled": false,
    "ExecuteTime": {
     "end_time": "2024-04-15T03:51:14.709442Z",
     "start_time": "2024-04-15T03:51:14.670190Z"
    }
   },
   "source": [
    "def valid_row(r):\n",
    "    return r.task == task and r.run_id == run_id\n",
    "\n",
    "metrics = collect_results(run_dir, df, valid_row=valid_row)\n",
    "_, conf = get_model_from_run(run_path, only_conf=True)\n",
    "n_dims = conf.model.n_dims\n",
    "\n",
    "models = relevant_model_names[task]\n",
    "basic_plot(metrics[\"standard\"], models=models)\n",
    "plt.show()"
   ],
   "execution_count": 4,
   "outputs": []
  },
  {
   "cell_type": "code",
   "id": "31b4ecca",
   "metadata": {
    "scrolled": true,
    "ExecuteTime": {
     "end_time": "2024-04-15T03:51:18.534907Z",
     "start_time": "2024-04-15T03:51:18.491534Z"
    }
   },
   "source": [
    "# plot any OOD metrics\n",
    "for name, metric in metrics.items():\n",
    "    if name == \"standard\": continue\n",
    "   \n",
    "    if \"scale\" in name:\n",
    "        scale = float(name.split(\"=\")[-1])**2\n",
    "    else:\n",
    "        scale = 1.0\n",
    "\n",
    "    trivial = 1.0 if \"noisy\" not in name else (1+1/n_dims)\n",
    "    fig, ax = basic_plot(metric, models=models, trivial=trivial * scale)\n",
    "    ax.set_title(name)\n",
    "    \n",
    "    if \"ortho\" in name:\n",
    "        ax.set_xlim(-1, n_dims - 1)\n",
    "    ax.set_ylim(-.1 * scale, 1.5 * scale)\n",
    "\n",
    "    plt.show()"
   ],
   "execution_count": 5,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "e6f961d4",
   "metadata": {},
   "source": [
    "# Interactive setup\n",
    "\n",
    "We will now directly load the model and measure its in-context learning ability on a batch of random inputs. (In the paper we average over multiple such batches to obtain better estimates.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "beb327ce",
   "metadata": {},
   "source": [
    "from samplers import get_data_sampler\n",
    "from tasks import get_task_sampler"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "03523b06",
   "metadata": {},
   "source": [
    "model, conf = get_model_from_run(run_path)\n",
    "\n",
    "n_dims = conf.model.n_dims\n",
    "batch_size = conf.training.batch_size\n",
    "\n",
    "data_sampler = get_data_sampler(conf.training.data, n_dims)\n",
    "task_sampler = get_task_sampler(\n",
    "    conf.training.task,\n",
    "    n_dims,\n",
    "    batch_size,\n",
    "    **conf.training.task_kwargs\n",
    ")"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1d9da7c3",
   "metadata": {},
   "source": [
    "task = task_sampler()\n",
    "xs = data_sampler.sample_xs(b_size=batch_size, n_points=conf.training.curriculum.points.end)\n",
    "ys = task.evaluate(xs)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "cb69ddda",
   "metadata": {},
   "source": [
    "with torch.no_grad():\n",
    "    pred = model(xs, ys)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "2aa97fa5",
   "metadata": {},
   "source": [
    "metric = task.get_metric()\n",
    "loss = metric(pred, ys).numpy()\n",
    "\n",
    "sparsity = conf.training.task_kwargs.sparsity if \"sparsity\" in conf.training.task_kwargs else None\n",
    "baseline = {\n",
    "    \"linear_regression\": n_dims,\n",
    "    \"sparse_linear_regression\": sparsity,\n",
    "    \"relu_2nn_regression\": n_dims,\n",
    "    \"decision_tree\": 1,\n",
    "}[conf.training.task]\n",
    "\n",
    "plt.plot(loss.mean(axis=0), lw=2, label=\"Transformer\")\n",
    "plt.axhline(baseline, ls=\"--\", color=\"gray\", label=\"zero estimator\")\n",
    "plt.xlabel(\"# in-context examples\")\n",
    "plt.ylabel(\"squared error\")\n",
    "plt.legend()\n",
    "plt.show()"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "eae775a1",
   "metadata": {},
   "source": [
    "As an exploration example, let's see how robust the model is to doubling all the inputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a58e04e4",
   "metadata": {},
   "source": [
    "xs2 = 2 * xs\n",
    "ys2 = task.evaluate(xs2)\n",
    "with torch.no_grad():\n",
    "    pred2 = model(xs2, ys2)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "7ea71ba5",
   "metadata": {},
   "source": [
    "loss2 = metric(pred2, ys2).numpy()\n",
    "\n",
    "plt.plot(loss.mean(axis=0), lw=2, label=\"Transformer\")\n",
    "plt.plot(loss2.mean(axis=0) / 4, lw=2, label=\"Transformer on doubled inputs\")\n",
    "plt.axhline(baseline, ls=\"--\", color=\"gray\", label=\"zero estimator\")\n",
    "plt.xlabel(\"# in-context examples\")\n",
    "plt.ylabel(\"squared error\")\n",
    "plt.legend()\n",
    "plt.show()"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "021f118e",
   "metadata": {},
   "source": [
    "The error does increase, especially when the number of in-context examples exceeds the dimension, but the model is still relatively accurate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3dc9cc8",
   "metadata": {},
   "source": [],
   "outputs": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
