{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "2203f021",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69678a79",
   "metadata": {},
   "source": [
    "# Scaling Up Forecast Horizon"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e87a4b74",
   "metadata": {},
   "source": [
    "Most deep forecasting models work as a **settled function** once trained. For scenarios where the prediction horizon is mismatched or long-term. The most common practice is rolling forecast."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "954c4aac",
   "metadata": {},
   "source": [
    "But one of the essential characteristics of non-stationary data is that the distribution changes significantly over time. Thus the forecasting error can be large at incoming time points, otherwise the model should be continously retrained. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "afb51505",
   "metadata": {},
   "source": [
    "Therefore, it poses two challenges for current model application pipeline: \n",
    "- (1) reuse parameters learned from observed series; \n",
    "- (2) utilize incoming ground truth for model adaptation. \n",
    " \n",
    "The practical scenarios, which we name as scaling up forecast horizon, may lead to failure on most deep models but can be naturally tackled by Koopa."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b67a52d",
   "metadata": {},
   "source": [
    "<p align=\"center\">\n",
    "<img src=\"./figures/adapt_demo.png\" height = \"150\" alt=\"\" align=center />\n",
    "</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bef99148",
   "metadata": {},
   "source": [
    "## Operator Adaptation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9aa48a16",
   "metadata": {},
   "source": [
    "In detail, we train a Koopa model with an initial forecast length and attempt to apply it on a larger one. \n",
    "\n",
    "The basic approach conducts rolling forecast by taking the model prediction as the input of the next iteration until the desired forecast horizon is all filled.\n",
    "\n",
    "Instead, we assume (commonly in real situations) that after the model gives a prediction, the model can utilize the incoming ground truth for **operator adaptation** and continue rolling forecast for the next iteration. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "ceb20791-33ce-496e-b5fb-2851b28ecdc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def operator_adaptation(Z, incoming_list):\n",
    "    B, G, D = incoming_list.shape\n",
    "    Zp, Zf = Z[:, :-1], Z[:, 1:]\n",
    "    # the same as torch.linalg.pinv(Zp) @ Zf\n",
    "    K = torch.linalg.lstsq(Zp, Zf).solution \n",
    "    n, pred_list = Z[:, -1:], [Z[:, -1:] @ K]\n",
    "    # K_list = [K]\n",
    "\n",
    "    for i in range(G):\n",
    "        m, n = n, incoming_list[:, i].unsqueeze(1)\n",
    "        Zp = torch.cat((Zp, m), dim=1)\n",
    "        Zf = torch.cat((Zf, n), dim=1)\n",
    "        K = torch.linalg.lstsq(Zp, Zf).solution\n",
    "        pred_list.append(n @ K)\n",
    "        # K_list.append(K)\n",
    "        \n",
    "    return torch.concat(pred_list, dim=1) # ,torch.stack(K_list, dim=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a4a4611",
   "metadata": {},
   "source": [
    "The naive implementation shown as the above repeatedly conducts DMD on the incremental embedding collection to obtain new operators."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "44c65f65",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([64, 11, 1024])\n"
     ]
    }
   ],
   "source": [
    "B = 64       # Batch size\n",
    "F = 10       # Number of observed snappshots\n",
    "D = 1024     # Dimension of Koopman embedding\n",
    "G = 10       # Number of successively incoming snapshots\n",
    "\n",
    "Z = torch.randn(B, F, D)                # Observed time points\n",
    "incoming_list = torch.randn(B, G, D)    # Successively incoming time points\n",
    "\n",
    "pred = operator_adaptation(Z, incoming_list) # Predicted time points (unseen the previous groundtruth)\n",
    "print(pred.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3c99262",
   "metadata": {},
   "source": [
    " It is notable that we do not retrain parameters during model adaptation, since it will lead to overfitting on the incoming ground truth and **Catastrophic Forgetting**."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c76ea014",
   "metadata": {},
   "source": [
    "This algorithm (Koopa OA) can make the model adapt to the change of data distribution without training, and further improve the prediction effect, especially in the **non-stationary time series**."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a1bc322",
   "metadata": {},
   "source": [
    "<p align=\"center\">\n",
    "<img src=\"./figures/adaptation.png\" alt=\"\" align=center />\n",
    "</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "221b8d61",
   "metadata": {},
   "source": [
    "## Computational Acceleration"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07d35fa5",
   "metadata": {},
   "source": [
    "The naive implementation has a complexity of $\\mathcal{O}(LD^3)$, where $L$ is the number of incoming time points and $D$ is the dimension of Koopman embedding. Based on the linearity of Koopman operators, we propose an **equivalent algorithm with improved complexity** of $\\mathcal{O}((L+D)D^2)$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "3e470c37-39a0-4d48-a3d4-54707755539f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def operator_adaptation_accelerate(Z, incoming_list):\n",
    "    B, G, D = incoming_list.shape\n",
    "    Zp, Zf = Z[:, :-1], Z[:, 1:]\n",
    "    Zp_inv = torch.linalg.pinv(Zp)\n",
    "    K = Zp_inv @ Zf\n",
    "    X = Zp_inv @ Zp\n",
    "    n, pred_list = Z[:, -1:], [Z[:, -1:] @ K]\n",
    "    # K_list = [K]\n",
    "\n",
    "    for i in range(G):\n",
    "        m, n = n, incoming_list[:, i].unsqueeze(1)\n",
    "        mt = m.transpose(1, 2)\n",
    "        r = mt - X.transpose(1, 2) @ mt\n",
    "        b = r / r.square().sum(dim=1, keepdim=True)\n",
    "        K = K - b @ (m @ K - n)\n",
    "        X = X - b @ (m @ X - m)\n",
    "        pred_list.append(n @ K)\n",
    "        # K_list.append(K)\n",
    "        \n",
    "    return torch.concat(pred_list, dim=1) # ,torch.stack(K_list, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "21312398-bbbf-4f49-b51d-d7324f387084",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Timer(object):\n",
    "    def __enter__(self):\n",
    "        self.t0 = time.time()\n",
    "\n",
    "    def __exit__(self, exc_type, exc_val, exc_tb):\n",
    "        print('[time spent: {time:.2f}s]'.format(time = time.time() - self.t0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "97c3484b-2a53-4432-b83c-676fe44ca9ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "B = 64       # Batch size\n",
    "F = 10       # Number of observed snappshots\n",
    "D = 1024     # Dimension of Koopman embedding\n",
    "G = 10       # Number of successively incoming snapshots\n",
    "\n",
    "Z = torch.randn(B, F, D)                # Observed time points\n",
    "incoming_list = torch.randn(B, G, D)    # Successively incoming time points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "6bece2d8-9ca5-4369-9dff-ed268dfc76c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[time spent: 5.59s]\n"
     ]
    }
   ],
   "source": [
    "# Algorithm 1\n",
    "with Timer():\n",
    "    pred1 = operator_adaptation(Z, incoming_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "f2683a1e-6cb3-4dc9-a337-aed85abd1b42",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[time spent: 2.68s]\n"
     ]
    }
   ],
   "source": [
    "# Algorithm 2\n",
    "with Timer():\n",
    "    pred2 = operator_adaptation_accelerate(Z, incoming_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "44c315e0-2eee-49d0-bfcc-705415401ebd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# equivalent results\n",
    "assert torch.norm(pred1-pred2) < 1e-3"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ddfa705d",
   "metadata": {},
   "source": [
    "Notably, we propsoe the accelerated version of operator adaptation **involving length mismatches and long-term dynamics forecast scenarios**. In the realm of the Koopman method, the long-term applicability are exceptional."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
