{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "292fae7a-e3d3-477c-bfce-65965e3e060e",
   "metadata": {},
   "source": [
    "# Frequency Bias"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "648969a6-e554-4dba-acd1-95c6f10e1b26",
   "metadata": {},
   "source": [
    "In this notebook, we analyze the frequency bias in designing TSFMs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f06774ef-529c-46e8-b806-f911b1e6199b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "pd.set_option('display.max_rows', None)\n",
    "import random\n",
    "import statistics\n",
    "import scipy\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import sys\n",
    "if 'google.colab' in sys.modules:\n",
    "    !pip install transformers weightwatcher\n",
    "\n",
    "from transformers import (\n",
    "    AutoModelForSeq2SeqLM,\n",
    "    AutoModelForCausalLM,\n",
    "    AutoConfig,\n",
    "    T5Config,\n",
    "    Trainer,\n",
    "    TrainingArguments,\n",
    ")\n",
    "\n",
    "from chronos import ChronosPipeline, ChronosBoltPipeline\n",
    "\n",
    "from collections import defaultdict\n",
    "\n",
    "from matplotlib.colors import PowerNorm\n",
    "import math\n",
    "\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "from matplotlib.colors import TwoSlopeNorm, LinearSegmentedColormap\n",
    "\n",
    "from matplotlib.patches import Ellipse"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "318c0357-a559-4951-89d3-ce0e478bfd3a",
   "metadata": {},
   "source": [
    "## Helper Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bb2faff-db18-4baa-94e9-0f13af2949f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_angles(mat):\n",
    "    L = mat.shape[0]\n",
    "    distance = np.zeros((L,L))\n",
    "    for i in range(L):\n",
    "        for j in range(i,L):\n",
    "            Vi = mat[i:(i+1),:].T\n",
    "            Vj = mat[j:(j+1),:].T\n",
    "            dist = np.mean(scipy.linalg.subspace_angles(Vi,Vj))\n",
    "            distance[i,j] = dist\n",
    "            distance[j,i] = dist\n",
    "    return distance\n",
    "\n",
    "def plot_angles(theta_rad, context, input_name=\"\", model_name=\"\"):\n",
    "    \"\"\"\n",
    "    Plots a heatmap of angles and a corresponding context signal below it,\n",
    "    ensuring the x-axes are perfectly aligned.\n",
    "\n",
    "    Args:\n",
    "        theta_rad (np.ndarray): A 2D numpy array of angle values in radians.\n",
    "        context (np.ndarray): A 1D numpy array representing the time series context.\n",
    "        input_name (str, optional): Name of the input for the title.\n",
    "        model_name (str, optional): Name of the model (not used in this version).\n",
    "    \"\"\"\n",
    "    # --- 1. Setup Figure and Subplots ---\n",
    "    # sharex=True is correct for linking the data axes.\n",
    "    fig, ax = plt.subplots(2, 1, figsize=(10, 12), sharex=True, gridspec_kw={'height_ratios': [10, 3]})\n",
    "    \n",
    "    # --- 2. Top Panel: Heatmap ---\n",
    "    norm = PowerNorm(gamma=0.5, vmin=0, vmax=math.pi/2)\n",
    "    im = ax[0].imshow(theta_rad, norm=norm, cmap='viridis', aspect='auto') # Changed aspect to 'auto'\n",
    "    ax[0].set_title(r'Angle between $\\mathbf{x}_i$ and $\\mathbf{x}_j$ (' + input_name + \")\")\n",
    "    ax[0].set_ylabel(\"Sequence Index i\")\n",
    "    divider = make_axes_locatable(ax[0])\n",
    "    \n",
    "    # --- 4. Bottom Panel: Line Plot ---\n",
    "    ax[1].plot(context)\n",
    "    ax[1].set_ylim(-2, 2)\n",
    "    ax[1].set_ylabel(\"Amplitude\")\n",
    "    ax[1].set_xlabel(\"Sequence Index j\")\n",
    "    ax[1].grid(True, linestyle=':')\n",
    "\n",
    "    # The line `ax[0].set_xlim(ax[1].get_xlim())` is now redundant and can be removed.\n",
    "    # The `sharex=True` argument handles this automatically.\n",
    "    fig.subplots_adjust(right=0.8)\n",
    "    cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])\n",
    "    fig.colorbar(im, cax=cbar_ax)\n",
    "    \n",
    "    return fig, ax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54682f3a-c018-402c-89dd-e4c76bc94db0",
   "metadata": {},
   "outputs": [],
   "source": [
    "bolt_p1 =  # Add you pretrained Chronos-Bolt with a patch size of 1\n",
    "bolt = ChronosBoltPipeline.from_pretrained(f\"amazon/chronos-bolt-small\")\n",
    "chronos = ChronosPipeline.from_pretrained(f\"amazon/chronos-t5-small\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b219d537-8201-412d-816b-d71f045426f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_bolt_output(output, context, groundtruth = None):\n",
    "    fig = plt.figure()\n",
    "    ax = fig.add_subplot(111)\n",
    "    ax.plot(np.arange(0,context.shape[0]), context)\n",
    "    if not groundtruth is None:\n",
    "        ax.plot(np.arange(context.shape[0],context.shape[0]+groundtruth.shape[0]), groundtruth, alpha=0.5, color='green', label=\"ground truth\")\n",
    "    ax.plot(np.arange(context.shape[0],context.shape[0]+output.shape[2]),output[0,4,:])\n",
    "    ax.fill_between(np.arange(context.shape[0],context.shape[0]+output.shape[2]), output[0,2,:], output[0,6,:], color='orange', alpha=0.4, label=\"30-70%\")\n",
    "    ax.fill_between(np.arange(context.shape[0],context.shape[0]+output.shape[2]), output[0,0,:], output[0,8,:], color='orange', alpha=0.2, label=\"10-90%\")\n",
    "    ax.legend()\n",
    "    return fig, ax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f7b44ec-7318-4888-87bc-a36222c0a585",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_all(directory, context, output, is_bolt, repeat_num = 1, savefig = None):\n",
    "    att1 = np.load(directory + \"attention8.npy\")\n",
    "    att2 = np.load(directory + \"attention10.npy\")\n",
    "    att3 = np.load(directory + \"attention12.npy\")\n",
    "    att4 = np.load(directory + \"attention14.npy\")\n",
    "    att5 = np.load(directory + \"attention16.npy\")\n",
    "    att6 = np.load(directory + \"attention18.npy\")\n",
    "    att_all = np.concatenate((att1,att2,att3,att4,att5,att6), axis=1)\n",
    "    att_all = att_all[0,:,0,:]\n",
    "    embed = np.load(directory + \"input1.npy\")\n",
    "    hidden = np.load(directory + \"hidden6.npy\")\n",
    "\n",
    "    q10 = np.repeat(np.quantile(att_all, 0.10, axis=0), repeat_num)[:512] / repeat_num\n",
    "    q30 = np.repeat(np.quantile(att_all, 0.30, axis=0), repeat_num)[:512] / repeat_num\n",
    "    q50 = np.repeat(np.quantile(att_all, 0.50, axis=0), repeat_num)[:512] / repeat_num  # median\n",
    "    q70 = np.repeat(np.quantile(att_all, 0.70, axis=0), repeat_num)[:512] / repeat_num\n",
    "    q90 = np.repeat(np.quantile(att_all, 0.90, axis=0), repeat_num)[:512] / repeat_num\n",
    "\n",
    "    x = np.arange(512)\n",
    "    \n",
    "    # --- Create two subplots sharing the x-axis ---\n",
    "    fig, ax = plt.subplots(15, 1, figsize=(10, 15), sharex=True, gridspec_kw={'height_ratios': [3, 0.5, 1.5, 0.5] + [1] * 5 + [0.5] + [1] * 5})\n",
    "    \n",
    "    # --- Top: quantile plot ---\n",
    "    ax[0].plot(x, q50, color='red', label='Median')\n",
    "    ax[0].fill_between(x, q30, q70, color='red', alpha=0.4, label='30–70%')\n",
    "    ax[0].fill_between(x, q10, q90, color='red', alpha=0.2, label='10–90%')\n",
    "    ax[0].set_ylabel(\"Attention Score\")\n",
    "    ax[0].set_ylim(-0.0005,0.01)\n",
    "    ax[0].legend()\n",
    "    if is_bolt:\n",
    "        ax[0].set_title(\"Cross Attention Scores (Bolt)\")\n",
    "    else:\n",
    "        ax[0].set_title(\"Cross Attention Scores (Chronos)\")\n",
    "        \n",
    "    # --- Bottom: context + forecast ---\n",
    "    ax[1].axis('off')\n",
    "    ax[1].text(220,0,\"Context and Forecast\", fontsize=12)\n",
    "    ax[2].plot(x[-512:], context[-512:], label='Context', color='green')\n",
    "    if is_bolt:\n",
    "        ax[2].plot(np.arange(q10.shape[0], q10.shape[0]+output.shape[2]), output[0,4,:], label='Forecast', color='orange')\n",
    "        ax[2].fill_between(np.arange(q10.shape[0], q10.shape[0]+output.shape[2]), output[0,2,:], output[0,6,:], label='30-70%', alpha=0.4, color='orange')\n",
    "        ax[2].fill_between(np.arange(q10.shape[0], q10.shape[0]+output.shape[2]), output[0,0,:], output[0,8,:], label='10-90%', alpha=0.2, color='orange')\n",
    "    else:\n",
    "        ax[2].plot(np.arange(q10.shape[0], q10.shape[0]+output.shape[2]), output[0,0,:], label='Forecast', color='orange')\n",
    "    ax[2].set_ylim(-1.5,1.5)\n",
    "    ax[2].legend()\n",
    "\n",
    "    ax[3].axis('off')\n",
    "    ax[3].text(230,0,\"Embedded Context\", fontsize=12)\n",
    "    U,_,Vt = scipy.linalg.svd(embed[0,...])\n",
    "    proj = embed @ Vt.T[:, 0:5]\n",
    "    proj = proj[0,:]\n",
    "    for i in range(5):\n",
    "        ax[i+4].plot(np.repeat(proj[:, i], repeat_num)[:511])\n",
    "        ax[i+4].axvline(x=0, color=\"gray\", alpha = 0.5)\n",
    "        ax[i+4].axvline(x=64, color=\"gray\", alpha = 0.5)\n",
    "        ax[i+4].axvline(x=64+128, color=\"gray\", alpha = 0.5)\n",
    "        ax[i+4].axvline(x=64+128*2, color=\"gray\", alpha = 0.5)\n",
    "        ax[i+4].axvline(x=64+128*3, color=\"gray\", alpha = 0.5)\n",
    "        #ax[i+4].set_yticks([])\n",
    "\n",
    "    ax[9].axis('off')\n",
    "    ax[9].text(230,0,\"Encoded Context\", fontsize=12)\n",
    "    U,_,Vt = scipy.linalg.svd(hidden[0,...])\n",
    "    proj = hidden @ Vt.T[:, 0:5]\n",
    "    proj = proj[0,:]\n",
    "    for i in range(5):\n",
    "        ax[i+10].plot(np.repeat(proj[:, i], repeat_num)[:511], color=\"purple\")\n",
    "        ax[i+10].axvline(x=64, color=\"gray\", alpha = 0.5)\n",
    "        ax[i+10].axvline(x=64+128, color=\"gray\", alpha = 0.5)\n",
    "        ax[i+10].axvline(x=64+128*2, color=\"gray\", alpha = 0.5)\n",
    "        ax[i+10].axvline(x=64+128*3, color=\"gray\", alpha = 0.5)\n",
    "        #ax[i+10].set_yticks([])\n",
    "    ax[-1].set_xlabel(\"Time Step\")\n",
    "\n",
    "    plt.tight_layout()\n",
    "    if not savefig is None:\n",
    "        plt.savefig(savefig+\".png\")\n",
    "        plt.savefig(savefig+\".eps\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b498b47-ceb4-4513-acec-f0709aed2dc5",
   "metadata": {},
   "source": [
    "## A simple test: superposition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a3f2602-d3b8-4d44-ad04-58b10f31b950",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "samplers = np.arange(512+64) / 512 * 2 * math.pi + math.pi / 10\n",
    "context = np.sin(samplers * 8) + np.sin(samplers * 110)\n",
    "groundtruth = context[-64:]\n",
    "context = context[:-64]\n",
    "plt.plot(context)\n",
    "plt.plot(context.shape[0] + np.arange(groundtruth.shape[0]), groundtruth)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6e2d3d1-f02b-4a37-a5a3-c072318e6c79",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_bolt_16 = bolt.predict(\n",
    "                torch.from_numpy(context),\n",
    "                prediction_length=64,\n",
    "            ).numpy()\n",
    "\n",
    "output_bolt_1 = bolt_p1.predict(\n",
    "                torch.from_numpy(context),\n",
    "                prediction_length=64,\n",
    "            ).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23d49734-de68-4221-9dcd-6d3308ea5873",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "fig, ax = plot_bolt_output(output_bolt_1, context, groundtruth)\n",
    "ax.set_title(\"Bolt's Forecast (Patch Size = 1)\")\n",
    "ax.set_xlabel(\"time step\")\n",
    "ax.set_ylabel(\"value\")\n",
    "ax.set_xlim(512-64,512+64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35faadba-b589-420b-a55f-8640cd35deec",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "fig, ax = plot_bolt_output(output_bolt_16, context, groundtruth)\n",
    "ax.set_title(\"Bolt's Forecast (Patch Size = 16)\")\n",
    "ax.set_xlabel(\"time step\")\n",
    "ax.set_ylabel(\"value\")\n",
    "ax.set_xlim(512-64,512+64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de6bc95d-9c5e-44b9-9cde-db9493e1f8a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(np.fft.fft((groundtruth))[:32], label=\"groundtruth\")\n",
    "plt.plot(np.fft.fft((output_bolt_1[0,4,:]))[:32], label=\"Bolt (patch size = 1)\")\n",
    "plt.plot(np.fft.fft((output_bolt_16[0,4,:]))[:32], label=\"Bolt (patch size = 16)\")\n",
    "plt.axvline(x=1, color=\"gray\", alpha=0.8, linestyle=\"--\", label=\"low frequency\")\n",
    "plt.axvline(x=16, color=\"red\", alpha=0.8, linestyle=\"--\", label=\"high frequency\")\n",
    "plt.legend()\n",
    "plt.ylabel(\"Fourier coefficient\")\n",
    "plt.xlabel(\"frequency\")\n",
    "plt.title(\"FFT of the Ground Truth and Forecasts\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49e4d962-d9c1-42b6-b2a0-7442da4cb3db",
   "metadata": {},
   "source": [
    "### Do a more comprehensive analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c3f55e4-362d-4bc9-a551-250e73df9176",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_two_modes(gt_f):\n",
    "    sort_inds = np.argsort(np.abs(gt_f[:gt_f.shape[0] // 2]))\n",
    "    if np.abs(sort_inds[-1] - sort_inds[-2]) == 1:\n",
    "        sort_inds[-2] = sort_inds[-3]\n",
    "    if sort_inds[-1] < sort_inds[-2]:\n",
    "        return sort_inds[-1], sort_inds[-2]\n",
    "    else:\n",
    "        return sort_inds[-2], sort_inds[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67abe205-8ef2-4cb6-aa38-6e478d24d016",
   "metadata": {},
   "outputs": [],
   "source": [
    "low_freq_grid = np.logspace(1, 3, num = 20, base = 2.0)\n",
    "high_freq_grid = np.logspace(6, 7, num = 20, base = 2.0)\n",
    "err_mat_low_1 = np.zeros((20,20))\n",
    "err_mat_high_1 = np.zeros((20,20))\n",
    "err_mat_low_16 = np.zeros((20,20))\n",
    "err_mat_high_16 = np.zeros((20,20))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4e64e80-cd7e-4ea2-9946-8ff9b11a9723",
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs_1 = []\n",
    "outputs_16 = []\n",
    "gts = []\n",
    "ctxs = []\n",
    "\n",
    "for i in range(low_freq_grid.shape[0]):\n",
    "    gts_sub = []\n",
    "    ctxs_sub = []\n",
    "    for j in range(high_freq_grid.shape[0]):\n",
    "        context = np.sin(samplers * low_freq_grid[i]) + np.sin(samplers * high_freq_grid[j])\n",
    "        groundtruth = context[-64:]\n",
    "        context = context[:-64]\n",
    "        gts_sub.append(groundtruth)\n",
    "        ctxs_sub.append(context)\n",
    "        \n",
    "        output_bolt_16 = bolt.predict(\n",
    "                        torch.from_numpy(context),\n",
    "                        prediction_length=64,\n",
    "                    ).numpy()\n",
    "        \n",
    "        output_bolt_1 = bolt_p1.predict(\n",
    "                        torch.from_numpy(context),\n",
    "                        prediction_length=64,\n",
    "                    ).numpy()\n",
    "\n",
    "        outputs_1.append(output_bolt_1)\n",
    "        outputs_16.append(output_bolt_16)\n",
    "\n",
    "        fft_gt = np.fft.fft(groundtruth)[:32]\n",
    "        fft_1 = np.fft.fft(output_bolt_1[0,4,:])[:32]\n",
    "        fft_16 = np.fft.fft(output_bolt_16[0,4,:])[:32]\n",
    "\n",
    "        i1, i2 = find_two_modes(fft_gt)\n",
    "        err_mat_low_1[i,j] = np.abs(np.abs(fft_gt[i1]) - np.abs(fft_1[i1])) / np.abs(fft_gt[i1])\n",
    "        err_mat_high_1[i,j] = np.abs(np.abs(fft_gt[i2]) - np.abs(fft_1[i2])) / np.abs(fft_gt[i2])\n",
    "        err_mat_low_16[i,j] = np.abs(np.abs(fft_gt[i1]) - np.abs(fft_16[i1])) / np.abs(fft_gt[i1])\n",
    "        err_mat_high_16[i,j] = np.abs(np.abs(fft_gt[i2]) - np.abs(fft_16[i2])) / np.abs(fft_gt[i2])\n",
    "\n",
    "    gts.append(gts_sub)\n",
    "    ctxs.append(ctxs_sub)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f36dcd8-1927-4066-aada-52e0f092a2ff",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plt.imshow(err_mat_low_16, vmin=0, vmax=1)\n",
    "plt.colorbar()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b4077fb-3ff8-42aa-97f0-222b4efe1a69",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plt.imshow(err_mat_low_1, vmin=0, vmax=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "821958e3-cf85-444b-bf45-b5e245e94566",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(err_mat_high_16, vmin=0, vmax=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c52a205-0db5-43fb-a033-a30b286e6f15",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plt.imshow(err_mat_high_1, vmin=0, vmax=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d778af4-41f7-4519-a9f2-61374c2880d0",
   "metadata": {},
   "source": [
    "## Check for a different context"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7d5fba0-1702-4067-b2e8-cb78385c8f35",
   "metadata": {},
   "outputs": [],
   "source": [
    "context1 = np.sin(np.linspace(0,2*math.pi,128) * 20)\n",
    "context2 = np.sin(np.linspace(0,2*math.pi,128) * 4)\n",
    "context = np.concatenate((context1, context2, context1, context2, context1, context2))\n",
    "\n",
    "groundtruth = context[-64:]\n",
    "context = context[:-64]\n",
    "context = context[-512:]\n",
    "plt.plot(context)\n",
    "plt.plot(context.shape[0] + np.arange(groundtruth.shape[0]), groundtruth)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea844101-e43a-44b4-af20-3731a0c01770",
   "metadata": {},
   "outputs": [],
   "source": [
    "attention_inputs = defaultdict(list)\n",
    "def get_hook(layer_id):\n",
    "    def hook(module, input, output):\n",
    "        attention_inputs[layer_id].append(input[0].detach().cpu())\n",
    "    return hook\n",
    "\n",
    "for idx, block in enumerate(bolt.model.encoder.block):\n",
    "    attention_layer = block.layer[0].SelfAttention\n",
    "    attention_layer.register_forward_hook(get_hook(idx))\n",
    "\n",
    "output_bolt_16 = bolt.predict(\n",
    "                torch.from_numpy(context),\n",
    "                prediction_length=64,\n",
    "            ).numpy()\n",
    "\n",
    "for idx, inputs in attention_inputs.items():\n",
    "    in_bolt = inputs[0].numpy()\n",
    "    in_bolt_16 = in_bolt[0,:,:]\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "796f3ff8-2cc9-4f53-a792-5685ddb109db",
   "metadata": {},
   "outputs": [],
   "source": [
    "attention_inputs = defaultdict(list)\n",
    "def get_hook(layer_id):\n",
    "    def hook(module, input, output):\n",
    "        attention_inputs[layer_id].append(input[0].detach().cpu())\n",
    "    return hook\n",
    "\n",
    "for idx, block in enumerate(bolt_p1.model.encoder.block):\n",
    "    attention_layer = block.layer[0].SelfAttention\n",
    "    attention_layer.register_forward_hook(get_hook(idx))\n",
    "\n",
    "output_bolt_1 = bolt_p1.predict(\n",
    "                torch.from_numpy(context),\n",
    "                prediction_length=64,\n",
    "            ).numpy()\n",
    "\n",
    "for idx, inputs in attention_inputs.items():\n",
    "    in_bolt = inputs[0].numpy()\n",
    "    in_bolt_1 = in_bolt[0,:,:]\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ae2bd83-f96e-497c-8ff1-fbc144481352",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "fig, ax = plot_bolt_output(output_bolt_1, context, groundtruth)\n",
    "ax.set_title(\"Bolt's Forecast (Patch Size = 1)\")\n",
    "ax.set_xlabel(\"time step\")\n",
    "ax.set_ylabel(\"value\")\n",
    "ax.set_xlim(512-64,512+64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bdc18d9-9bde-4865-a64d-7fc6c4f0318d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "fig, ax = plot_bolt_output(output_bolt_16, context, groundtruth)\n",
    "ax.set_title(\"Bolt's Forecast (Patch Size = 16)\")\n",
    "ax.set_xlabel(\"time step\")\n",
    "ax.set_ylabel(\"value\")\n",
    "ax.set_xlim(512-64,512+64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc2929cd-e1ae-4696-b4af-74ee2800f3f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "distance_1 = compute_angles(in_bolt_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29234655-3cf9-44f9-9670-b41b533b231a",
   "metadata": {},
   "outputs": [],
   "source": [
    "distance_16 = compute_angles(in_bolt_16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1da7ca45-cb50-445e-9501-73a6f24f1178",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plot_angles(distance_1[:-1,:-1], context, \"Bolt, Patch Size = 1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7c5959b-e3b3-404d-b1e3-7bb204f00bf4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plot_angles(np.repeat(np.repeat(distance_16[:-1,:-1], 16, axis=0), 16, axis=1), context, \"Bolt, Patch Size = 16\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa918e6e-6f8d-4329-a853-c0ecc9202e58",
   "metadata": {},
   "source": [
    "## Finer grained analysis of the superpositioned case"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cf8c041-527d-4347-ae43-77ac64ab94d2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "samplers = np.arange(512+64) / 512 * 2 * math.pi + math.pi / 10\n",
    "context = np.sin(samplers * 8) + np.sin(samplers * 110)\n",
    "context_low = np.sin(samplers * 7)[:-64]\n",
    "context_high = np.sin(samplers * 137)[:-64]\n",
    "groundtruth = context[-64:]\n",
    "context = context[:-64]\n",
    "plt.plot(context)\n",
    "plt.plot(context.shape[0] + np.arange(groundtruth.shape[0]), groundtruth)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1affa4f-747a-4a8e-af80-30a64c1680f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "attention_inputs = defaultdict(list)\n",
    "def get_hook(layer_id):\n",
    "    def hook(module, input, output):\n",
    "        attention_inputs[layer_id].append(input[0].detach().cpu())\n",
    "    return hook\n",
    "\n",
    "for idx, block in enumerate(bolt.model.encoder.block):\n",
    "    attention_layer = block.layer[0].SelfAttention\n",
    "    attention_layer.register_forward_hook(get_hook(idx))\n",
    "\n",
    "_ = bolt.predict(\n",
    "                torch.from_numpy(context_low),\n",
    "                prediction_length=64,\n",
    "            ).numpy()\n",
    "\n",
    "for idx, inputs in attention_inputs.items():\n",
    "    in_bolt = inputs[0].numpy()\n",
    "    in_bolt_low = in_bolt[0,:,:]\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5b72ff4-1418-44bf-b4d1-9a4d447611ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "attention_inputs = defaultdict(list)\n",
    "def get_hook(layer_id):\n",
    "    def hook(module, input, output):\n",
    "        attention_inputs[layer_id].append(input[0].detach().cpu())\n",
    "    return hook\n",
    "\n",
    "for idx, block in enumerate(bolt.model.encoder.block):\n",
    "    attention_layer = block.layer[0].SelfAttention\n",
    "    attention_layer.register_forward_hook(get_hook(idx))\n",
    "\n",
    "output_bolt_16 = bolt.predict(\n",
    "                torch.from_numpy(context),\n",
    "                prediction_length=64,\n",
    "            ).numpy()\n",
    "\n",
    "for idx, inputs in attention_inputs.items():\n",
    "    in_bolt = inputs[0].numpy()\n",
    "    in_bolt_16 = in_bolt[0,:,:]\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79f310ab-d9e2-48cd-8722-60d9672875d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "attention_inputs = defaultdict(list)\n",
    "def get_hook(layer_id):\n",
    "    def hook(module, input, output):\n",
    "        attention_inputs[layer_id].append(input[0].detach().cpu())\n",
    "    return hook\n",
    "\n",
    "for idx, block in enumerate(bolt_p1.model.encoder.block):\n",
    "    attention_layer = block.layer[0].SelfAttention\n",
    "    attention_layer.register_forward_hook(get_hook(idx))\n",
    "\n",
    "output_bolt_1 = bolt_p1.predict(\n",
    "                torch.from_numpy(context),\n",
    "                prediction_length=64,\n",
    "            ).numpy()\n",
    "\n",
    "for idx, inputs in attention_inputs.items():\n",
    "    in_bolt = inputs[0].numpy()\n",
    "    in_bolt_1 = in_bolt[0,:,:]\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6723fe5-5036-4f51-aef8-84605149b402",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "fig, ax = plot_bolt_output(output_bolt_1, context, groundtruth)\n",
    "ax.set_title(\"Bolt's Forecast (Patch Size = 1)\")\n",
    "ax.set_xlabel(\"time step\")\n",
    "ax.set_ylabel(\"value\")\n",
    "ax.set_xlim(512-64,512+64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3943a62b-95f8-4329-89e4-d02e0fb8eea4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "fig, ax = plot_bolt_output(output_bolt_16, context, groundtruth)\n",
    "ax.set_title(\"Bolt's Forecast (Patch Size = 16)\")\n",
    "ax.set_xlabel(\"time step\")\n",
    "ax.set_ylabel(\"value\")\n",
    "ax.set_xlim(512-64,512+64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c937829b-0aa7-4ae6-a5b3-0502c0bc670a",
   "metadata": {},
   "outputs": [],
   "source": [
    "distance_1 = compute_angles(in_bolt_1)\n",
    "distance_16 = compute_angles(in_bolt_16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22081827-1a30-4e9d-b743-b5c7ccae7646",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plot_angles(distance_1[:-1,:-1], context, \"Bolt, Patch Size = 1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b602e571-ae9f-43ef-ba11-88340d03af84",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plot_angles(np.repeat(np.repeat(distance_16[:-1,:-1], 16, axis=0), 16, axis=1), context, \"Bolt, Patch Size = 16\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cc22116-7d8f-4def-85a2-9fd50f781d8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "context1 = np.sin(samplers * 8)\n",
    "context2 = np.sin(samplers * 110)\n",
    "plt.plot(context1[:16])\n",
    "plt.plot(context2[:16])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "688f4059-b119-49b3-a127-369acf0a7229",
   "metadata": {},
   "outputs": [],
   "source": [
    "W1 = bolt.model.input_patch_embedding.hidden_layer.weight.detach().numpy()[:,:16]\n",
    "b1 = bolt.model.input_patch_embedding.hidden_layer.bias.detach().numpy()\n",
    "W2 = bolt.model.input_patch_embedding.output_layer.weight.detach().numpy()\n",
    "b2 = bolt.model.input_patch_embedding.output_layer.bias.detach().numpy()\n",
    "W3 = bolt.model.input_patch_embedding.residual_layer.weight.detach().numpy()[:,:16]\n",
    "b3 = bolt.model.input_patch_embedding.residual_layer.bias.detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68d21531-0936-438f-8366-79d28f675551",
   "metadata": {},
   "outputs": [],
   "source": [
    "def np_relu(x):\n",
    "    return x * (x > 0)\n",
    "\n",
    "def manual_embed(patch):\n",
    "    return W3 @ patch + W2 @ np_relu(W1 @ patch + b1) + b2 + b3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea96a087-392c-459e-a56f-91f431cb9f4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "distance_low = compute_angles(in_bolt_low)\n",
    "distance_high = compute_angles(in_bolt_high)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4eaa3e3d-76ff-4503-81fa-160eea3ea21a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_angles(np.repeat(np.repeat(distance_low[:-1,:-1], 16, axis=0), 16, axis=1), context_low, \"Bolt, Patch Size = 16\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "795d6d6b-0e6f-4409-ab8d-ed231c588b31",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_angles(np.repeat(np.repeat(distance_high[:-1,:-1], 16, axis=0), 16, axis=1), context_high, \"Bolt, Patch Size = 16\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a43c3a9b-14e8-4123-bc05-02e938e4c19c",
   "metadata": {},
   "source": [
    "## One more experiment to corroborate"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a331063-d1fe-49da-910e-0564c3c30531",
   "metadata": {},
   "source": [
    "### 1. Better aligned signals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6171a05-9b2f-4de0-95bc-8f41868029b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "samplers = np.arange(512+64) / 512 * 2 * math.pi + math.pi / 10\n",
    "context = np.sin(samplers * 8) + np.sin(samplers * 128)\n",
    "groundtruth = context[-64:]\n",
    "context = context[:-64]\n",
    "plt.plot(context)\n",
    "plt.plot(context.shape[0] + np.arange(groundtruth.shape[0]), groundtruth)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77d55713-f12c-4b02-9766-33feb518ed08",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_bolt_16 = bolt.predict(\n",
    "                torch.from_numpy(context),\n",
    "                prediction_length=64,\n",
    "            ).numpy()\n",
    "fig, ax = plot_bolt_output(output_bolt_16, context, groundtruth)\n",
    "ax.set_title(\"Bolt's Forecast (Patch Size = 16)\")\n",
    "ax.set_xlabel(\"time step\")\n",
    "ax.set_ylabel(\"value\")\n",
    "ax.set_xlim(512-64,512+64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28aa640f-85e5-46fd-a9ec-10fdb767ddb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(context[:200])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
