{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "dba5d04d-db66-4cc0-85ea-059a0707344f",
   "metadata": {},
   "source": [
    "# Scale Bias"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a2f9d6a-d7ab-41f8-b0e9-3f50756309d7",
   "metadata": {},
   "source": [
    "In this notebook, we analyze the scale bias in designing TSFMs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e7d2c74-4266-43ee-8a35-70de32a054fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "pd.set_option('display.max_rows', None)\n",
    "import random\n",
    "import statistics\n",
    "import scipy\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import sys\n",
    "if 'google.colab' in sys.modules:\n",
    "    !pip install transformers weightwatcher\n",
    "\n",
    "from transformers import (\n",
    "    AutoModelForSeq2SeqLM,\n",
    "    AutoModelForCausalLM,\n",
    "    AutoConfig,\n",
    "    T5Config,\n",
    "    Trainer,\n",
    "    TrainingArguments,\n",
    ")\n",
    "\n",
    "from chronos import ChronosPipeline, ChronosBoltPipeline\n",
    "\n",
    "from collections import defaultdict\n",
    "\n",
    "from matplotlib.colors import PowerNorm\n",
    "import math"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9b445b7-5b00-4280-b705-a210060dc650",
   "metadata": {},
   "source": [
    "## Helper Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e854baa3-6799-411b-98bc-fb584ea3755a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_angles(mat):\n",
    "    L = mat.shape[0]\n",
    "    distance = np.zeros((L,L))\n",
    "    for i in range(L):\n",
    "        for j in range(i,L):\n",
    "            Vi = mat[i:(i+1),:].T\n",
    "            Vj = mat[j:(j+1),:].T\n",
    "            dist = np.mean(scipy.linalg.subspace_angles(Vi,Vj))\n",
    "            distance[i,j] = dist\n",
    "            distance[j,i] = dist\n",
    "    return distance\n",
    "\n",
    "def plot_angles(theta_rad, input_name = \"\", model_name = \"\"):\n",
    "    norm = PowerNorm(gamma=1, vmin=0, vmax=math.pi/2)\n",
    "    plt.imshow(theta_rad, norm=norm, cmap='viridis', aspect='equal')\n",
    "    cbar = plt.colorbar()\n",
    "    cbar.set_ticks([0,0.4,0.8,1.2,math.pi/2])\n",
    "    cbar.set_ticklabels([\"0.0\", \"0.4\", \"0.8\", \"1.2\", r\"$\\pi/2$\"])\n",
    "    plt.title(r'Angle between $\\mathbf{y}_i$ and $\\mathbf{y}_j$ (' + input_name + \")\")\n",
    "    plt.xlabel(r'$i$')\n",
    "    plt.ylabel(r'$j$')\n",
    "\n",
    "def perfect_align_context(num_patches):\n",
    "    signal = lambda t: np.cos(t+0.3)\n",
    "    gap = 2 * math.pi / num_patches / 16\n",
    "    samplers = np.arange(-2048, 64) * gap\n",
    "    return signal(samplers), gap\n",
    "\n",
    "def measure_periodicity(\n",
    "    data_array: np.ndarray,\n",
    "    period: int,\n",
    "    motif_length: int = 32,\n",
    ") -> float:\n",
    "    \"\"\"\n",
    "    Measures the periodicity of a 2D numpy array along its second axis.\n",
    "\n",
    "    This function works by comparing short segments of the sequence (\"motifs\")\n",
    "    with segments that occur exactly one period later. The similarity is\n",
    "    measured using the average cosine similarity between the corresponding\n",
    "    vectors in the motifs.\n",
    "\n",
    "    Args:\n",
    "        data_array (np.ndarray): The input array of shape (hidden_dim, seq_len).\n",
    "        period (int): The expected period to check for (e.g., 160).\n",
    "        motif_length (int, optional): The length of the sequence segment to\n",
    "            compare at each step. A longer motif is more robust to noise but\n",
    "            might miss finer details. Defaults to 32.\n",
    "        num_motifs_to_check (int, optional): The number of random motif pairs\n",
    "            to sample and average over. A higher number provides a more stable\n",
    "            score. Defaults to 50.\n",
    "\n",
    "    Returns:\n",
    "        float: A periodicity score between -1 and 1. A score closer to 1\n",
    "               indicates a stronger periodicity at the specified lag.\n",
    "    \"\"\"\n",
    "    \n",
    "    hidden_dim, seq_len = data_array.shape\n",
    "    max_start_index = seq_len - period - motif_length\n",
    "    all_similarities = []\n",
    "    \n",
    "    # starting points for the motifs\n",
    "    start_indices = np.arange(max_start_index + 1)\n",
    "\n",
    "    for start_idx in start_indices:\n",
    "        # Extract the first motif\n",
    "        motif1 = data_array[:, start_idx : start_idx + motif_length]\n",
    "        \n",
    "        # Extract the second motif, exactly one period later\n",
    "        motif2 = data_array[:, start_idx + period : start_idx + period + motif_length]\n",
    "        \n",
    "        _, _, r_score, _, _ = scipy.stats.linregress(motif1.flatten(), motif2.flatten())\n",
    "        all_similarities.append(r_score ** 2)\n",
    "        \n",
    "    return all_similarities\n",
    "\n",
    "\n",
    "def measure_distance(\n",
    "    data_array: np.ndarray,\n",
    "    period: int,\n",
    ") -> float:\n",
    "    \"\"\"\n",
    "    Measures the periodicity of a 2D numpy array along its first axis.\n",
    "\n",
    "    This function works by comparing short segments of the sequence (\"motifs\")\n",
    "    with segments that occur exactly one period later. The similarity is\n",
    "    measured using the average cosine similarity between the corresponding\n",
    "    vectors in the motifs.\n",
    "\n",
    "    Args:\n",
    "        data_array (np.ndarray): The input array of shape (hidden_dim, seq_len).\n",
    "        period (int): The expected period to check for (e.g., 160).\n",
    "        motif_length (int, optional): The length of the sequence segment to\n",
    "            compare at each step. A longer motif is more robust to noise but\n",
    "            might miss finer details. Defaults to 32.\n",
    "        num_motifs_to_check (int, optional): The number of random motif pairs\n",
    "            to sample and average over. A higher number provides a more stable\n",
    "            score. Defaults to 50.\n",
    "\n",
    "    Returns:\n",
    "        float: A periodicity score between -1 and 1. A score closer to 1\n",
    "               indicates a stronger periodicity at the specified lag.\n",
    "    \"\"\"\n",
    "    \n",
    "    seq_len, hidden_dim = data_array.shape\n",
    "\n",
    "    U, S, VT = scipy.linalg.svd(data_array)\n",
    "    total_energy = np.sum(S)\n",
    "    distance = np.zeros(seq_len - period)\n",
    "    r_score_ave = 0\n",
    "    \n",
    "    for i in range(hidden_dim):\n",
    "        distance += (np.abs((U[:(seq_len - period),i] - U[period:,i]))) * (S[i] / total_energy)\n",
    "        \n",
    "    return distance\n",
    "\n",
    "\n",
    "def smooth_step(start, end, num_points):\n",
    "    \"\"\"Smooth transition from start to end using cosine interpolation.\"\"\"\n",
    "    t = np.linspace(0, 1, num_points)\n",
    "    return start + (end - start) * 0.5 * (1 - np.cos(np.pi * t))\n",
    "\n",
    "\n",
    "def generate_smooth_plateaus(plateau_width=50, transition_width=10, num_plateaus=6, high_value=1.0, low_value=0.05):\n",
    "    values = []\n",
    "    current_value = high_value\n",
    "\n",
    "    for i in range(num_plateaus):\n",
    "        # Plateau region\n",
    "        plateau = np.full(plateau_width, current_value)\n",
    "        values.append(plateau)\n",
    "\n",
    "        # Transition region\n",
    "        next_value = low_value if current_value == high_value else high_value\n",
    "        transition = smooth_step(current_value, next_value, transition_width)\n",
    "        values.append(transition)\n",
    "\n",
    "        # Alternate for next iteration\n",
    "        current_value = next_value\n",
    "\n",
    "    # Flatten and return\n",
    "    return np.concatenate(values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6243996a-2f6a-4354-8006-37e3b8c25397",
   "metadata": {},
   "outputs": [],
   "source": [
    "chronos = ChronosPipeline.from_pretrained(f\"amazon/chronos-t5-small\")\n",
    "bolt = ChronosBoltPipeline.from_pretrained(f\"amazon/chronos-bolt-small\")\n",
    "bolt_f1 = # Add you pretrained Chronos-Bolt with a patch size of 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c94955f6-a469-422f-a1e2-9d7630a83d87",
   "metadata": {},
   "source": [
    "## Decoupled Scale Bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3aa6b11a-b6c0-4342-b813-2f963cfec951",
   "metadata": {},
   "outputs": [],
   "source": [
    "context_gt, _ = perfect_align_context(1)\n",
    "context_gt -= 4\n",
    "arr = generate_smooth_plateaus(plateau_width=100, transition_width=10, num_plateaus=40, low_value=0.05)[:context_gt.shape[0]]\n",
    "context_gt *= arr\n",
    "context_gt += 4\n",
    "context = context_gt[(-512-64):-64]\n",
    "groundtruth = context_gt[-64:]\n",
    "plt.plot(context)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c15ee944-eccb-411a-8b71-fe5e04a2f007",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_bolt = bolt.predict(\n",
    "                torch.from_numpy(context),\n",
    "                prediction_length=1,\n",
    "            ).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0127fd46-5b7b-4516-9862-37a24c6407ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_chronos = chronos.predict(\n",
    "                torch.from_numpy(context),\n",
    "                prediction_length=1,\n",
    "                num_samples=1,\n",
    "            ).numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e46c9e4-da9d-4b19-9acf-8dd03783baa8",
   "metadata": {},
   "source": [
    "## Show diagnostic plots for both Chronos and Chronos-Bolt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "891f40ae-5503-4083-b9a4-0c7d38fe282f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "hidden_chronos = # Load your saved hidden states here\n",
    "input_chronos = # Load your saved input here\n",
    "\n",
    "_,_,Vt = scipy.linalg.svd(input_chronos[0,...])\n",
    "proj_in = input_chronos @ Vt.T[:, 0:10]\n",
    "proj_in = proj_in[0,:]\n",
    "\n",
    "U,S_chronos,Vt = scipy.linalg.svd(hidden_chronos[0,...])\n",
    "proj = hidden_chronos @ Vt.T[:, 0:10]\n",
    "proj = proj[0,:]\n",
    "fig, ax = plt.subplots(14, 1, figsize=(10, 16), sharex=True, gridspec_kw={'height_ratios': [3] + [0.3] + [3] + [0.3] + [1]*10})\n",
    "ax[0].plot(context[:512], color=\"red\")\n",
    "ax[0].axvline(105, color=\"gray\", alpha=0.5)\n",
    "ax[0].axvline(215, color=\"gray\", alpha=0.5)\n",
    "ax[0].axvline(325, color=\"gray\", alpha=0.5)\n",
    "ax[0].axvline(435, color=\"gray\", alpha=0.5)\n",
    "ax[0].set_yticks([])\n",
    "\n",
    "ax[2].plot(scipy.linalg.norm(input_chronos[0,:512,:], axis=1), color=\"green\")\n",
    "ax[2].axvline(105, color=\"gray\", alpha=0.5)\n",
    "ax[2].axvline(215, color=\"gray\", alpha=0.5)\n",
    "ax[2].axvline(325, color=\"gray\", alpha=0.5)\n",
    "ax[2].axvline(435, color=\"gray\", alpha=0.5)\n",
    "ax[2].set_ylim(0.45,0.55)\n",
    "\n",
    "for i in range(4,14):\n",
    "    ax[i].plot(proj[:512, i-4])\n",
    "    ax[i].axvline(105, color=\"gray\", alpha=0.5)\n",
    "    ax[i].axvline(215, color=\"gray\", alpha=0.5)\n",
    "    ax[i].axvline(325, color=\"gray\", alpha=0.5)\n",
    "    ax[i].axvline(435, color=\"gray\", alpha=0.5)\n",
    "    ax[i].set_yticks([])\n",
    "\n",
    "ax[0].set_title(\"Input Context\")\n",
    "ax[1].axis('off')\n",
    "ax[1].text(175, 0, 'Norm of Embedded Vectors (Chronos)', fontsize=12)\n",
    "ax[3].axis('off')\n",
    "ax[3].text(98, 0, 'Encoded Context Projected onto the First 10 Singular Vectors (Chronos)', fontsize=12)\n",
    "\n",
    "plt.tight_layout()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3269a2aa-a1f3-4199-b55d-4ddc1662d082",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "hidden_bolt = # Load your saved hidden states here\n",
    "input_bolt = # Load your saved input here\n",
    "U,S_bolt,Vt = scipy.linalg.svd(hidden_bolt[0,...])\n",
    "proj = hidden_bolt @ Vt.T[:, 0:10]\n",
    "proj = proj[0,:]\n",
    "fig, ax = plt.subplots(14, 1, figsize=(10, 16), sharex=True, gridspec_kw={'height_ratios': [3] + [0.3] + [3] + [0.3] + [1]*10})\n",
    "ax[0].plot(context[:512], color=\"red\")\n",
    "ax[0].axvline(105, color=\"gray\", alpha=0.5)\n",
    "ax[0].axvline(215, color=\"gray\", alpha=0.5)\n",
    "ax[0].axvline(325, color=\"gray\", alpha=0.5)\n",
    "ax[0].axvline(435, color=\"gray\", alpha=0.5)\n",
    "ax[0].set_yticks([])\n",
    "\n",
    "ax[2].plot(scipy.linalg.norm(input_bolt[0,:512,:], axis=1), color=\"green\")\n",
    "ax[2].axvline(105, color=\"gray\", alpha=0.5)\n",
    "ax[2].axvline(215, color=\"gray\", alpha=0.5)\n",
    "ax[2].axvline(325, color=\"gray\", alpha=0.5)\n",
    "ax[2].axvline(435, color=\"gray\", alpha=0.5)\n",
    "ax[2].set_ylim(0,1.3)\n",
    "\n",
    "for i in range(4,14):\n",
    "    ax[i].plot(proj[:512, i-4])\n",
    "    ax[i].axvline(105, color=\"gray\", alpha=0.5)\n",
    "    ax[i].axvline(215, color=\"gray\", alpha=0.5)\n",
    "    ax[i].axvline(325, color=\"gray\", alpha=0.5)\n",
    "    ax[i].axvline(435, color=\"gray\", alpha=0.5)\n",
    "    ax[i].set_yticks([])\n",
    "\n",
    "ax[0].set_title(\"Input Context\")\n",
    "ax[1].axis('off')\n",
    "ax[1].text(182, 0, 'Norm of Embedded Vectors (Bolt)', fontsize=12)\n",
    "ax[3].axis('off')\n",
    "ax[3].text(107, 0, 'Encoded Context Projected onto the First 10 Singular Vectors (Bolt)', fontsize=12)\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01cedaf4-fbba-4ad7-8e84-e628c44b8b0a",
   "metadata": {},
   "source": [
    "## Predicting a large scale one"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1a8f100-8ffb-430e-a360-d7a36f047b02",
   "metadata": {},
   "outputs": [],
   "source": [
    "context_gt, _ = perfect_align_context(1)\n",
    "context_gt -= 4\n",
    "arr = generate_smooth_plateaus(plateau_width=100, transition_width=10, num_plateaus=40, low_value=0.05)[:context_gt.shape[0]]\n",
    "context_gt *= arr\n",
    "context_gt += 4\n",
    "context_gt = context_gt[:-10]\n",
    "context = context_gt[(-512-64):-64]\n",
    "groundtruth = context_gt[-64:]\n",
    "plt.plot(context)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9921a025-2154-4cc1-8d63-0300abfe9d82",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_bolt = bolt_f1.predict(\n",
    "                torch.from_numpy(context),\n",
    "                prediction_length=64,\n",
    "            ).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30747286-e47c-4028-854e-7c6cdef30d48",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_chronos = chronos.predict(\n",
    "                torch.from_numpy(context),\n",
    "                prediction_length=64,\n",
    "                num_samples=1,\n",
    "            ).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "749fbe41-421f-4224-ad04-31d302975a7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(context, label=\"context\")\n",
    "plt.plot(np.arange(context.shape[0], context.shape[0] + output_chronos.shape[-1]), output_chronos[0,0,:], label=\"forecast\")\n",
    "plt.xlabel(\"time step\")\n",
    "plt.ylabel(\"value\")\n",
    "plt.title(\"Chronos Forecast\")\n",
    "plt.legend(loc=\"upper left\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76d3497c-5cc1-4a73-b77a-63a5051af0d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(context, label=\"context\")\n",
    "plt.plot(np.arange(context.shape[0], context.shape[0] + output_bolt.shape[-1]), output_bolt[0,4,:], label=\"median forecast\")\n",
    "plt.fill_between(np.arange(context.shape[0], context.shape[0] + output_bolt.shape[-1]), output_bolt[0,2,:], output_bolt[0,6,:], color=\"orange\", alpha=0.4, label=\"30-70%\")\n",
    "plt.fill_between(np.arange(context.shape[0], context.shape[0] + output_bolt.shape[-1]), output_bolt[0,0,:], output_bolt[0,8,:], color=\"orange\", alpha=0.2, label=\"10-90%\")\n",
    "plt.xlabel(\"time step\")\n",
    "plt.ylabel(\"value\")\n",
    "plt.title(\"Bolt Forecast\")\n",
    "plt.legend(loc=\"upper left\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f1d78b7-abb0-4240-8fd8-8a36f1eb94e6",
   "metadata": {},
   "source": [
    "## Predicting a small scale one"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ebfd2fe-4133-40a0-99d3-baefaf4b37bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "context_gt, _ = perfect_align_context(1)\n",
    "context_gt -= 4\n",
    "arr = generate_smooth_plateaus(plateau_width=100, transition_width=10, num_plateaus=40, low_value=0.05)[:context_gt.shape[0]]\n",
    "context_gt *= arr\n",
    "context_gt += 4\n",
    "context_gt = context_gt[:-140]\n",
    "context = context_gt[(-512-64):-64]\n",
    "groundtruth = context_gt[-64:]\n",
    "plt.plot(context)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60a873a1-b8ed-4a42-805f-c62331c533b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_bolt = bolt_f1.predict(\n",
    "                torch.from_numpy(context),\n",
    "                prediction_length=64,\n",
    "            ).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1e2675f-a66e-487d-bd33-a5fc30c626a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_chronos = chronos.predict(\n",
    "                torch.from_numpy(context),\n",
    "                prediction_length=64,\n",
    "                num_samples=1,\n",
    "            ).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f1c021d-145d-4893-a834-63cab4eeb5a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(context, label=\"context\")\n",
    "plt.plot(np.arange(context.shape[0], context.shape[0] + output_chronos.shape[-1]), output_chronos[0,0,:], label=\"forecast\")\n",
    "plt.xlabel(\"time step\")\n",
    "plt.ylabel(\"value\")\n",
    "plt.title(\"Chronos Forecast\")\n",
    "plt.legend(loc=\"upper left\")\n",
    "plt.ylim(3.5,4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d6002fe-f0e4-4b82-a51f-7cceffcf6562",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(context, label=\"context\")\n",
    "plt.plot(np.arange(context.shape[0], context.shape[0] + output_bolt.shape[-1]), output_bolt[0,4,:], label=\"median forecast\")\n",
    "plt.fill_between(np.arange(context.shape[0], context.shape[0] + output_bolt.shape[-1]), output_bolt[0,2,:], output_bolt[0,6,:], color=\"orange\", alpha=0.4, label=\"30-70%\")\n",
    "plt.fill_between(np.arange(context.shape[0], context.shape[0] + output_bolt.shape[-1]), output_bolt[0,0,:], output_bolt[0,8,:], color=\"orange\", alpha=0.2, label=\"10-90%\")\n",
    "plt.xlabel(\"time step\")\n",
    "plt.ylabel(\"value\")\n",
    "plt.title(\"Bolt Forecast\")\n",
    "plt.legend(loc=\"upper left\")\n",
    "plt.ylim(3.5,4)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
