{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ea97df10-70ae-4217-876c-0078e8afc5c1",
   "metadata": {},
   "source": [
    "# Offset Bias"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77b96b93-5f06-4a39-aa6c-87c9e767ace4",
   "metadata": {},
   "source": [
    "In this notebook, we analyze the offset bias in designing TSFMs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80f84bb7-fd0e-42ed-99f5-e91997c5ac69",
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "pd.set_option('display.max_rows', None)\n",
    "import random\n",
    "import statistics\n",
    "import scipy\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import sys\n",
    "if 'google.colab' in sys.modules:\n",
    "    !pip install transformers weightwatcher\n",
    "\n",
    "from transformers import (\n",
    "    AutoModelForSeq2SeqLM,\n",
    "    AutoModelForCausalLM,\n",
    "    AutoConfig,\n",
    "    T5Config,\n",
    "    Trainer,\n",
    "    TrainingArguments,\n",
    ")\n",
    "\n",
    "from chronos import ChronosPipeline, ChronosBoltPipeline\n",
    "\n",
    "from collections import defaultdict\n",
    "\n",
    "from matplotlib.colors import PowerNorm\n",
    "import math"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36b14582-b959-4d25-b56b-86197fc906cc",
   "metadata": {},
   "source": [
    "## Helper Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5bd0705-58df-449c-9039-c02c6c5bbd4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_smooth_tri_plateaus(plateau_width=50, transition_width=10, num_plateaus=6, high_value=1.0, low_value=-1.0):\n",
    "    values = []\n",
    "\n",
    "    shifts = [low_value, 0, high_value, 0]\n",
    "    current_value = shifts[0]\n",
    "\n",
    "    for i in range(num_plateaus):\n",
    "        # Plateau region\n",
    "        plateau = np.full(plateau_width, current_value)\n",
    "        values.append(plateau)\n",
    "\n",
    "        # Transition region\n",
    "        next_value = shifts[i % 4]\n",
    "        transition = smooth_step(current_value, next_value, transition_width)\n",
    "        values.append(transition)\n",
    "\n",
    "        # Alternate for next iteration\n",
    "        current_value = next_value\n",
    "\n",
    "    # Flatten and return\n",
    "    return np.concatenate(values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "315ab196",
   "metadata": {},
   "outputs": [],
   "source": [
    "chronos = ChronosPipeline.from_pretrained(f\"amazon/chronos-t5-small\")\n",
    "bolt = ChronosBoltPipeline.from_pretrained(f\"amazon/chronos-bolt-small\")\n",
    "bolt_f1 = # Add you pretrained Chronos-Bolt with a patch size of 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be3e0c5e-dc24-4753-8229-9b4451acbb50",
   "metadata": {},
   "source": [
    "## Compute the outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d06f2072-dbd2-4cdf-a177-b00481f4f7e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "context_gt, _ = perfect_align_context(1)\n",
    "context_gt *= 0.3\n",
    "arr = generate_smooth_tri_plateaus(plateau_width=100, transition_width=10, num_plateaus=40)[:context_gt.shape[0]]\n",
    "context_gt += arr * 4\n",
    "# context_gt = context_gt[:-140]\n",
    "context_gt = context_gt[:-28]\n",
    "context = context_gt[(-512-64):-64]\n",
    "groundtruth = context_gt[-64:]\n",
    "plt.plot(context)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8d3d1af-4f9a-4c36-9c22-f030733fbf33",
   "metadata": {},
   "source": [
    "### Predicting near-zero segments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e0007fb-9f42-4410-b689-3acda00c0b1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_bolt = bolt_f1.predict(\n",
    "                torch.from_numpy(context),\n",
    "                prediction_length=64,\n",
    "            ).numpy()\n",
    "\n",
    "output_chronos = chronos.predict(\n",
    "                torch.from_numpy(context),\n",
    "                prediction_length=64,\n",
    "                num_samples=1,\n",
    "            ).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7f5083a-7975-4b83-afc3-94e7a09e1366",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(context, label=\"context\")\n",
    "plt.plot(np.arange(context.shape[0], context.shape[0] + output_chronos.shape[-1]), output_chronos[0,0,:], label=\"forecast\")\n",
    "plt.xlabel(\"time step\")\n",
    "plt.ylabel(\"value\")\n",
    "plt.title(\"Chronos Forecast\")\n",
    "plt.legend(loc=\"upper left\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c0df8ba-a288-467c-8e2d-08879e4b4289",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(context, label=\"context\")\n",
    "plt.plot(np.arange(context.shape[0], context.shape[0] + output_bolt.shape[-1]), output_bolt[0,4,:], label=\"median forecast\")\n",
    "plt.fill_between(np.arange(context.shape[0], context.shape[0] + output_bolt.shape[-1]), output_bolt[0,2,:], output_bolt[0,6,:], color=\"orange\", alpha=0.4, label=\"30-70%\")\n",
    "plt.fill_between(np.arange(context.shape[0], context.shape[0] + output_bolt.shape[-1]), output_bolt[0,0,:], output_bolt[0,8,:], color=\"orange\", alpha=0.2, label=\"10-90%\")\n",
    "plt.xlabel(\"time step\")\n",
    "plt.ylabel(\"value\")\n",
    "plt.title(\"Bolt Forecast\")\n",
    "plt.legend(loc=\"upper left\")\n",
    "plt.ylim(-4.5,4.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d17143b8-4b57-4eb5-88c3-c9841c70cd43",
   "metadata": {},
   "source": [
    "### Predicting off-zero segments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb522e63-0387-4cb0-b0da-0885aadaca83",
   "metadata": {},
   "outputs": [],
   "source": [
    "context_gt, _ = perfect_align_context(1)\n",
    "context_gt *= 0.3\n",
    "arr = generate_smooth_tri_plateaus(plateau_width=100, transition_width=10, num_plateaus=40)[:context_gt.shape[0]]\n",
    "context_gt += arr * 4\n",
    "context_gt = context_gt[:-140]\n",
    "# context_gt = context_gt[:-40]\n",
    "context = context_gt[(-440-64):-64]\n",
    "groundtruth = context_gt[-64:]\n",
    "plt.plot(context)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b891eb07-dcac-4519-9519-986bd129efac",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_bolt = bolt_f1.predict(\n",
    "                torch.from_numpy(context),\n",
    "                prediction_length=64,\n",
    "            ).numpy()\n",
    "\n",
    "output_chronos = chronos.predict(\n",
    "                torch.from_numpy(context),\n",
    "                prediction_length=64,\n",
    "                num_samples=1,\n",
    "            ).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a81438f9-ba86-40f8-bf2f-798e8f7cfa5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(context, label=\"context\")\n",
    "plt.plot(np.arange(context.shape[0], context.shape[0] + output_chronos.shape[-1]), output_chronos[0,0,:], label=\"forecast\")\n",
    "plt.xlabel(\"time step\")\n",
    "plt.ylabel(\"value\")\n",
    "plt.title(\"Chronos Forecast\")\n",
    "plt.legend(loc=\"upper left\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83e1cc09-3c83-4ff3-a8b3-b6976da1b85a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(context, label=\"context\")\n",
    "plt.plot(np.arange(context.shape[0], context.shape[0] + output_bolt.shape[-1]), output_bolt[0,4,:], label=\"median forecast\")\n",
    "plt.fill_between(np.arange(context.shape[0], context.shape[0] + output_bolt.shape[-1]), output_bolt[0,2,:], output_bolt[0,6,:], color=\"orange\", alpha=0.4, label=\"30-70%\")\n",
    "plt.fill_between(np.arange(context.shape[0], context.shape[0] + output_bolt.shape[-1]), output_bolt[0,0,:], output_bolt[0,8,:], color=\"orange\", alpha=0.2, label=\"10-90%\")\n",
    "plt.xlabel(\"time step\")\n",
    "plt.ylabel(\"value\")\n",
    "plt.title(\"Bolt Forecast\")\n",
    "plt.legend(loc=\"upper left\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e09e7323-ab3f-41e1-a883-e0e38c3a7e0a",
   "metadata": {},
   "source": [
    "## Plot diagnostic measures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7fcda0b-6901-4ad4-9085-69ad75cb0fa8",
   "metadata": {},
   "outputs": [],
   "source": [
    "hidden_chronos = # load your hidden states\n",
    "input_chronos = # load your input\n",
    "\n",
    "_,_,Vt = scipy.linalg.svd(input_chronos[0,...])\n",
    "proj_in = input_chronos @ Vt.T[:, 0:10]\n",
    "proj_in = proj_in[0,:]\n",
    "\n",
    "U,S_chronos,Vt = scipy.linalg.svd(hidden_chronos[0,...])\n",
    "proj = hidden_chronos @ Vt.T[:, 0:10]\n",
    "proj = proj[0,:]\n",
    "fig, ax = plt.subplots(14, 1, figsize=(10, 16), sharex=True, gridspec_kw={'height_ratios': [3] + [0.3] + [3] + [0.3] + [1]*10})\n",
    "ax[0].plot(context[:512], color=\"red\")\n",
    "ax[0].axvline(101, color=\"gray\", alpha=0.5)\n",
    "ax[0].axvline(202, color=\"gray\", alpha=0.5)\n",
    "ax[0].axvline(303, color=\"gray\", alpha=0.5)\n",
    "ax[0].axvline(404, color=\"gray\", alpha=0.5)\n",
    "ax[0].set_yticks([])\n",
    "\n",
    "ax[2].plot(scipy.linalg.norm(input_chronos[0,:512,:], axis=1), color=\"green\")\n",
    "ax[2].axvline(101, color=\"gray\", alpha=0.5)\n",
    "ax[2].axvline(202, color=\"gray\", alpha=0.5)\n",
    "ax[2].axvline(303, color=\"gray\", alpha=0.5)\n",
    "ax[2].axvline(404, color=\"gray\", alpha=0.5)\n",
    "ax[2].set_ylim(0,0.55)\n",
    "\n",
    "for i in range(4,14):\n",
    "    ax[i].plot(proj[:512, i-4])\n",
    "    ax[i].axvline(101, color=\"gray\", alpha=0.5)\n",
    "    ax[i].axvline(202, color=\"gray\", alpha=0.5)\n",
    "    ax[i].axvline(303, color=\"gray\", alpha=0.5)\n",
    "    ax[i].axvline(404, color=\"gray\", alpha=0.5)\n",
    "    ax[i].set_yticks([])\n",
    "\n",
    "ax[0].set_title(\"Input Context\")\n",
    "ax[1].axis('off')\n",
    "ax[1].text(175, 0, 'Norm of Embedded Vectors (Chronos)', fontsize=12)\n",
    "ax[3].axis('off')\n",
    "ax[3].text(98, 0, 'Encoded Context Projected onto the First 10 Singular Vectors (Chronos)', fontsize=12)\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4cecd8c-be68-412e-93ca-d2a1a4801519",
   "metadata": {},
   "outputs": [],
   "source": [
    "hidden_bolt = # load your hidden states\n",
    "input_bolt = # load your input\n",
    "U,S_bolt,Vt = scipy.linalg.svd(hidden_bolt[0,...])\n",
    "proj = hidden_bolt @ Vt.T\n",
    "fig, ax = plt.subplots(14, 1, figsize=(10, 16), sharex=True, gridspec_kw={'height_ratios': [3] + [0.3] + [3] + [0.3] + [1]*10})\n",
    "ax[0].plot(context[:-1], color=\"red\")\n",
    "ax[0].axvline(101, color=\"gray\", alpha=0.5)\n",
    "ax[0].axvline(202, color=\"gray\", alpha=0.5)\n",
    "ax[0].axvline(303, color=\"gray\", alpha=0.5)\n",
    "ax[0].axvline(404, color=\"gray\", alpha=0.5)\n",
    "ax[0].set_yticks([])\n",
    "\n",
    "ax[2].plot(scipy.linalg.norm(input_bolt[0,:-1], axis=1), color=\"green\")\n",
    "ax[2].axvline(101, color=\"gray\", alpha=0.5)\n",
    "ax[2].axvline(202, color=\"gray\", alpha=0.5)\n",
    "ax[2].axvline(303, color=\"gray\", alpha=0.5)\n",
    "ax[2].axvline(404, color=\"gray\", alpha=0.5)\n",
    "ax[2].set_ylim(0,1.5)\n",
    "\n",
    "for i in range(4,14):\n",
    "    ax[i].plot(U[:-1, i-4])\n",
    "    ax[i].axvline(101, color=\"gray\", alpha=0.5)\n",
    "    ax[i].axvline(202, color=\"gray\", alpha=0.5)\n",
    "    ax[i].axvline(303, color=\"gray\", alpha=0.5)\n",
    "    ax[i].axvline(404, color=\"gray\", alpha=0.5)\n",
    "    ax[i].set_yticks([])\n",
    "\n",
    "ax[0].set_title(\"Input Context\")\n",
    "ax[1].axis('off')\n",
    "ax[1].text(182, 0, 'Norm of Embedded Vectors (Bolt)', fontsize=12)\n",
    "ax[3].axis('off')\n",
    "ax[3].text(107, 0, 'Encoded Context Projected onto the First 10 Singular Vectors (Bolt)', fontsize=12)\n",
    "\n",
    "plt.tight_layout()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
