{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 2/5000 [00:00<04:32, 18.34it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====================\n",
      "75450\n",
      "Combined loss:0: tensor(26.1243, grad_fn=<AddBackward0>)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 107/5000 [00:03<02:34, 31.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Combined loss:100: tensor(17.0094, grad_fn=<AddBackward0>)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|▍         | 207/5000 [00:06<02:16, 35.22it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Combined loss:200: tensor(7.8833, grad_fn=<AddBackward0>)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|▍         | 219/5000 [00:06<02:29, 32.01it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-1-750150580587>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m    161\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    162\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 163\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1500\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-1-750150580587>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(hidden_size, num_layers)\u001b[0m\n\u001b[1;32m    131\u001b[0m             \u001b[0mloss\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    132\u001b[0m             \u001b[0mstart\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtimeit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdefault_timer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 133\u001b[0;31m             \u001b[0meval_nl\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlinear_val_ode2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtt_tors\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mt_d\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    134\u001b[0m             \"\"\"\n\u001b[1;32m    135\u001b[0m             \u001b[0;32mfor\u001b[0m \u001b[0mj\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxy_d_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-1-750150580587>\u001b[0m in \u001b[0;36mlinear_val_ode2\u001b[0;34m(init_v, t_d)\u001b[0m\n\u001b[1;32m     85\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mlinear_val_ode2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minit_v\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mt_d\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     86\u001b[0m             \u001b[0minit_v_in\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrev_inn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minit_v\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 87\u001b[0;31m             \u001b[0meval_lin\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorchdiffeq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0modeint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0minit_v_in\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mt_d\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mmethod\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'euler'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;31m#options={'step_size':0.01}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     88\u001b[0m             \u001b[0meval_out\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfor_inn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meval_lin\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     89\u001b[0m             \u001b[0;32mreturn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meval_out\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/torchdiffeq/_impl/odeint.py\u001b[0m in \u001b[0;36modeint\u001b[0;34m(func, y0, t, rtol, atol, method, options)\u001b[0m\n\u001b[1;32m     63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     64\u001b[0m     \u001b[0msolver\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mSOLVERS\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mmethod\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my0\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrtol\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrtol\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0matol\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0matol\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0moptions\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 65\u001b[0;31m     \u001b[0msolution\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msolver\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mintegrate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     66\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     67\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mshapes\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/torchdiffeq/_impl/solvers.py\u001b[0m in \u001b[0;36mintegrate\u001b[0;34m(self, t)\u001b[0m\n\u001b[1;32m     85\u001b[0m         \u001b[0my0\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0my0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     86\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mt0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt1\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtime_grid\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime_grid\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 87\u001b[0;31m             \u001b[0mdy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_step_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mt0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     88\u001b[0m             \u001b[0my1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my0\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mdy\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     89\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/torchdiffeq/_impl/fixed_grid.py\u001b[0m in \u001b[0;36m_step_func\u001b[0;34m(self, func, t, dt, y)\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_step_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mdt\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as pl\n",
    "import torchdiffeq\n",
    "import torch.nn as n\n",
    "import FrEIA.framework as Ff\n",
    "import FrEIA.modules as Fm\n",
    "\n",
    "device = 'cpu'\n",
    "\n",
    "torch.manual_seed(0)\n",
    "import random\n",
    "random.seed(0)\n",
    "np.random.seed(0)\n",
    "xx=torch.tensor([0.,0.,0.])\n",
    "tt=torch.tensor([0.,5.])\n",
    "t_d=torch.linspace(0,2,80)\n",
    "\n",
    "\n",
    "device_str='cpu'\n",
    "def test_fun(t,x):\n",
    "    sig=10.\n",
    "    rho=28.\n",
    "    beta=8/3\n",
    "    vel=torch.zeros((3,1))\n",
    "    vel[0]=sig*(x[1]-x[0])\n",
    "    vel[1]=x[0]*(rho-x[2])-x[1]\n",
    "    vel[2]=x[0]*x[1]-beta*x[2]\n",
    "    return(vel)\n",
    "\n",
    "tt_tors=torch.tensor([[.15,.15,.15,0.,0.,0.]])\n",
    "xy_d_list=[]\n",
    "for i in range(len(tt_tors)):\n",
    "    noise_c=torch.normal(torch.zeros((len(t_d),3)),0.01*torch.ones((len(t_d),3)))\n",
    "    xy_d=torchdiffeq.odeint(test_fun,tt_tors[i,:3][None].T,t_d).reshape((-1,3))+noise_c\n",
    "    xy_d_list.append(xy_d.clone().detach())\n",
    "\n",
    "\n",
    "    \n",
    "def loss_er(x_pred,x_gt):\n",
    "    mse_l=torch.norm(x_pred-x_gt,dim=1)\n",
    "    sum_c=0.0\n",
    "    for i in range(len(mse_l)):\n",
    "        sum_c+=(mse_l[i])#*(1/float(i+1))\n",
    "    return(sum_c/len(mse_l))\n",
    "\n",
    "    \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "print('='*20)\n",
    "#print(f'hidden_size={hidden_size}, num_layers={num_layers}')\n",
    "global tt_tors, xy_d_list, t_d\n",
    "seed = 2\n",
    "torch.manual_seed(seed)\n",
    "import random\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "f_x=n.Sequential(\n",
    "n.Linear(6, 30),\n",
    "n.Tanh(),\n",
    "n.Linear(30, 30),\n",
    "n.Tanh(),\n",
    "n.Linear(30, 30),\n",
    "n.Tanh(),\n",
    "n.Linear(30, 6)\n",
    ")\n",
    "def fx(t,x):\n",
    "    return(f_x(x))\n",
    "\n",
    "def for_inn(x):\n",
    "    return(inn(x)[0])\n",
    "def rev_inn(x):\n",
    "    return(inn(x,rev=True)[0])\n",
    "def rev_mse_inn_eig(rf,x_gt):\n",
    "    return(torch.mean(torch.norm(rf-x_gt,dim=1)))\n",
    "def linear_val_ode(w_vec,init_v,t_d):\n",
    "    init_v_in=rev_inn(init_v)\n",
    "    eval_lin=eigen_ode__(w_vec,init_v_in,t_d)\n",
    "    ori_shape = eval_lin.shape\n",
    "    eval_out=for_inn(eval_lin.reshape(-1, eval_lin.shape[-1]))\n",
    "    return(eval_out.reshape(ori_shape))\n",
    "def linear_val_ode2(init_v,t_d):\n",
    "    init_v_in=rev_inn(init_v)\n",
    "    eval_lin=torchdiffeq.odeint(fx,init_v_in,t_d,method='euler')[:,0,:]#options={'step_size':0.01}\n",
    "    eval_out=for_inn(eval_lin)\n",
    "    return(eval_out)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "hidden_size=1500\n",
    "num_layers=5\n",
    "\n",
    "N_DIM = 6\n",
    "def subnet_fc(dims_in, dims_out):\n",
    "    return n.Sequential(n.Linear(dims_in, hidden_size), n.ReLU(),\n",
    "                         n.Linear(hidden_size, dims_out))\n",
    "\n",
    "inn = Ff.SequenceINN(N_DIM)\n",
    "for k in range(num_layers):\n",
    "    inn.append(Fm.AllInOneBlock, subnet_constructor=subnet_fc,permute_soft=True)\n",
    "\n",
    "\n",
    "optimizer_comb = torch.optim.Adam( \n",
    "[{'params': f_x.parameters(),'lr': 0.0001},{'params': inn.parameters(), \n",
    "                    'lr': 0.0001}])\n",
    "print(sum(p.numel() for p in inn.parameters()))\n",
    "\n",
    "startings=tt_tors.clone().detach()\n",
    "\n",
    "#Training loop\n",
    "import timeit\n",
    "epoch_time=[]\n",
    "import tqdm\n",
    "from tqdm import trange\n",
    "\n",
    "tt_tors = tt_tors.to(device)\n",
    "#t_d = t_d.to(device)\n",
    "inn.to(device)\n",
    "t_d = t_d.to(device)\n",
    "xy_d_list = torch.stack(xy_d_list).to(device)\n",
    "\n",
    "\n",
    "\n",
    "#for i in trange(0, 5000):\n",
    "for i in trange(0, 5000):\n",
    "    optimizer_comb.zero_grad()\n",
    "    loss=0.0\n",
    "    start = timeit.default_timer()\n",
    "    eval_nl=linear_val_ode2(tt_tors,t_d)\n",
    "    \"\"\"\n",
    "    for j in range(len(xy_d_list)):\n",
    "        eval_nl=linear_val_ode(w_vec,tt_tors[j],t_d)\n",
    "\n",
    "        #torchdiffeq.odeint(fx,\n",
    "        #                   tt_tors[j][None],t_d_test_tt,atol=1e-2,#,rtol=1e-5,\n",
    "        #                   method='euler')[:,0,:]\n",
    "\n",
    "        #loss_cur = rev_mse_inn_eig(eval_nl[:,:3],xy_d_list[j])\n",
    "        #loss+=loss_cur\n",
    "    \"\"\"\n",
    "    loss_cur = torch.mean(torch.norm(eval_nl[:,:3]-xy_d,dim=1))\n",
    "    loss+=loss_cur\n",
    "\n",
    "    loss.backward()\n",
    "    optimizer_comb.step()\n",
    "    end = timeit.default_timer()\n",
    "    epoch_time.append(end-start)\n",
    "    if(i%100==0):\n",
    "        print('Combined loss:'+str(i)+': '+str(loss))\n",
    "ep_time=np.array(epoch_time)\n",
    "#print(f'mean train time:{ep_time.mean():.3f} {ep_time.std():.3f}')\n",
    "#print(f'total: {ep_time.sum() / run_for * target:.2f}')\n",
    "torch.save(f_x.state_dict(),'f_x_base_save_good_eod2.tar')\n",
    "torch.save(inn.state_dict(),'inn2_save_good_eod2.tar')\n",
    "        \n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'linear_val_ode2' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-5-9ad1809fce1c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0meval_nl\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlinear_val_ode2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtt_tors\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mt_d\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0mpl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxy_d\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0mpl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meval_nl\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'r'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'linear_val_ode2' is not defined"
     ]
    }
   ],
   "source": [
    "eval_nl=linear_val_ode2(tt_tors,t_d)\n",
    "pl.plot(xy_d.detach())\n",
    "pl.plot(eval_nl[:,:3].detach(),c='r')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([80, 3])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xy_d.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
