{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ef08cad1-d920-4c46-904f-c0be54dbfe3d",
   "metadata": {},
   "source": [
    "# Locality Bias"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e76ed6f4-57b8-40fe-9fbe-4efbf03f52d6",
   "metadata": {},
   "source": [
    "In this notebook, we analyze the locality bias in designing TSFMs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0f3f1c4-575c-4917-a0f5-8f15cac52c98",
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "pd.set_option('display.max_rows', None)\n",
    "import random\n",
    "import statistics\n",
    "import scipy\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import sys\n",
    "if 'google.colab' in sys.modules:\n",
    "    !pip install transformers weightwatcher\n",
    "\n",
    "from transformers import (\n",
    "    AutoModelForSeq2SeqLM,\n",
    "    AutoModelForCausalLM,\n",
    "    AutoConfig,\n",
    "    T5Config,\n",
    "    Trainer,\n",
    "    TrainingArguments,\n",
    ")\n",
    "\n",
    "from chronos import ChronosPipeline, ChronosBoltPipeline\n",
    "\n",
    "from collections import defaultdict\n",
    "\n",
    "from matplotlib.colors import PowerNorm\n",
    "import math\n",
    "\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "from matplotlib.colors import TwoSlopeNorm, LinearSegmentedColormap\n",
    "\n",
    "from matplotlib.patches import Ellipse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e29ebdd6-92b3-4802-bd8d-07122db795c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_angles(mat):\n",
    "    L = mat.shape[0]\n",
    "    angle = np.zeros((L,L))\n",
    "    for i in range(L):\n",
    "        for j in range(i,L):\n",
    "            Vi = mat[i:(i+1),:].T\n",
    "            Vj = mat[j:(j+1),:].T\n",
    "            dist = np.mean(scipy.linalg.subspace_angles(Vi,Vj))\n",
    "            angle[i,j] = dist\n",
    "            angle[j,i] = dist\n",
    "    return angle\n",
    "\n",
    "def compute_distance(mat):\n",
    "    L = mat.shape[0]\n",
    "    distance = np.zeros((L,L))\n",
    "    for i in range(L):\n",
    "        for j in range(i,L):\n",
    "            Vi = mat[i:(i+1),:].T\n",
    "            Vj = mat[j:(j+1),:].T\n",
    "            dist = np.linalg.norm(Vi-Vj) / (np.linalg.norm(Vi) + np.linalg.norm(Vj))\n",
    "            distance[i,j] = dist\n",
    "            distance[j,i] = dist\n",
    "    return distance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "164ddf03-25b3-463d-81a4-af14f2e54952",
   "metadata": {},
   "outputs": [],
   "source": [
    "bolt_p1 = # Add you pretrained Chronos-Bolt with a patch size of 1\n",
    "bolt = ChronosBoltPipeline.from_pretrained(f\"amazon/chronos-bolt-base\")\n",
    "chronos = ChronosPipeline.from_pretrained(f\"amazon/chronos-t5-small\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a730f68-8c07-4fad-bba7-cf2ba705ac45",
   "metadata": {},
   "source": [
    "## Embedding Distance and Angle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39011681-85f6-4705-9bce-299092d66ac5",
   "metadata": {},
   "outputs": [],
   "source": [
    "embed_mat = chronos.model.model.shared.weight.detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdb49788-e52d-4ff6-9419-03ad308b6f3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "embed_mat.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5bf211b-3882-400b-98a9-08bb5b65b5ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "angles_chronos = compute_angles(embed_mat[(2048-512):(2048+512),:])\n",
    "distance_chronos = compute_distance(embed_mat[(2048-512):(2048+512),:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85218e61-4038-40a6-840b-9f3bfcae2caf",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(angles_chronos, origin='lower')\n",
    "plt.colorbar()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b0e6d03-203f-49dc-b4fe-44f798e61fc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(distance_chronos, origin='lower')\n",
    "plt.colorbar()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b83648f-3c9a-442a-a6be-8913d67d43f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_grid = np.expand_dims(np.linspace(-15/4,15/4,1024), axis=-1)\n",
    "input_mask = np.expand_dims(np.ones(1024), axis=-1)\n",
    "input_grid = np.concatenate((input_grid, input_mask), axis=-1)\n",
    "embed_bolt = bolt_p1.model.input_patch_embedding(torch.from_numpy(input_grid).to(bolt_p1.model.input_patch_embedding.hidden_layer.weight.dtype))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b210a67-c36d-4eda-979c-9f05ea1f5d9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "angles_bolt = compute_angles(embed_bolt.detach().numpy())\n",
    "distance_bolt = compute_distance(embed_bolt.detach().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4e5bf42-5b57-4b00-8d0a-67629d158611",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(angles_bolt, origin='lower')\n",
    "plt.colorbar()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c9d11da-6e5d-4f0a-823e-e3be2859a7c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(distance_bolt, origin='lower')\n",
    "plt.colorbar()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1eaeaad5-dddb-4044-b3f8-1aad1454a5d8",
   "metadata": {},
   "source": [
    "## Is the order preserved?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc24ee8c-6a32-4a4f-9284-e9795a9690a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "a = np.random.randn(15)\n",
    "b = np.random.randn(15)\n",
    "context = []\n",
    "for i in range(10):\n",
    "    context += [a[i]] * 10\n",
    "    context += [b[i]] * 10\n",
    "    if a[i] > b[i]:\n",
    "        context += [0] * 10\n",
    "    else:\n",
    "        context += [1] * 10\n",
    "context += [a[-1]] * 10\n",
    "context += [b[-1]] * 10\n",
    "plt.plot(np.array(context))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44d3a6ce-707a-46da-9801-652f07825ec0",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_bolt_1 = bolt_p1.predict(\n",
    "                torch.from_numpy(np.array(context)),\n",
    "                prediction_length=64,\n",
    "            ).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3525b304-eb6a-441c-b998-4ec90dd09264",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(output_bolt_1[0,4,:])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed25480e-dc31-4744-b2fb-f1fdb1453c86",
   "metadata": {},
   "source": [
    "## A good example and a bad example"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ae1e5f8-7072-4a1b-bcdb-8ebbffc01101",
   "metadata": {},
   "source": [
    "### 1. An example that favors Chronos-Bolt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35b70d3e-2520-44b8-9bbc-1553633e6329",
   "metadata": {},
   "outputs": [],
   "source": [
    "context = np.cos(np.arange(512+64) / 7)\n",
    "scale = np.arange(context.shape[0] // 2) / 512\n",
    "context = context[0:(context.shape[0] // 2)] * ((0.1 + scale ** 2) ** -1)\n",
    "context = np.concatenate((np.flip(context), context))\n",
    "groundtruth = context[512:]\n",
    "context = torch.from_numpy(context[:512])\n",
    "\n",
    "output = chronos.predict(\n",
    "    context,\n",
    "    prediction_length=64,\n",
    "    num_samples=1,\n",
    ").numpy()\n",
    "\n",
    "output_bolt = bolt.predict(\n",
    "    context,\n",
    "    prediction_length=64,\n",
    ").numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f1e06d9-37ca-439b-b0ba-02478b3135f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(context,label=\"context\")\n",
    "plt.plot(np.arange(context.shape[0],context.shape[0]+output.shape[2]), output[0,0,:], label=\"forecast\")\n",
    "plt.plot(np.arange(context.shape[0],context.shape[0]+output.shape[2]), groundtruth, color=\"gray\", label=\"ground truth\", alpha=0.3)\n",
    "plt.plot(np.concatenate((np.flip((0.1 + scale ** 2) ** -1), (0.1 + scale ** 2) ** -1)), '--', color=\"gray\")\n",
    "plt.plot(np.concatenate((-np.flip((0.1 + scale ** 2) ** -1), -(0.1 + scale ** 2) ** -1)), '--', color=\"gray\")\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0606a296-9100-4e79-a237-aacad29683b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(context,label=\"context\")\n",
    "plt.plot(np.arange(context.shape[0],context.shape[0]+output.shape[2]), output_bolt[0,4,:], label=\"forecast\")\n",
    "plt.plot(np.arange(context.shape[0],context.shape[0]+output.shape[2]), groundtruth, color=\"gray\", label=\"ground truth\", alpha=0.3)\n",
    "plt.plot(np.concatenate((np.flip((0.1 + scale ** 2) ** -1), (0.1 + scale ** 2) ** -1)), '--', color=\"gray\")\n",
    "plt.plot(np.concatenate((-np.flip((0.1 + scale ** 2) ** -1), -(0.1 + scale ** 2) ** -1)), '--', color=\"gray\")\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b570281f-015e-423e-8f95-c11fd9a0e4a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "att1 = # load your saved attention matrices\n",
    "att2 = # load your saved attention matrices\n",
    "att3 = # load your saved attention matrices\n",
    "att4 = # load your saved attention matrices\n",
    "att5 = # load your saved attention matrices\n",
    "att6 = # load your saved attention matrices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a7c2e7e-b674-4118-ac1a-72fce6fcdc9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "att_all = np.concatenate((att1,att2,att3,att4,att5,att6), axis=1)\n",
    "att_all = att_all[0,:,0,:]\n",
    "scale = np.arange((512+64) // 2) / 512\n",
    "scale = np.concatenate((np.flip((0.1 + scale ** 2) ** -1), (0.1 + scale ** 2) ** -1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae98431b-1938-45ce-a519-244504d92f12",
   "metadata": {},
   "outputs": [],
   "source": [
    "q10 = np.quantile(att_all, 0.10, axis=0)\n",
    "q30 = np.quantile(att_all, 0.30, axis=0)\n",
    "q50 = np.quantile(att_all, 0.50, axis=0)  # median\n",
    "q70 = np.quantile(att_all, 0.70, axis=0)\n",
    "q90 = np.quantile(att_all, 0.90, axis=0)\n",
    "\n",
    "# X-axis: index from 0 to 512\n",
    "x = np.arange(att_all.shape[1])\n",
    "\n",
    "# --- Create two subplots sharing the x-axis ---\n",
    "fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 7), sharex=True, gridspec_kw={'height_ratios': [3, 3]})\n",
    "\n",
    "# --- Top: quantile plot ---\n",
    "ax1.plot(x, q50, color='red', label='Median')\n",
    "ax1.fill_between(x, q30, q70, color='red', alpha=0.4, label='30–70%')\n",
    "ax1.fill_between(x, q10, q90, color='red', alpha=0.2, label='10–90%')\n",
    "ax1.set_ylabel(\"Attention Score\")\n",
    "ax1.set_ylim(-0.002,0.03)\n",
    "ax1.legend()\n",
    "ax1.set_title(\"Chronos Cross Attention Scores\")\n",
    "\n",
    "# --- Bottom: context + hidden ---\n",
    "ax2.plot(x[:context.shape[0]], context, label='Context')\n",
    "ax2.plot(np.arange(att_all.shape[1], att_all.shape[1]+output.shape[-1]), output[0,0,:], label='Forecast', color='orange')\n",
    "ax2.plot(np.arange(att_all.shape[1], att_all.shape[1]+output.shape[-1]), groundtruth, label='Ground Truth', color='green', alpha=0.3)\n",
    "plt.plot(np.arange(scale.shape[0]), scale, ':', color=\"magenta\", alpha=0.3)\n",
    "plt.plot(np.arange(scale.shape[0]), -scale, ':', color=\"magenta\", alpha=0.3)\n",
    "ax2.set_xlabel(\"Time Step\")\n",
    "ax2.set_ylabel(\"Value\")\n",
    "ax2.legend()\n",
    "\n",
    "# --- Final layout ---\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a5a9fbc-5f15-486c-94b6-a3dee017ecda",
   "metadata": {},
   "outputs": [],
   "source": [
    "att1 = # load your saved attention matrices\n",
    "att2 = # load your saved attention matrices\n",
    "att3 = # load your saved attention matrices\n",
    "att4 = # load your saved attention matrices\n",
    "att5 = # load your saved attention matrices\n",
    "att6 = # load your saved attention matrices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07295272-a7b6-41b4-b9bf-25f224dda9d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "att_all_bolt = np.concatenate((att1,att2,att3,att4,att5,att6), axis=1)\n",
    "att_all_bolt = att_all_bolt[0,:,0,:]\n",
    "scale = np.arange((512+64) // 2) / 512\n",
    "scale = np.concatenate((np.flip((0.1 + scale ** 2) ** -1), (0.1 + scale ** 2) ** -1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c083e37-f9aa-4e00-a2cf-40b452c8d888",
   "metadata": {},
   "outputs": [],
   "source": [
    "q10 = np.quantile(att_all_bolt, 0.10, axis=0)\n",
    "q30 = np.quantile(att_all_bolt, 0.30, axis=0)\n",
    "q50 = np.quantile(att_all_bolt, 0.50, axis=0)  # median\n",
    "q70 = np.quantile(att_all_bolt, 0.70, axis=0)\n",
    "q90 = np.quantile(att_all_bolt, 0.90, axis=0)\n",
    "\n",
    "# X-axis: index from 0 to 512\n",
    "x = np.arange(att_all.shape[1])\n",
    "\n",
    "# --- Create two subplots sharing the x-axis ---\n",
    "fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 7), sharex=True, gridspec_kw={'height_ratios': [3, 3]})\n",
    "\n",
    "# --- Top: quantile plot ---\n",
    "ax1.plot(x, q50, color='red', label='Median')\n",
    "ax1.fill_between(x, q30, q70, color='red', alpha=0.4, label='30–70%')\n",
    "ax1.fill_between(x, q10, q90, color='red', alpha=0.2, label='10–90%')\n",
    "ax1.set_ylabel(\"Attention Score\")\n",
    "ax1.set_ylim(-0.002,0.03)\n",
    "ax1.legend()\n",
    "ax1.set_title(\"Chronos Cross Attention Scores\")\n",
    "\n",
    "# --- Bottom: context + hidden ---\n",
    "ax2.plot(x[:context.shape[0]], context, label='Context')\n",
    "ax2.plot(np.arange(att_all.shape[1], att_all.shape[1]+output.shape[-1]), output_bolt[0,4,:], label='Forecast', color='orange')\n",
    "ax2.fill_between(np.arange(att_all.shape[1], att_all.shape[1]+output.shape[-1]), output_bolt[0,2,:], output_bolt[0,6,:], label='Forecast', color='orange', alpha=0.4)\n",
    "ax2.fill_between(np.arange(att_all.shape[1], att_all.shape[1]+output.shape[-1]), output_bolt[0,0,:], output_bolt[0,8,:], label='Forecast', color='orange', alpha=0.2)\n",
    "ax2.plot(np.arange(att_all.shape[1], att_all.shape[1]+output.shape[-1]), groundtruth, label='Ground Truth', color='green', alpha=0.3)\n",
    "plt.plot(np.arange(scale.shape[0]), scale, ':', color=\"magenta\", alpha=0.3)\n",
    "plt.plot(np.arange(scale.shape[0]), -scale, ':', color=\"magenta\", alpha=0.3)\n",
    "ax2.set_xlabel(\"Time Step\")\n",
    "ax2.set_ylabel(\"Value\")\n",
    "ax2.legend()\n",
    "\n",
    "# --- Final layout ---\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b739bc3c-2606-4651-b7bb-773da73d2af7",
   "metadata": {},
   "source": [
    "### 2. An example that does not favor either model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5746ccc6-e350-4a26-a71f-fd876ca29702",
   "metadata": {},
   "outputs": [],
   "source": [
    "context = np.sin(np.arange(512+64) / 7.5)\n",
    "scale = np.arange(context.shape[0]) / 512 +1\n",
    "context = context * (scale ** -2)\n",
    "groundtruth = context[512:]\n",
    "context = torch.from_numpy(context[:512])\n",
    "\n",
    "output = chronos.predict(\n",
    "    context,\n",
    "    prediction_length=64,\n",
    "    num_samples=1,\n",
    ").numpy()\n",
    "\n",
    "output_bolt = bolt.predict(\n",
    "    context,\n",
    "    prediction_length=64,\n",
    ").numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9412f0e9-ba16-409a-949e-8b78cb606e87",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(context,label=\"context\")\n",
    "plt.plot(np.arange(context.shape[0],context.shape[0]+output.shape[2]), output[0,0,:], label=\"forecast\")\n",
    "plt.plot(np.arange(context.shape[0],context.shape[0]+output.shape[2]), groundtruth, color=\"gray\", label=\"ground truth\", alpha=0.3)\n",
    "plt.plot(scale ** -2, '--', color=\"gray\")\n",
    "plt.plot(-scale ** -2, '--', color=\"gray\")\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5bd3abc-0943-4ef4-ac09-af3465d68422",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(context,label=\"context\")\n",
    "plt.plot(np.arange(context.shape[0],context.shape[0]+output.shape[2]), output_bolt[0,4,:], label=\"forecast\")\n",
    "plt.plot(np.arange(context.shape[0],context.shape[0]+output.shape[2]), groundtruth, color=\"gray\", label=\"ground truth\", alpha=0.3)\n",
    "plt.plot(scale ** -2, '--', color=\"gray\")\n",
    "plt.plot(-scale ** -2, '--', color=\"gray\")\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ffb71d9-1d77-477f-8479-0c947ed27730",
   "metadata": {},
   "outputs": [],
   "source": [
    "att1 = # load your saved attention matrices\n",
    "att2 = # load your saved attention matrices\n",
    "att3 = # load your saved attention matrices\n",
    "att4 = # load your saved attention matrices\n",
    "att5 = # load your saved attention matrices\n",
    "att6 = # load your saved attention matrices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e3d5e54-93cf-4cdd-a8bc-83bf4515b629",
   "metadata": {},
   "outputs": [],
   "source": [
    "att_all = np.concatenate((att1,att2,att3,att4,att5,att6), axis=1)\n",
    "att_all = att_all[0,:,0,:]\n",
    "scale = scale ** -2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "977cc5ef-bfaf-4df3-9196-131c30ea6f15",
   "metadata": {},
   "outputs": [],
   "source": [
    "q10 = np.quantile(att_all, 0.10, axis=0)\n",
    "q30 = np.quantile(att_all, 0.30, axis=0)\n",
    "q50 = np.quantile(att_all, 0.50, axis=0)  # median\n",
    "q70 = np.quantile(att_all, 0.70, axis=0)\n",
    "q90 = np.quantile(att_all, 0.90, axis=0)\n",
    "\n",
    "# X-axis: index from 0 to 512\n",
    "x = np.arange(att_all.shape[1])\n",
    "\n",
    "# --- Create two subplots sharing the x-axis ---\n",
    "fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 7), sharex=True, gridspec_kw={'height_ratios': [3, 3]})\n",
    "\n",
    "# --- Top: quantile plot ---\n",
    "ax1.plot(x, q50, color='red', label='Median')\n",
    "ax1.fill_between(x, q30, q70, color='red', alpha=0.4, label='30–70%')\n",
    "ax1.fill_between(x, q10, q90, color='red', alpha=0.2, label='10–90%')\n",
    "ax1.set_ylabel(\"Attention Score\")\n",
    "ax1.set_ylim(-0.002,0.03)\n",
    "ax1.legend()\n",
    "ax1.set_title(\"Chronos Cross Attention Scores\")\n",
    "\n",
    "# --- Bottom: context + hidden ---\n",
    "ax2.plot(x[:context.shape[0]], context, label='Context')\n",
    "ax2.plot(np.arange(att_all.shape[1], att_all.shape[1]+output.shape[-1]), output[0,0,:], label='Forecast', color='orange')\n",
    "ax2.plot(np.arange(att_all.shape[1], att_all.shape[1]+output.shape[-1]), groundtruth, label='Ground Truth', color='green', alpha=0.3)\n",
    "plt.plot(np.arange(scale.shape[0]), scale, ':', color=\"magenta\", alpha=0.3)\n",
    "plt.plot(np.arange(scale.shape[0]), -scale, ':', color=\"magenta\", alpha=0.3)\n",
    "ax2.set_xlabel(\"Time Step\")\n",
    "ax2.set_ylabel(\"Value\")\n",
    "ax2.legend()\n",
    "\n",
    "# --- Final layout ---\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b93979bd-bef6-4669-b739-f5a254a1c6e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "att1 = # load your saved attention matrices\n",
    "att2 = # load your saved attention matrices\n",
    "att3 = # load your saved attention matrices\n",
    "att4 = # load your saved attention matrices\n",
    "att5 = # load your saved attention matrices\n",
    "att6 = # load your saved attention matrices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0faf4366-a00c-4df0-b325-b20e55d4c27c",
   "metadata": {},
   "outputs": [],
   "source": [
    "att_all_bolt = np.concatenate((att1,att2,att3,att4,att5,att6), axis=1)\n",
    "att_all_bolt = att_all_bolt[0,:,0,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8483e9ff-fa87-4012-b104-a1608542e982",
   "metadata": {},
   "outputs": [],
   "source": [
    "q10 = np.quantile(att_all_bolt, 0.10, axis=0)\n",
    "q30 = np.quantile(att_all_bolt, 0.30, axis=0)\n",
    "q50 = np.quantile(att_all_bolt, 0.50, axis=0)  # median\n",
    "q70 = np.quantile(att_all_bolt, 0.70, axis=0)\n",
    "q90 = np.quantile(att_all_bolt, 0.90, axis=0)\n",
    "\n",
    "# X-axis: index from 0 to 512\n",
    "x = np.arange(att_all.shape[1])\n",
    "\n",
    "# --- Create two subplots sharing the x-axis ---\n",
    "fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 7), sharex=True, gridspec_kw={'height_ratios': [3, 3]})\n",
    "\n",
    "# --- Top: quantile plot ---\n",
    "ax1.plot(x, q50, color='red', label='Median')\n",
    "ax1.fill_between(x, q30, q70, color='red', alpha=0.4, label='30–70%')\n",
    "ax1.fill_between(x, q10, q90, color='red', alpha=0.2, label='10–90%')\n",
    "ax1.set_ylabel(\"Attention Score\")\n",
    "ax1.set_ylim(-0.002,0.03)\n",
    "ax1.legend()\n",
    "ax1.set_title(\"Chronos Cross Attention Scores\")\n",
    "\n",
    "# --- Bottom: context + hidden ---\n",
    "ax2.plot(x[:context.shape[0]], context, label='Context')\n",
    "ax2.plot(np.arange(att_all.shape[1], att_all.shape[1]+output.shape[-1]), output_bolt[0,4,:], label='Forecast', color='orange')\n",
    "ax2.fill_between(np.arange(att_all.shape[1], att_all.shape[1]+output.shape[-1]), output_bolt[0,2,:], output_bolt[0,6,:], label='Forecast', color='orange', alpha=0.4)\n",
    "ax2.fill_between(np.arange(att_all.shape[1], att_all.shape[1]+output.shape[-1]), output_bolt[0,0,:], output_bolt[0,8,:], label='Forecast', color='orange', alpha=0.2)\n",
    "ax2.plot(np.arange(att_all.shape[1], att_all.shape[1]+output.shape[-1]), groundtruth, label='Ground Truth', color='green', alpha=0.3)\n",
    "plt.plot(np.arange(scale.shape[0]), scale, ':', color=\"magenta\", alpha=0.3)\n",
    "plt.plot(np.arange(scale.shape[0]), -scale, ':', color=\"magenta\", alpha=0.3)\n",
    "ax2.set_xlabel(\"Time Step\")\n",
    "ax2.set_ylabel(\"Value\")\n",
    "ax2.legend()\n",
    "\n",
    "# --- Final layout ---\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e306cb5a-c9a2-48c0-9393-bd7a37788466",
   "metadata": {},
   "source": [
    "### 3. An example that favors Chronos"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "260011fb-36ad-471b-ba82-3d549906f234",
   "metadata": {},
   "outputs": [],
   "source": [
    "context = np.sin(np.arange(512+64) / 7)\n",
    "scale = np.arange(context.shape[0]) * 0 + 1\n",
    "context = context * scale\n",
    "groundtruth = context[512:]\n",
    "context = torch.from_numpy(context[:512])\n",
    "\n",
    "output = chronos.predict(\n",
    "    context,\n",
    "    prediction_length=64,\n",
    "    num_samples=1,\n",
    ").numpy()\n",
    "\n",
    "output_bolt = bolt.predict(\n",
    "    context,\n",
    "    prediction_length=64,\n",
    ").numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "229b9ccf-0851-4231-9fd7-dba3a6d9e3d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(context,label=\"context\")\n",
    "plt.plot(np.arange(context.shape[0],context.shape[0]+output.shape[2]), output[0,0,:], label=\"forecast\")\n",
    "plt.plot(np.arange(context.shape[0],context.shape[0]+output.shape[2]), groundtruth, color=\"gray\", label=\"ground truth\", alpha=0.3)\n",
    "plt.plot(scale, '--', color=\"gray\")\n",
    "plt.plot(-scale, '--', color=\"gray\")\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42c59a61-01ac-47c5-8d3b-8b3226d1e352",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(context,label=\"context\")\n",
    "plt.plot(np.arange(context.shape[0],context.shape[0]+output.shape[2]), output_bolt[0,4,:], label=\"forecast\")\n",
    "plt.plot(np.arange(context.shape[0],context.shape[0]+output.shape[2]), groundtruth, color=\"gray\", label=\"ground truth\", alpha=0.3)\n",
    "plt.plot(scale, '--', color=\"gray\")\n",
    "plt.plot(-scale, '--', color=\"gray\")\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a94d9d93-ea3d-40c5-a79f-455889c20c64",
   "metadata": {},
   "outputs": [],
   "source": [
    "att1 = # load your saved attention matrices\n",
    "att2 = # load your saved attention matrices\n",
    "att3 = # load your saved attention matrices\n",
    "att4 = # load your saved attention matrices\n",
    "att5 = # load your saved attention matrices\n",
    "att6 = # load your saved attention matrices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5cf4bbc-cf3b-4938-94e4-d8cb4db6d239",
   "metadata": {},
   "outputs": [],
   "source": [
    "att_all = np.concatenate((att1,att2,att3,att4,att5,att6), axis=1)\n",
    "att_all = att_all[0,:,0,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe63cc44-d5b6-455c-88e4-9aca2e0aaea7",
   "metadata": {},
   "outputs": [],
   "source": [
    "q10 = np.quantile(att_all, 0.10, axis=0)\n",
    "q30 = np.quantile(att_all, 0.30, axis=0)\n",
    "q50 = np.quantile(att_all, 0.50, axis=0)  # median\n",
    "q70 = np.quantile(att_all, 0.70, axis=0)\n",
    "q90 = np.quantile(att_all, 0.90, axis=0)\n",
    "\n",
    "# X-axis: index from 0 to 512\n",
    "x = np.arange(att_all.shape[1])\n",
    "\n",
    "# --- Create two subplots sharing the x-axis ---\n",
    "fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 7), sharex=True, gridspec_kw={'height_ratios': [3, 3]})\n",
    "\n",
    "# --- Top: quantile plot ---\n",
    "ax1.plot(x, q50, color='red', label='Median')\n",
    "ax1.fill_between(x, q30, q70, color='red', alpha=0.4, label='30–70%')\n",
    "ax1.fill_between(x, q10, q90, color='red', alpha=0.2, label='10–90%')\n",
    "ax1.set_ylabel(\"Attention Score\")\n",
    "ax1.set_ylim(-0.002,0.03)\n",
    "ax1.legend()\n",
    "ax1.set_title(\"Chronos Cross Attention Scores\")\n",
    "\n",
    "# --- Bottom: context + hidden ---\n",
    "ax2.plot(x[:context.shape[0]], context, label='Context')\n",
    "ax2.plot(np.arange(att_all.shape[1], att_all.shape[1]+output.shape[-1]), output[0,0,:], label='Forecast', color='orange')\n",
    "ax2.plot(np.arange(att_all.shape[1], att_all.shape[1]+output.shape[-1]), groundtruth, label='Ground Truth', color='green', alpha=0.3)\n",
    "plt.plot(np.arange(scale.shape[0]), scale, ':', color=\"magenta\", alpha=0.3)\n",
    "plt.plot(np.arange(scale.shape[0]), -scale, ':', color=\"magenta\", alpha=0.3)\n",
    "ax2.set_xlabel(\"Time Step\")\n",
    "ax2.set_ylabel(\"Value\")\n",
    "ax2.legend()\n",
    "\n",
    "# --- Final layout ---\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "822ef8bf-8511-4669-a43d-55be4d91a9aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "att1 = # load your saved attention matrices\n",
    "att2 = # load your saved attention matrices\n",
    "att3 = # load your saved attention matrices\n",
    "att4 = # load your saved attention matrices\n",
    "att5 = # load your saved attention matrices\n",
    "att6 = # load your saved attention matrices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b373dff6-e7ca-44cc-9b51-5bc6a7c51182",
   "metadata": {},
   "outputs": [],
   "source": [
    "att_all_bolt = np.concatenate((att1,att2,att3,att4,att5,att6), axis=1)\n",
    "att_all_bolt = att_all_bolt[0,:,0,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2d8982e-05b8-4633-b18e-f44d0ce844dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "q10 = np.quantile(att_all_bolt, 0.10, axis=0)\n",
    "q30 = np.quantile(att_all_bolt, 0.30, axis=0)\n",
    "q50 = np.quantile(att_all_bolt, 0.50, axis=0)  # median\n",
    "q70 = np.quantile(att_all_bolt, 0.70, axis=0)\n",
    "q90 = np.quantile(att_all_bolt, 0.90, axis=0)\n",
    "\n",
    "# X-axis: index from 0 to 512\n",
    "x = np.arange(att_all.shape[1])\n",
    "\n",
    "# --- Create two subplots sharing the x-axis ---\n",
    "fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 7), sharex=True, gridspec_kw={'height_ratios': [3, 3]})\n",
    "\n",
    "# --- Top: quantile plot ---\n",
    "ax1.plot(x, q50, color='red', label='Median')\n",
    "ax1.fill_between(x, q30, q70, color='red', alpha=0.4, label='30–70%')\n",
    "ax1.fill_between(x, q10, q90, color='red', alpha=0.2, label='10–90%')\n",
    "ax1.set_ylabel(\"Attention Score\")\n",
    "ax1.set_ylim(-0.002,0.03)\n",
    "ax1.legend()\n",
    "ax1.set_title(\"Chronos Cross Attention Scores\")\n",
    "\n",
    "# --- Bottom: context + hidden ---\n",
    "ax2.plot(x[:context.shape[0]], context, label='Context')\n",
    "ax2.plot(np.arange(att_all.shape[1], att_all.shape[1]+output.shape[-1]), output_bolt[0,4,:], label='Forecast', color='orange')\n",
    "ax2.fill_between(np.arange(att_all.shape[1], att_all.shape[1]+output.shape[-1]), output_bolt[0,2,:], output_bolt[0,6,:], label='Forecast', color='orange', alpha=0.4)\n",
    "ax2.fill_between(np.arange(att_all.shape[1], att_all.shape[1]+output.shape[-1]), output_bolt[0,0,:], output_bolt[0,8,:], label='Forecast', color='orange', alpha=0.2)\n",
    "ax2.plot(np.arange(att_all.shape[1], att_all.shape[1]+output.shape[-1]), groundtruth, label='Ground Truth', color='green', alpha=0.3)\n",
    "plt.plot(np.arange(scale.shape[0]), scale, ':', color=\"magenta\", alpha=0.3)\n",
    "plt.plot(np.arange(scale.shape[0]), -scale, ':', color=\"magenta\", alpha=0.3)\n",
    "ax2.set_xlabel(\"Time Step\")\n",
    "ax2.set_ylabel(\"Value\")\n",
    "ax2.legend()\n",
    "\n",
    "# --- Final layout ---\n",
    "plt.tight_layout()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
