{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "850d978d-6e26-4f4f-8d68-c5fc10485c0a",
   "metadata": {},
   "source": [
    "# Outlier Bias"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60e900c5-8bf7-4344-b68d-d66d1f9cba05",
   "metadata": {},
   "source": [
    "In this notebook, we show an interesting application of the previously discussed primitive biases. We will combine several primitive biases to form an outlier bias."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bed7c51-b4f4-40e9-a2a0-ffe662be5fbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "pd.set_option('display.max_rows', None)\n",
    "import random\n",
    "import statistics\n",
    "import scipy\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import sys\n",
    "if 'google.colab' in sys.modules:\n",
    "    !pip install transformers weightwatcher\n",
    "\n",
    "from transformers import (\n",
    "    AutoModelForSeq2SeqLM,\n",
    "    AutoModelForCausalLM,\n",
    "    AutoConfig,\n",
    "    T5Config,\n",
    "    Trainer,\n",
    "    TrainingArguments,\n",
    ")\n",
    "\n",
    "from chronos import ChronosPipeline, ChronosBoltPipeline\n",
    "\n",
    "from collections import defaultdict\n",
    "\n",
    "from matplotlib.colors import PowerNorm\n",
    "import math"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d6f7a9a-7f2c-442c-bd10-f786f4d10af2",
   "metadata": {},
   "outputs": [],
   "source": [
    "chronos = ChronosPipeline.from_pretrained(f\"amazon/chronos-t5-small\")\n",
    "bolt = ChronosBoltPipeline.from_pretrained(f\"amazon/chronos-bolt-small\")\n",
    "bolt_p1 = # Add you pretrained Chronos-Bolt with a patch size of 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7a5c241-170d-4eb2-82e4-79b454fc1534",
   "metadata": {},
   "source": [
    "## An exploratory example involving an outlier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "170b5af5-27b6-42c7-8716-edac09b807b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def perfect_align_context(num_patches):\n",
    "    signal = lambda t: np.cos(t+0.3)\n",
    "    gap = 2 * math.pi / num_patches / 16\n",
    "    samplers = np.arange(-2048, 64) * gap\n",
    "    return signal(samplers), gap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4162969c-fe27-4772-95bb-34f9ccc10e2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "context_gt, _ = perfect_align_context(10)\n",
    "context = context_gt[(-512-64):-64]\n",
    "context[0:50] += 50\n",
    "groundtruth = context_gt[-64:]\n",
    "plt.plot(context)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6275b920-60e7-4b32-9521-ef55c47240c3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "output_bolt = bolt_p1.predict(\n",
    "                torch.from_numpy(context),\n",
    "                prediction_length=64,\n",
    "            ).numpy()\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(np.arange(0,context.shape[0]), context)\n",
    "plt.plot(np.arange(context.shape[0],context.shape[0]+output_bolt.shape[2]),output_bolt[0,4,:])\n",
    "plt.fill_between(np.arange(context.shape[0],context.shape[0]+output_bolt.shape[2]), output_bolt[0,2,:], output_bolt[0,6,:], color='orange', alpha=0.4, label=\"30-70%\")\n",
    "plt.fill_between(np.arange(context.shape[0],context.shape[0]+output_bolt.shape[2]), output_bolt[0,0,:], output_bolt[0,8,:], color='orange', alpha=0.2, label=\"10-90%\")\n",
    "plt.plot(np.arange(context.shape[0],context.shape[0]+output_bolt.shape[2]), groundtruth, alpha=0.5, color='lightgray', label=\"ground truth\")\n",
    "\n",
    "plt.ylim(-1.5,1.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5b9bc90-0c76-4c95-99e2-e2715287d2e9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "output_chronos = chronos.predict(\n",
    "                torch.from_numpy(context),\n",
    "                prediction_length=64,\n",
    "                num_samples=1,\n",
    "            ).numpy()\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(np.arange(0,context.shape[0]), context)\n",
    "plt.plot(np.arange(context.shape[0],context.shape[0]+output_chronos.shape[2]),output_chronos[0,0,:])\n",
    "plt.plot(np.arange(context.shape[0],context.shape[0]+output_bolt.shape[2]), groundtruth, color='lightgray', alpha=0.5, label=\"ground truth\")\n",
    "plt.ylim(-1.5,1.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd6ce1e7-076a-4755-ade3-db9e580c206b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "context_gt, _ = perfect_align_context(10)\n",
    "context = context_gt[(-512-64):-64]\n",
    "#context[[50,90,130,170]] += 100\n",
    "context[random.sample(range(512), 30)] += np.random.randn(30) * 100\n",
    "# context[30:50] += 50\n",
    "groundtruth = context_gt[-64:]\n",
    "plt.plot(context)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "297a529c-728c-40c5-9e15-77867e7888aa",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "output_bolt = bolt_p1.predict(\n",
    "                torch.from_numpy(context),\n",
    "                prediction_length=64,\n",
    "            ).numpy()\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(np.arange(0,context.shape[0]), context)\n",
    "plt.plot(np.arange(context.shape[0],context.shape[0]+output_bolt.shape[2]),output_bolt[0,4,:])\n",
    "plt.fill_between(np.arange(context.shape[0],context.shape[0]+output_bolt.shape[2]), output_bolt[0,2,:], output_bolt[0,6,:], color='orange', alpha=0.4, label=\"30-70%\")\n",
    "plt.fill_between(np.arange(context.shape[0],context.shape[0]+output_bolt.shape[2]), output_bolt[0,0,:], output_bolt[0,8,:], color='orange', alpha=0.2, label=\"10-90%\")\n",
    "plt.plot(np.arange(context.shape[0],context.shape[0]+output_bolt.shape[2]), groundtruth, alpha=0.5, color='lightgray', label=\"ground truth\")\n",
    "\n",
    "plt.ylim(-1.5,1.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f798d06-65d7-4bd2-a3a4-222b4d93abb9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "output_chronos = chronos.predict(\n",
    "                torch.from_numpy(context),\n",
    "                prediction_length=64,\n",
    "                num_samples=1,\n",
    "            ).numpy()\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(np.arange(0,context.shape[0]), context)\n",
    "plt.plot(np.arange(context.shape[0],context.shape[0]+output_chronos.shape[2]),output_chronos[0,0,:])\n",
    "plt.plot(np.arange(context.shape[0],context.shape[0]+output_bolt.shape[2]), groundtruth, color='lightgray', alpha=0.5, label=\"ground truth\")\n",
    "plt.ylim(-1.5,1.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "726940e2-a4fc-4c93-9448-051a0f3de9f4",
   "metadata": {},
   "source": [
    "## Do more quantification"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f20568bd-dc3e-428a-8de4-68e75cdd9270",
   "metadata": {},
   "source": [
    "### 1. Change the outlier's magnitude"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "615e7ccc-7f53-4e24-b645-38f1424ea698",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_outlier_mag(magnitude):\n",
    "    context_gt, _ = perfect_align_context(10)\n",
    "    context = context_gt[(-512-64):-64]\n",
    "    context[10] += magnitude * -3.967\n",
    "    context[37] += magnitude * 1.4934\n",
    "    context[98] += magnitude * 4.2007\n",
    "    context[175] += magnitude * -1.185379\n",
    "    groundtruth = context_gt[-64:]\n",
    "    return context, groundtruth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "759f487a-d323-4b96-8b13-0a129b7085ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "magnitudes = np.logspace(-1, 5, num=100)\n",
    "\n",
    "err_chronos = []\n",
    "err_bolt = []\n",
    "outputs_chronos = []\n",
    "outputs_bolt = []\n",
    "contexts = []\n",
    "groundtruths = []\n",
    "\n",
    "for i in range(magnitudes.shape[0]):\n",
    "    context, groundtruth = make_outlier_mag(magnitudes[i])\n",
    "    contexts.append(context)\n",
    "    groundtruths.append(groundtruth)\n",
    "    \n",
    "    output_bolt = bolt_p1.predict(\n",
    "                    torch.from_numpy(context),\n",
    "                    prediction_length=64,\n",
    "                ).numpy()\n",
    "    \n",
    "    output_chronos = chronos.predict(\n",
    "                    torch.from_numpy(context),\n",
    "                    prediction_length=64,\n",
    "                    num_samples=1,\n",
    "                ).numpy()\n",
    "\n",
    "    outputs_bolt.append(output_bolt)\n",
    "    outputs_chronos.append(output_chronos)\n",
    "\n",
    "    err_bolt.append(np.mean((output_bolt[0,4,:] - groundtruth) ** 2))\n",
    "    err_chronos.append(np.mean((output_chronos[0,0,:] - groundtruth) ** 2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d07b4846-ca4c-459d-ac71-22528ef3a9b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(15, 6))\n",
    "plt.plot(magnitudes, err_chronos, linewidth=3, label=\"chronos\", color=\"red\")\n",
    "plt.plot(magnitudes, err_bolt, linewidth=3, label=\"bolt\", color=\"green\")\n",
    "plt.yscale(\"log\")\n",
    "plt.xscale(\"log\")\n",
    "plt.xlabel(\"outliers' magnitudes\")\n",
    "plt.ylabel(\"MSE\")\n",
    "plt.title(\"Prediction Error as a Function of Outliers' Magnitudes\")\n",
    "plt.legend()\n",
    "plt.grid(which='both', axis='both')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17c6aa16-21f3-4b85-8229-8982db5752c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt_index = 30  # 70, 30, 20 are representative\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(np.arange(0,contexts[plt_index].shape[0]), contexts[plt_index], label=\"context\")\n",
    "plt.plot(np.arange(contexts[plt_index].shape[0],context.shape[0]+outputs_chronos[plt_index].shape[2]),outputs_chronos[plt_index][0,0,:], label=\"forecast\")\n",
    "plt.plot(np.arange(contexts[plt_index].shape[0],context.shape[0]+outputs_chronos[plt_index].shape[2]), groundtruths[plt_index], color='lightgray', alpha=0.5, label=\"ground truth\")\n",
    "plt.ylim(-1.5,1.5)\n",
    "plt.legend()\n",
    "plt.title(f\"Chronos' Forecast (Outlier Magnitude = {magnitudes[plt_index]:.2e})\")\n",
    "plt.xlabel(\"time step\")\n",
    "plt.ylabel(\"value\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6c828b8-833f-47b1-8e58-96a79ae9947f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt_index = 70\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(np.arange(0,contexts[plt_index].shape[0]), contexts[plt_index], label=\"context\")\n",
    "plt.plot(np.arange(contexts[plt_index].shape[0],context.shape[0]+outputs_bolt[plt_index].shape[2]),outputs_bolt[plt_index][0,4,:], label=\"forecast\")\n",
    "plt.fill_between(np.arange(context.shape[0],context.shape[0]+outputs_bolt[plt_index].shape[2]), outputs_bolt[plt_index][0,2,:], outputs_bolt[plt_index][0,6,:], color='orange', alpha=0.4, label=\"30-70%\")\n",
    "plt.fill_between(np.arange(context.shape[0],context.shape[0]+outputs_bolt[plt_index].shape[2]), outputs_bolt[plt_index][0,0,:], outputs_bolt[plt_index][0,8,:], color='orange', alpha=0.2, label=\"10-90%\")\n",
    "plt.plot(np.arange(contexts[plt_index].shape[0],context.shape[0]+outputs_bolt[plt_index].shape[2]), groundtruths[plt_index], color='lightgray', alpha=0.5, label=\"ground truth\")\n",
    "plt.ylim(-1.5,1.5)\n",
    "plt.legend()\n",
    "plt.title(f\"Bolt's Forecast (Outlier Magnitude = {magnitudes[plt_index]:.2e})\")\n",
    "plt.xlabel(\"time step\")\n",
    "plt.ylabel(\"value\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ecbbabf-af4c-4a28-befa-f2ad646215a3",
   "metadata": {},
   "source": [
    "### 2. Change the number of outliers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82c16eb0-da95-4fc2-96c0-6dc2bc212404",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_outlier_num(order, magnitudes, number):\n",
    "    context_gt, _ = perfect_align_context(10)\n",
    "    context = context_gt[(-512-64):-64]\n",
    "    \n",
    "    for i in range(number):\n",
    "        context[order[i]] += magnitudes[i]\n",
    "        \n",
    "    groundtruth = context_gt[-64:]\n",
    "    return context, groundtruth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb421362-6346-449d-b5c0-911957ca9461",
   "metadata": {},
   "outputs": [],
   "source": [
    "number = np.arange(100)\n",
    "order = np.random.permutation(np.arange(256))\n",
    "magnitudes = 50 * np.random.randn(512)\n",
    "\n",
    "err_chronos = []\n",
    "err_bolt = []\n",
    "outputs_chronos = []\n",
    "outputs_bolt = []\n",
    "contexts = []\n",
    "groundtruths = []\n",
    "\n",
    "for i in range(number.shape[0]):\n",
    "    context, groundtruth = make_outlier_num(order, magnitudes, number[i])\n",
    "    contexts.append(context)\n",
    "    groundtruths.append(groundtruth)\n",
    "    \n",
    "    output_bolt = bolt_p1.predict(\n",
    "                    torch.from_numpy(context),\n",
    "                    prediction_length=64,\n",
    "                ).numpy()\n",
    "    \n",
    "    output_chronos = chronos.predict(\n",
    "                    torch.from_numpy(context),\n",
    "                    prediction_length=64,\n",
    "                    num_samples=1,\n",
    "                ).numpy()\n",
    "\n",
    "    outputs_bolt.append(output_bolt)\n",
    "    outputs_chronos.append(output_chronos)\n",
    "\n",
    "    err_bolt.append(np.mean((output_bolt[0,4,:] - groundtruth) ** 2))\n",
    "    err_chronos.append(np.mean((output_chronos[0,0,:] - groundtruth) ** 2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "278d647e-2b89-4b2d-b20a-e33391e90924",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(15, 6))\n",
    "plt.plot(number, err_chronos, linewidth=3, label=\"chronos\", color=\"red\")\n",
    "plt.plot(number, err_bolt, linewidth=3, label=\"bolt\", color=\"green\")\n",
    "plt.yscale(\"log\")\n",
    "#plt.xscale(\"log\")\n",
    "plt.xlabel(\"number of outliers\")\n",
    "plt.ylabel(\"MSE\")\n",
    "plt.title(\"Prediction Error as a Function of Number of Outliers\")\n",
    "plt.legend()\n",
    "plt.grid(which='both', axis='y')\n",
    "\n",
    "plt.axvline(x=20, linewidth=2, linestyle='--', color=\"blue\")\n",
    "plt.axvline(x=80, linewidth=2, linestyle='--', color=\"blue\")\n",
    "plt.ylim(1e-5, 1e3)\n",
    "plt.text(2.5, 250, \"Phase I\", color=\"blue\", fontsize=20)\n",
    "plt.text(45, 250, \"Phase II\", color=\"blue\", fontsize=20)\n",
    "plt.text(86, 250, \"Phase III\", color=\"blue\", fontsize=20)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfe3ccd0-02cb-4988-b9ee-f4d47ccc19c7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plt_index = 46  # 10, 46, 99 are representative\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(np.arange(0,contexts[plt_index].shape[0]), contexts[plt_index], label=\"context\")\n",
    "plt.plot(np.arange(contexts[plt_index].shape[0],context.shape[0]+outputs_chronos[plt_index].shape[2]),outputs_chronos[plt_index][0,0,:], label=\"forecast\")\n",
    "plt.plot(np.arange(contexts[plt_index].shape[0],context.shape[0]+outputs_chronos[plt_index].shape[2]), groundtruths[plt_index], color='lightgray', alpha=0.5, label=\"ground truth\")\n",
    "plt.ylim(-1.5,1.5)\n",
    "plt.legend()\n",
    "plt.title(f\"Chronos' Forecast (Number of Outliers = {number[plt_index]})\")\n",
    "plt.xlabel(\"time step\")\n",
    "plt.ylabel(\"value\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dde9d84a-4a81-4d7d-9cc9-07b28d7190bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt_index = 99\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(np.arange(0,contexts[plt_index].shape[0]), contexts[plt_index], label=\"context\")\n",
    "plt.plot(np.arange(contexts[plt_index].shape[0],context.shape[0]+outputs_bolt[plt_index].shape[2]),outputs_bolt[plt_index][0,4,:], label=\"forecast\")\n",
    "plt.fill_between(np.arange(context.shape[0],context.shape[0]+outputs_bolt[plt_index].shape[2]), outputs_bolt[plt_index][0,2,:], outputs_bolt[plt_index][0,6,:], color='orange', alpha=0.4, label=\"30-70%\")\n",
    "plt.fill_between(np.arange(context.shape[0],context.shape[0]+outputs_bolt[plt_index].shape[2]), outputs_bolt[plt_index][0,0,:], outputs_bolt[plt_index][0,8,:], color='orange', alpha=0.2, label=\"10-90%\")\n",
    "plt.plot(np.arange(contexts[plt_index].shape[0],context.shape[0]+outputs_bolt[plt_index].shape[2]), groundtruths[plt_index], color='lightgray', alpha=0.5, label=\"ground truth\")\n",
    "plt.ylim(-1.5,1.5)\n",
    "plt.legend()\n",
    "plt.title(f\"Bolt's Forecast (Number of Outliers = {number[plt_index]})\")\n",
    "plt.xlabel(\"time step\")\n",
    "plt.ylabel(\"value\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c070c103-e367-4046-9cd9-c18355869e9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_chronos = chronos.predict(\n",
    "                    torch.from_numpy(contexts[46]),\n",
    "                    prediction_length=64,\n",
    "                    num_samples=1,\n",
    "                ).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fc37753-16a9-4b1b-a993-08b73a0705a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_bolt = bolt_p1.predict(\n",
    "                    torch.from_numpy(contexts[46]),\n",
    "                    prediction_length=64,\n",
    "                ).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e836b8fa-6bca-4ad7-9ec6-f84d08672582",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_64_bolt = bolt.predict(\n",
    "        torch.from_numpy(contexts[46]),\n",
    "        prediction_length=64,\n",
    "    ).numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0bcc857b-64f5-4049-acf1-86a1aec14ed2",
   "metadata": {},
   "source": [
    "## Comparing the Patch Size?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5b9b6eb-c4fb-441b-9b63-8d8c6d5f0afb",
   "metadata": {},
   "outputs": [],
   "source": [
    "err_bolt_orig = []\n",
    "err_bolt = []\n",
    "outputs_bolt_orig = []\n",
    "outputs_bolt = []\n",
    "\n",
    "for i in range(number.shape[0]):\n",
    "    context = contexts[i]\n",
    "    groundtruth = groundtruths[i]\n",
    "    \n",
    "    output_bolt = bolt_p1.predict(\n",
    "                    torch.from_numpy(context),\n",
    "                    prediction_length=64,\n",
    "                ).numpy()\n",
    "    \n",
    "    output_bolt_orig = bolt.predict(\n",
    "                    torch.from_numpy(context),\n",
    "                    prediction_length=64,\n",
    "                ).numpy()\n",
    "\n",
    "    outputs_bolt.append(output_bolt)\n",
    "    outputs_bolt_orig.append(output_bolt_orig)\n",
    "\n",
    "    err_bolt.append(np.mean((output_bolt[0,4,:] - groundtruth) ** 2))\n",
    "    err_bolt_orig.append(np.mean((output_bolt_orig[0,4,:] - groundtruth) ** 2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b77f7bd6-61e6-4873-a865-73706dc2e54a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(15, 6))\n",
    "plt.plot(number, err_bolt_orig, linewidth=3, label=\"bolt (patch size = 16)\", color=\"magenta\")\n",
    "plt.plot(number, err_bolt, linewidth=3, label=\"bolt (patch size = 1)\", color=\"green\")\n",
    "#plt.yscale(\"log\")\n",
    "#plt.xscale(\"log\")\n",
    "plt.xlabel(\"number of outliers\")\n",
    "plt.ylabel(\"MSE\")\n",
    "plt.title(\"Prediction Error as a Function of Number of Outliers\")\n",
    "plt.legend()\n",
    "plt.grid(which='both', axis='y')\n",
    "\n",
    "plt.axvline(x=20, linewidth=2, linestyle='--', color=\"blue\")\n",
    "plt.axvline(x=80, linewidth=2, linestyle='--', color=\"blue\")\n",
    "plt.text(2.5, 0, \"Phase I\", color=\"blue\", fontsize=20)\n",
    "plt.text(45, 0, \"Phase II\", color=\"blue\", fontsize=20)\n",
    "plt.text(86, 0, \"Phase III\", color=\"blue\", fontsize=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32484a9c-8387-4760-9d1c-18da08d5fc5f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
