{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n"
     ]
    }
   ],
   "source": [
    "from sys import path\n",
    "from os import getcwd\n",
    "path.append(getcwd()+\"/../\")\n",
    "import pipelines.pipelines_1parameter as p1\n",
    "import pipelines.pipelines_2parameter as p2\n",
    "from datasets_getter.UCR import get as get_ucr\n",
    "from datasets_getter.graphs import get_graphs\n",
    "import numpy as np\n",
    "from random import choice\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook provide elementary code to get the timings to compute our vectorization from the datasets. \n",
    "Note that it assumes that the datasets are located at\n",
    " - `\"$HOME_DIRECTORY/Datasets/UCR/$name_of_dataset\"` for the UCR datasets\n",
    " - `\"$HOME_DIRECTORY/Datasets/graphs/$name_of_dataset\"` for the graph datasets\n",
    "\n",
    "Also note that some parts of the code is compiled by numba, and needs multiple runs to warmup. \n",
    "\n",
    "Furthermore, the convolution times depends on the format of the signed measure (sparse or dense tensor); here we take the slowest approach, which is computing convolution on sparse signed measures. \n",
    "\n",
    "See `pipelines/convolutions/convolutions.py` for the different convolutions implementations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "UCR_datasets = [\"UCR/DistalPhalanxOutlineAgeGroup\",\"UCR/DistalPhalanxOutlineCorrect\",\"UCR/DistalPhalanxTW\",\"UCR/ProximalPhalanxOutlineAgeGroup\",\"UCR/ProximalPhalanxOutlineCorrect\",\"UCR/ProximalPhalanxTW\",\"UCR/ECG200\",\"UCR/ItalyPowerDemand\",\"UCR/MedicalImages\",\"UCR/Plane\",\"UCR/SwedishLeaf\",\"UCR/GunPoint\",\"UCR/GunPointAgeSpan\",\"UCR/GunPointMaleVersusFemale\",\"UCR/GunPointOldVersusYoung\",\"UCR/PowerCons\",\"UCR/SyntheticControl\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "dataset = \"UCR/ItalyPowerDemand\"\n",
    "def get_UCR(dataset):\n",
    "    xtrain, ytrain = get_ucr(dataset=dataset, test=False)\n",
    "    xtest,ytest = get_ucr(dataset=dataset, test=False)\n",
    "    X = np.concatenate([xtrain,xtest], axis=0)\n",
    "    Y = np.concatenate([ytrain,ytest], axis=0)\n",
    "    return X,Y, len(xtrain)\n",
    "X,Y,num_train = get_UCR(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def time_st_ucr(dataset):\n",
    "    from time import perf_counter_ns, sleep\n",
    "    X,Y,num_train= get_UCR(dataset)\n",
    "    timings = []\n",
    "    ## simplextrees\n",
    "    sleep(0.5)\n",
    "    start = perf_counter_ns()\n",
    "    X = p2.RipsDensity2SimplexTree(n_jobs=16, num_collapse=100, bandwidth=-0.1, progress=False).fit_transform(X);\n",
    "    end = perf_counter_ns()\n",
    "    timings.append((end-start) / 10**9)\n",
    "    ## signed measures\n",
    "    sleep(0.5)\n",
    "    start = perf_counter_ns()\n",
    "    X = p2.SimplexTree2SignedMeasure(degrees = [0,1],filtration_quantile=0.01,normalize_filtrations=True,infer_filtration_strategy=\"regular\", resolution=200, n_jobs=8).fit_transform(X);\n",
    "    end = perf_counter_ns()\n",
    "    timings.append((end-start) / 10**9)\n",
    "    sleep(0.5)\n",
    "\n",
    "    ## Sliced wasserstein kernel\n",
    "    start = perf_counter_ns()\n",
    "    p2.SignedMeasure2SlicedWassersteinDistance(num_directions=50, n_jobs=8).fit(X[:num_train]).transform(X[num_train:]);\n",
    "    end = perf_counter_ns()\n",
    "    timings.append((end-start) / 10**9)\n",
    "    sleep(0.5)\n",
    "\n",
    "    ## Signed measure convolution\n",
    "    start = perf_counter_ns()\n",
    "    p2.SignedMeasure2Img(resolution=50,n_jobs=4,old_implementation=False).fit_transform(X);\n",
    "    end = perf_counter_ns()\n",
    "    timings.append((end-start) / 10**9)\n",
    "    sleep(0.5)\n",
    "    return timings\n",
    "def get_num_simplices(dataset):\n",
    "\tX,Y,num = get_UCR(dataset)\n",
    "\tsimplextree, = p2.RipsDensity2SimplexTree(n_jobs=8, progress=False).fit_transform([X[0]]);\n",
    "\tsimplextree.expansion(2)\n",
    "\treturn simplextree.num_simplices()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1561"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# num_simplices_ucr = [get_num_simplices(dataset) for dataset in tqdm(UCR_datasets)]\n",
    "get_num_simplices(dataset)\n",
    "# num_simplices_ucr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Inferring filtration grid from simplextrees, with strategy regular...Done.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.150031166, 0.22194875, 0.6134885, 1.123133083]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Timings for : \n",
    "# - Construct the bifiltered simplicial complex\n",
    "# - Compute degree 0 and 1 hilbert function decomposition, i.e., the signed measure\n",
    "# - Compute the sliced wasserstein kernel from the signed measure\n",
    "# - Compute the signed measure convolution\n",
    "time_st_ucr(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# from tqdm import tqdm\n",
    "# timings = [time_st_ucr(dataset) for dataset in tqdm(UCR_datasets)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "graphs_dataset = [\"graphs/COX2\",\"graphs/DHFR\",\"graphs/IMDB-BINARY\",\"graphs/IMDB-MULTI\",\"graphs/MUTAG\",\"graphs/PROTEINS\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def time_st_graphs(dataset):\n",
    "\tfrom time import perf_counter_ns, sleep\n",
    "\tX,Y= get_graphs(dataset)\n",
    "\tnum_train = int(.8*len(X))\n",
    "\ttimings = []\n",
    "\t## simplextrees\n",
    "\tsleep(0.5)\n",
    "\tstart = perf_counter_ns()\n",
    "\tX = p2.Graph2SimplexTree(filtrations=[\"hks_10\", \"cc\", \"degree\"]).fit_transform(X);\n",
    "\tend = perf_counter_ns()\n",
    "\ttimings.append((end-start) / 10**9)\n",
    "\t\n",
    "\tnum_simplices = [x.num_simplices() for x in X]\n",
    "\t\n",
    "\t## signed measures\n",
    "\tsleep(0.5)\n",
    "\tstart = perf_counter_ns()\n",
    "\tX = p2.SimplexTree2SignedMeasure(degrees = [0,1],infer_filtration_strategy=\"exact\", resolution=100, n_jobs=8).fit_transform(X);\n",
    "\tend = perf_counter_ns()\n",
    "\ttimings.append((end-start) / 10**9)\n",
    "\tsleep(0.5)\n",
    "\n",
    "\t## Sliced wasserstein kernel\n",
    "\tstart = perf_counter_ns()\n",
    "\tp2.SignedMeasure2SlicedWassersteinDistance(num_directions=50, n_jobs=8).fit(X[:num_train]).transform(X[num_train:]);\n",
    "\tend = perf_counter_ns()\n",
    "\ttimings.append((end-start) / 10**9)\n",
    "\tsleep(0.5)\n",
    "\n",
    "\t## Signed measure convolution\n",
    "\tstart = perf_counter_ns()\n",
    "\tp2.SignedMeasure2Img(resolution=20,n_jobs=4,old_implementation=False).fit_transform(X);\n",
    "\tend = perf_counter_ns()\n",
    "\ttimings.append((end-start) / 10**9)\n",
    "\tsleep(0.5)\n",
    "\treturn timings, np.mean(num_simplices, axis=0), np.std(num_simplices, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# num_data = [len(get_graphs(graph)[0]) for graph in tqdm(graphs_dataset)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# The same for graphs\n",
    "# [time_st_graphs(graph) for graph in tqdm(graphs_dataset)]\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
