{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "b5e41b13-3e40-4198-9249-d88192686f8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.linalg import expm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "a0e47461-34d2-4f3a-8c99-1e5eb946cf95",
   "metadata": {},
   "outputs": [],
   "source": [
    "m = 1\n",
    "eta = 0.1\n",
    "T = 10\n",
    "dt = 0.1\n",
    "A = np.array([[0,0,1,0],[0,0,0,1],[0,0,-eta/m,0],[0,0,0,-eta/m]])\n",
    "B = np.array([[0,0],[0,0],[1/m,0],[0,1/m]])\n",
    "Sig = np.array([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]])\n",
    "c = np.array([1,1,1,1]).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "06692bba-818e-4553-9b6d-745433fb555c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#drift and its derivative in theta\n",
    "\n",
    "def mu(x,theta): \n",
    "    return A @ x + B@theta @ x\n",
    "def del_theta_mu(x,theta,i,j): \n",
    "    theta[i,j]*x[j]*B[:,i]\n",
    "\n",
    "#loss and its derivative in theta and x\n",
    "def del_theta_r(x,i,j):\n",
    "    return 2*(theta@x)[i]*x[j]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "0618e1c2-e1bb-4ba0-8980-94cf7114db7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# simulate del_x v(t,x)\n",
    "def simulate_x(x,theta,ts):\n",
    "    ret = ()\n",
    "    G = np.random.normal(0,np.sqrt(dt),[len(x),int(max(ts)/dt+1)])\n",
    "    time = 0\n",
    "    for t in np.arange(0,max(ts)+dt,dt):\n",
    "        if t in ts: \n",
    "            temp_count = ts.count(t)\n",
    "            while temp_count > 0:\n",
    "                ret = ret+(x,)\n",
    "                temp_count -=1\n",
    "        x = x + A @ x + B @ theta @ x + Sig @ G[:,time:time+1]\n",
    "        time += 1\n",
    "    print( ret)\n",
    "    return ret\n",
    "\n",
    "def simulate_del_x_v(t,theta,x):\n",
    "    tau = np.random.choice(np.arange(0,T-t+dt,dt))\n",
    "    Xtau,XT = simulate_x(x,theta,[tau,T-t])\n",
    "    DX_tau = expm(tau*(A+B@theta))\n",
    "    print(DX_tau)\n",
    "    DX_T_t = expm((T-t)*(A+B@theta))\n",
    "    return (T-t)*2*(theta.T@theta@ Xtau).T @ DX_tau + 2*(XT-c).T @ DX_T_t\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "id": "53552cf2-b8bd-40e6-8026-eb93256c5e61",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(array([[-1.44378779e+10],\n",
      "       [-1.69074348e+10],\n",
      "       [-1.37340201e+10],\n",
      "       [-1.60831841e+10]]), array([[-8.57559040e+22],\n",
      "       [-1.00424201e+23],\n",
      "       [-8.15752368e+22],\n",
      "       [-9.55284432e+22]]))\n",
      "[[16.13111562  0.         15.3230985   0.        ]\n",
      " [ 0.         16.13111562  0.         15.3230985 ]\n",
      " [15.3230985   0.         14.59880577  0.        ]\n",
      " [ 0.         15.3230985   0.         14.59880577]]\n",
      "[[-3.46156365e+26 -4.05365402e+26 -3.29280972e+26 -3.85603522e+26]\n",
      " [-3.46156365e+26 -4.05365402e+26 -3.29280972e+26 -3.85603522e+26]\n",
      " [-3.46156365e+26 -4.05365402e+26 -3.29280972e+26 -3.85603522e+26]\n",
      " [-3.46156365e+26 -4.05365402e+26 -3.29280972e+26 -3.85603522e+26]]\n"
     ]
    }
   ],
   "source": [
    "theta = np.array([[1,0,0,0],[0,1,0,0]])\n",
    "x_0 = np.array([[0,0,-1,-1]]).T\n",
    "print(simulate_del_x_v(2,theta,x_0))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (Spyder)",
   "language": "python3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
