{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import itertools\n",
    "import numpy as np\n",
    "from sympy import *\n",
    "import torch\n",
    "\n",
    "k = 2\n",
    "seq_len = 20\n",
    "ind = range(0,seq_len,1)\n",
    "comb = []\n",
    "for comb_length in range(2,k+1,1):\n",
    "    # compute all combination of ind:\n",
    "    comb.extend(list(itertools.combinations(ind,comb_length)))\n",
    "\n",
    "#print(comb)\n",
    "\n",
    "# fix the seed torch:\n",
    "torch.manual_seed(0)\n",
    "# create a random torch tensor of shape (3, seq_len):\n",
    "u_f = torch.randn(3, seq_len)\n",
    "\n",
    "# create an empty torch tensor of length len(comb) with first dimention the same as u_f:\n",
    "u_f_corr = torch.zeros(u_f.shape[0], len(comb))\n",
    "for i in range (0,len(comb),1):\n",
    "    u_f_corr[..., i] = u_f[..., comb[i][0]] * u_f[..., comb[i][1]]   \n",
    "\n",
    "\n",
    "# flip u_f_corr along its last dimension:\n",
    "u_f_corr_flip = torch.flip(u_f_corr, [-1])\n",
    "\n",
    "# A, B, u0, u1, u2, u3, u4, u5, u6, u7, u8 = symbols('A B u0 u1 u2 u3 u4 u5 u6 u7 u8')\n",
    "\n",
    "# expr0 = B * u0\n",
    "\n",
    "# expr1 = B * u1  + (A + B * u1) * expr0\n",
    "# xx = expand(expr1)\n",
    "\n",
    "# expr2 = B * u2  + (A + B * u2) * xx\n",
    "# expand(expr2)\n",
    "\n",
    "# # expr3 = B * u3  + (A + B * u3) * xx\n",
    "# # xx = expand(expr3)\n",
    "\n",
    "# # expr4 = B * u4  + (A + B * u4) * xx\n",
    "# # xx = expand(expr4)\n",
    "\n",
    "# # expr5 = B * u5  + (A + B * u5) * xx\n",
    "# # xx = expand(expr5)\n",
    "\n",
    "# # expr6 = B * u6  + (A + B * u6) * xx\n",
    "# # xx = expand(expr6)\n",
    "\n",
    "# # expr7 = B * u7  + (A + B * u7) * xx\n",
    "# # xx = expand(expr7)\n",
    "\n",
    "# # expr8 = B * u8  + (A + B * u8) * xx\n",
    "# # expand(expr8)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Adding path:  /home/ramin/state-spaces/src/models/sequence/ss\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ramin/miniconda3/envs/state-space/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "/tmp/ipykernel_701559/1005520754.py:30: DeprecationWarning: The 'warn' method is deprecated, use 'warning' instead\n",
      "  log.warn(\n",
      "CUDA extension for cauchy multiplication not found. Install by going to extensions/cauchy/ and running `python setup.py install`. This should speed up end-to-end training by 10-50%\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import pathlib\n",
    "\n",
    "p = pathlib.Path().absolute()\n",
    "print(\"Adding path: \", p)\n",
    "sys.path.append(str(p))\n",
    "sys.path.append(str('/home/ramin/state-spaces/'))\n",
    "\n",
    "import math\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import scipy.fft\n",
    "from einops import rearrange, repeat\n",
    "from opt_einsum import contract, contract_expression\n",
    "from omegaconf import DictConfig\n",
    "\n",
    "import src.models.hippo.hippo as hippo\n",
    "from src.models.functional.krylov import krylov, power\n",
    "\n",
    "import src.utils.train\n",
    "\n",
    "log = src.utils.train.get_logger()\n",
    "\n",
    "try:\n",
    "    from extensions.cauchy.cauchy import cauchy_mult\n",
    "    has_cauchy_extension = True\n",
    "except:\n",
    "    log.warn(\n",
    "        \"CUDA extension for cauchy multiplication not found. Install by going to extensions/cauchy/ and running `python setup.py install`. This should speed up end-to-end training by 10-50%\"\n",
    "    )\n",
    "    has_cauchy_extension = False\n",
    "\n",
    "try:\n",
    "    import pykeops\n",
    "    from src.models.functional.cauchy import cauchy_conj\n",
    "    has_pykeops = True\n",
    "except ImportError:\n",
    "    has_pykeops = False\n",
    "    from src.models.functional.cauchy import cauchy_slow\n",
    "    if not has_cauchy_extension:\n",
    "        log.error(\n",
    "            \"Falling back on slow Cauchy kernel. Install at least one of pykeops or the CUDA extension for efficiency.\"\n",
    "        )\n",
    "\n",
    "_isnan = lambda x: torch.isnan(x).any()\n",
    "_isinf = lambda x: torch.isinf(x).any()\n",
    "\n",
    "# conj creates a vector tensor of x and its conjugate:\n",
    "_conj = lambda x: torch.cat([x, x.conj()], dim=-1)\n",
    "\n",
    "_c2r = torch.view_as_real\n",
    "_r2c = torch.view_as_complex\n",
    "\n",
    "# resolve_conj takes the conjucate of a given complex tensor:\n",
    "if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 10):\n",
    "    _resolve_conj = lambda x: x.conj().resolve_conj()\n",
    "else:\n",
    "    _resolve_conj = lambda x: x.conj()\n",
    "\n",
    "def bilinear(dt, A, B=None):\n",
    "    \"\"\"\n",
    "    dt: (...) timescales\n",
    "    A: (... N N)\n",
    "    B: (... N)\n",
    "    \"\"\"\n",
    "    N = A.shape[-1]\n",
    "    I = torch.eye(N).to(A)\n",
    "    A_backwards = I - dt[:, None, None] / 2 * A\n",
    "    A_forwards = I + dt[:, None, None] / 2 * A\n",
    "\n",
    "    if B is None:\n",
    "        dB = None\n",
    "    else:\n",
    "        dB = dt[..., None] * torch.linalg.solve(\n",
    "            A_backwards, B.unsqueeze(-1)\n",
    "        ).squeeze(-1) # (... N)\n",
    "\n",
    "    dA = torch.linalg.solve(A_backwards, A_forwards)  # (... N N)\n",
    "    return dA, dB\n",
    "\n",
    "# fix the seed \n",
    "torch.manual_seed(0)\n",
    "dt = 0.0001 * torch.randn(10)\n",
    "# create a 100 x 100 torch tensor:\n",
    "A = torch.randn(10, 10)\n",
    "B = torch.randn(10)\n",
    "\n",
    "dA, dB = bilinear(dt, A, B)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "4f1e135a9fa02f708c5c5958d74a0ccbece3238f1d8aa5b69d3d654254c2197d"
  },
  "kernelspec": {
   "display_name": "Python 3.8.8 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
