{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8137cece",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'Path_Char'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_4000772/2395714010.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0;31m# import higherOrderKME\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0;31m# from higherOrderKME import sigkernel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mPath_Char\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      9\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mPath_Char\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath_characteristic_function\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mchar_func_path\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'Path_Char'"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.notebook import tqdm\n",
    "import itertools\n",
    "# import higherOrderKME\n",
    "# from higherOrderKME import sigkernel\n",
    "import Path_Char\n",
    "from Path_Char.path_characteristic_function import char_func_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2ede6210",
   "metadata": {},
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'cython_backend'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_3995026/465658775.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mhigherOrderKME\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mhigherOrderKME\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msigkernel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/Gitrepos/higherOrderKME/higherOrderKME/sigkernel/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mtransformers\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mcuda_backend\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0msigkernel\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/Gitrepos/higherOrderKME/higherOrderKME/sigkernel/sigkernel.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mnumba\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcuda\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mcython_backend\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msig_kernel_batch_varpar\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msig_kernel_Gram_varpar\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     10\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mcuda_backend\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcompute_sig_kernel_batch_varpar_from_increments_cuda\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcompute_sig_kernel_Gram_mat_varpar_from_increments_cuda\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'cython_backend'"
     ]
    }
   ],
   "source": [
    "import higherOrderKME\n",
    "from higherOrderKME import sigkernel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4e56ce52",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "id": "5625561d",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'sigkernel' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_3963534/2737718005.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mstatic_kernel_21\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msigkernel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mRBFKernel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msigma\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1e-5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0madd_time\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mL\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'sigkernel' is not defined"
     ]
    }
   ],
   "source": [
    "static_kernel_21 = sigkernel.RBFKernel(sigma=1e-5, add_time=L-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7f592e8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# number of path coordinates and number of time steps\n",
    "D, L = 1, 3 \n",
    "\n",
    "# experimental setup\n",
    "repeats, n_samples, n, lambda_ = 100, 500, 5*1e5, 1e-5\n",
    "\n",
    "MMD_1 = np.zeros((repeats, 2))\n",
    "MMD_2 = np.zeros((repeats, 2))\n",
    "\n",
    "# to store the sample paths from X_n and X\n",
    "X = np.zeros((repeats, n_samples, 2, L, D)) \n",
    "X_n = np.zeros((repeats, n_samples, 2, L, D))   \n",
    "\n",
    "# sample from X_n \n",
    "omega_1 = np.random.choice(a=[-1, 1], size=(repeats, n_samples, 2))\n",
    "omega_2 = np.random.choice(a=[-1, 1], size=(repeats, n_samples, 2))\n",
    "X_n[:, :, :, 1, 0] = omega_1 * 1./n\n",
    "X_n[:, :, :, 2, 0] = 0.1 * omega_1\n",
    "\n",
    "# sample from X\n",
    "omega_1 = np.random.choice(a=[-1, 1], size=(repeats, n_samples, 2))\n",
    "X[:, :, :, 2, 0] = 0.1 * omega_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1776dc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_0[0,:,1,:,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "id": "e898e50f",
   "metadata": {},
   "outputs": [],
   "source": [
    "x0 = torch.tensor(X[0, :, 0, :, :], dtype=torch.float64).to(device)     # X of shape (repeats,n_samples,2,L,D)\n",
    "x0_ = torch.tensor(X[0, :, 1, :, :], dtype=torch.float64).to(device)\n",
    "\n",
    "# static_kernel_21.batch_kernel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "id": "5d76c5fc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([500, 3, 1])"
      ]
     },
     "execution_count": 123,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x0_.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 124,
   "id": "10e21324",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'static_kernel_21' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_3963534/1897957268.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mstatic_kernel_21\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_kernel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx0_\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'static_kernel_21' is not defined"
     ]
    }
   ],
   "source": [
    "static_kernel_21.batch_kernel(x0, x0_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "9bd18aef",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "char_func_path(\n",
       "  (unitary_development): development_layer(\n",
       "    (projection): projection(\n",
       "      (param_map): unitary()\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lie_degree_1 = 2\n",
    "num_samples_1 = 20\n",
    "input_size = 1\n",
    "add_time = False\n",
    "\n",
    "num_samples_2 = 20\n",
    "lie_degree_2 = 3\n",
    "\n",
    "pcf_level_1 = char_func_path(num_samples=num_samples_1, \n",
    "                              hidden_size=lie_degree_1, \n",
    "                              input_size=input_size, \n",
    "                              add_time=add_time, \n",
    "                              include_initial = False)\n",
    "\n",
    "pcf_level_2 = char_func_path(num_samples=num_samples_2, \n",
    "                              hidden_size=lie_degree_2, \n",
    "                              input_size=lie_degree_1**2, \n",
    "                              add_time=add_time, \n",
    "                              include_initial = False)\n",
    "\n",
    "pcf_level_1.to(device)\n",
    "pcf_level_2.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "23a4389f",
   "metadata": {},
   "outputs": [],
   "source": [
    "x0 = torch.tensor(X[0, :, 0, :, :], dtype=torch.float64).to(device)\n",
    "dev_x0 = pcf_level_1.unitary_development(x0)\n",
    "expected_dev_x0 = dev_x0.mean(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "id": "b0d9dae4",
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "unhashable type: 'list'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_3963534/1762839126.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m     \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx0\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m     \u001b[0;32mif\u001b[0m \u001b[0mcurr_path\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtree_node\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      9\u001b[0m         \u001b[0mtree_node\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mcurr_path\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx0\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mTypeError\u001b[0m: unhashable type: 'list'"
     ]
    }
   ],
   "source": [
    "t = 2\n",
    "tree_node = {}\n",
    "for i in x0[:,:,:]:\n",
    "    tree_node[2] = {}\n",
    "    curr_path = i.flatten().tolist()\n",
    "    \n",
    "    torch.all(x0 == i, dim=1).flatten()\n",
    "    if curr_path not in tree_node[2]:\n",
    "        tree_node[2][curr_path] = torch.all(x0 == i, dim=1).flatten()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "id": "53dcd60e",
   "metadata": {},
   "outputs": [],
   "source": [
    "t = 2\n",
    "tree_node = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "id": "f80a3016",
   "metadata": {},
   "outputs": [],
   "source": [
    "tree_node[2] = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "id": "fb95ed31",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{}"
      ]
     },
     "execution_count": 119,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tree_node[2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "id": "db224c8e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{2: {}}"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tree_node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "id": "702094d8",
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "unhashable type: 'list'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_3963534/2719578087.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtree_node\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"not in\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mTypeError\u001b[0m: unhashable type: 'list'"
     ]
    }
   ],
   "source": [
    "if [2] not in tree_node[2]:\n",
    "    print(\"not in\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "32284fff",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.0, 0.0, 0.1]\n"
     ]
    }
   ],
   "source": [
    "for i in x0[:,:,:]:\n",
    "    curr_path = i.flatten().tolist()\n",
    "    print(curr_path)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "id": "aac93911",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.0000],\n",
       "         [0.0000],\n",
       "         [0.1000]]], device='cuda:0', dtype=torch.float64)"
      ]
     },
     "execution_count": 94,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "i.reshape([1,-1, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "id": "fcfd5475",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([262, 3, 1])"
      ]
     },
     "execution_count": 116,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x0[torch.all(x0 == i, dim=1).flatten(),:,:].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "id": "170c8e66",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([500, 3, 1])"
      ]
     },
     "execution_count": 109,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(x0 == i).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "id": "a1bcb26a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ True],\n",
       "         [ True],\n",
       "         [ True]],\n",
       "\n",
       "        [[ True],\n",
       "         [ True],\n",
       "         [ True]],\n",
       "\n",
       "        [[ True],\n",
       "         [ True],\n",
       "         [False]],\n",
       "\n",
       "        ...,\n",
       "\n",
       "        [[ True],\n",
       "         [ True],\n",
       "         [False]],\n",
       "\n",
       "        [[ True],\n",
       "         [ True],\n",
       "         [False]],\n",
       "\n",
       "        [[ True],\n",
       "         [ True],\n",
       "         [False]]], device='cuda:0')"
      ]
     },
     "execution_count": 92,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x0 == i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "id": "cb16e27d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0000],\n",
       "        [0.0000],\n",
       "        [0.1000]], device='cuda:0', dtype=torch.float64)"
      ]
     },
     "execution_count": 91,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "7195a87e",
   "metadata": {},
   "outputs": [],
   "source": [
    "dev_x0 = pcf_level_1.unitary_development(x0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "cb2e09d5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.9953+0.0802j, -0.0406-0.0372j],\n",
       "        [ 0.0463-0.0298j,  0.9945+0.0889j]], device='cuda:0',\n",
       "       grad_fn=<SelectBackward0>)"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dev_x0[0,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "cc3228da",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.9953+0.0802j, -0.0406-0.0372j],\n",
       "        [ 0.0463-0.0298j,  0.9945+0.0889j]], device='cuda:0',\n",
       "       grad_fn=<SelectBackward0>)"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dev_x0[1,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "7a66ed0a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.9953-0.0802j,  0.0463+0.0298j],\n",
       "        [-0.0406+0.0372j,  0.9945-0.0889j]], device='cuda:0',\n",
       "       grad_fn=<SelectBackward0>)"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dev_x0[2,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "b09e6a68",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([20, 2, 2])"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "expected_dev_x0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "194ad870",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([500, 20, 2, 2])"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dev_x0.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bde6fc49",
   "metadata": {},
   "outputs": [],
   "source": [
    "# As we progress in t, we need to reduce our sample space and compute the expected development wrt X where X is of the shape [N, T, D]\n",
    "# If the developement measure is empirical of shape [M, lie_degree, lie_degree, D] then the expected development wrt X is of shape\n",
    "# [M, lie_degree, lie_degree] (take the average across N)\n",
    "\n",
    "\n",
    "\n",
    "# As we progress in t, we need to reduce our sample space, we will have [N', T, D] paths, conditional on the past path at t. We compute\n",
    "# the same expected signature.\n",
    "\n",
    "# That is how we creat4ed the path, of shape [M, T, lie_degree, lie_degree], call it Y.\n",
    "\n",
    "# Once we got this, we turn it into flat vector Y.shape = [M, T, lie_degree**2]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5650c8ce",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e347d11b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f4182b3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62b47b3e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6fe0c42",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d319c8cb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9281304e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "2b918e15",
   "metadata": {},
   "outputs": [],
   "source": [
    "At = pcf.unitary_development.projection.A.permute(1, 2, -1, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "98c5e7c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "dx = torch.randn([10, 1]).to(device).to(torch.cfloat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "3358bff1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([20, 2, 2, 10])"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "At.matmul(dx.T).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdc53d71",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d563328",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02cc1afa",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "032dc546",
   "metadata": {},
   "source": [
    "**Run the experiment**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5e420cb0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[[[ 0.e+00],\n",
       "         [-2.e-06],\n",
       "         [-1.e-01]],\n",
       "\n",
       "        [[ 0.e+00],\n",
       "         [-2.e-06],\n",
       "         [-1.e-01]]],\n",
       "\n",
       "\n",
       "       [[[ 0.e+00],\n",
       "         [-2.e-06],\n",
       "         [-1.e-01]],\n",
       "\n",
       "        [[ 0.e+00],\n",
       "         [-2.e-06],\n",
       "         [-1.e-01]]],\n",
       "\n",
       "\n",
       "       [[[ 0.e+00],\n",
       "         [-2.e-06],\n",
       "         [-1.e-01]],\n",
       "\n",
       "        [[ 0.e+00],\n",
       "         [-2.e-06],\n",
       "         [-1.e-01]]],\n",
       "\n",
       "\n",
       "       [[[ 0.e+00],\n",
       "         [ 2.e-06],\n",
       "         [ 1.e-01]],\n",
       "\n",
       "        [[ 0.e+00],\n",
       "         [-2.e-06],\n",
       "         [-1.e-01]]],\n",
       "\n",
       "\n",
       "       [[[ 0.e+00],\n",
       "         [-2.e-06],\n",
       "         [-1.e-01]],\n",
       "\n",
       "        [[ 0.e+00],\n",
       "         [ 2.e-06],\n",
       "         [ 1.e-01]]],\n",
       "\n",
       "\n",
       "       [[[ 0.e+00],\n",
       "         [ 2.e-06],\n",
       "         [ 1.e-01]],\n",
       "\n",
       "        [[ 0.e+00],\n",
       "         [ 2.e-06],\n",
       "         [ 1.e-01]]],\n",
       "\n",
       "\n",
       "       [[[ 0.e+00],\n",
       "         [ 2.e-06],\n",
       "         [ 1.e-01]],\n",
       "\n",
       "        [[ 0.e+00],\n",
       "         [ 2.e-06],\n",
       "         [ 1.e-01]]],\n",
       "\n",
       "\n",
       "       [[[ 0.e+00],\n",
       "         [-2.e-06],\n",
       "         [-1.e-01]],\n",
       "\n",
       "        [[ 0.e+00],\n",
       "         [-2.e-06],\n",
       "         [-1.e-01]]],\n",
       "\n",
       "\n",
       "       [[[ 0.e+00],\n",
       "         [-2.e-06],\n",
       "         [-1.e-01]],\n",
       "\n",
       "        [[ 0.e+00],\n",
       "         [ 2.e-06],\n",
       "         [ 1.e-01]]],\n",
       "\n",
       "\n",
       "       [[[ 0.e+00],\n",
       "         [ 2.e-06],\n",
       "         [ 1.e-01]],\n",
       "\n",
       "        [[ 0.e+00],\n",
       "         [ 2.e-06],\n",
       "         [ 1.e-01]]]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for i in range(repeats):\n",
    "\n",
    "    x0 = torch.tensor(X[i, :, 0, :, :], dtype=torch.float64).to(device)     # X of shape (repeats,n_samples,2,L,D)\n",
    "    x0_ = torch.tensor(X[i, :, 1, :, :], dtype=torch.float64).to(device)    # independent copy of X\n",
    "    xn = torch.tensor(X_n[i, :, 0, :, :], dtype=torch.float64).to(device)\n",
    "    xn_ = torch.tensor(X_n[i, :, 1, :, :], dtype=torch.float64).to(device)  # independent copy of X_n\n",
    "\n",
    "    MMD_2[i,0] = kernel_order2.compute_mmd(x0, x0_, lambda_=lambda_, estimator='ub', order=2)   \n",
    "    MMD_2[i,1] = kernel_order2.compute_mmd(xn, x0, lambda_=lambda_, estimator='ub', order=2)   \n",
    "    \n",
    "    MMD_1[i,0] = kernel_order1.compute_mmd(x0, x0_, estimator='ub', order=1)   \n",
    "    MMD_1[i,1] = kernel_order1.compute_mmd(xn, x0, estimator='ub', order=1)   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "92ecc336",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[[-1, -1],\n",
       "        [ 1, -1],\n",
       "        [-1,  1],\n",
       "        ...,\n",
       "        [ 1, -1],\n",
       "        [ 1, -1],\n",
       "        [ 1, -1]],\n",
       "\n",
       "       [[ 1, -1],\n",
       "        [ 1, -1],\n",
       "        [-1, -1],\n",
       "        ...,\n",
       "        [-1, -1],\n",
       "        [-1, -1],\n",
       "        [-1,  1]],\n",
       "\n",
       "       [[ 1,  1],\n",
       "        [-1, -1],\n",
       "        [-1, -1],\n",
       "        ...,\n",
       "        [-1, -1],\n",
       "        [ 1, -1],\n",
       "        [ 1,  1]],\n",
       "\n",
       "       ...,\n",
       "\n",
       "       [[ 1, -1],\n",
       "        [ 1,  1],\n",
       "        [ 1, -1],\n",
       "        ...,\n",
       "        [ 1, -1],\n",
       "        [-1,  1],\n",
       "        [-1, -1]],\n",
       "\n",
       "       [[ 1, -1],\n",
       "        [-1,  1],\n",
       "        [ 1, -1],\n",
       "        ...,\n",
       "        [-1,  1],\n",
       "        [ 1,  1],\n",
       "        [-1,  1]],\n",
       "\n",
       "       [[-1, -1],\n",
       "        [ 1,  1],\n",
       "        [ 1,  1],\n",
       "        ...,\n",
       "        [ 1,  1],\n",
       "        [-1,  1],\n",
       "        [-1, -1]]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "omega_2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22e6c575",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "levy",
   "language": "python",
   "name": "levy"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
