{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import tqdm\n",
    "import numpy as np\n",
    "from numpy import linalg as la\n",
    "from sklearn import svm\n",
    "\n",
    "from functions import tda\n",
    "from functions import pwk  # persistence weighted kernel\n",
    "import utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded persistence diagrams.\n"
     ]
    }
   ],
   "source": [
    "x_train, y_train, x_test, y_test = utils.get_data_images('mnist')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def nearestPD(A):\n",
    "    \"\"\"\n",
    "    Adapted from @https://stackoverflow.com/questions/43238173/python-convert-matrix-to-positive-semi-definite\n",
    "    Find the nearest positive-definite matrix to input\n",
    "    \n",
    "    A Python/Numpy port of John D'Errico's `nearestSPD` MATLAB code [1], which\n",
    "    credits [2].\n",
    "\n",
    "    [1] https://www.mathworks.com/matlabcentral/fileexchange/42885-nearestspd\n",
    "\n",
    "    [2] N.J. Higham, \"Computing a nearest symmetric positive semidefinite\n",
    "    matrix\" (1988): https://doi.org/10.1016/0024-3795(88)90223-6\n",
    "    \"\"\"\n",
    "\n",
    "    B = (A + A.T) / 2\n",
    "    _, s, V = la.svd(B)\n",
    "\n",
    "    H = np.dot(V.T, np.dot(np.diag(s), V))\n",
    "\n",
    "    A2 = (B + H) / 2\n",
    "\n",
    "    A3 = (A2 + A2.T) / 2\n",
    "\n",
    "    if isPD(A3):\n",
    "        return A3\n",
    "\n",
    "    spacing = np.spacing(la.norm(A))\n",
    "    I = np.eye(A.shape[0])\n",
    "    k = 1\n",
    "    while not isPD(A3):\n",
    "        mineig = np.min(np.real(la.eigvals(A3)))\n",
    "        A3 += I * (-mineig * k**2 + spacing)\n",
    "        k += 1\n",
    "\n",
    "    return A3\n",
    "\n",
    "def isPD(B):\n",
    "    \"\"\"Returns true when input is positive-definite, via Cholesky\"\"\"\n",
    "    try:\n",
    "        _ = la.cholesky(B)\n",
    "        return True\n",
    "    except la.LinAlgError:\n",
    "        return False\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_gram(list_pd):\n",
    "    def function_weight(_name_weight, _val_c=1.0, _val_p=1):\n",
    "        if _name_weight == \"arctan\":\n",
    "            def _func_weight(vec_bd):\n",
    "                return np.arctan(\n",
    "                    np.power((vec_bd[1] - vec_bd[0]) / _val_c, _val_p))\n",
    "        else:  # p-persistence\n",
    "            def _func_weight(vec_bd):\n",
    "                return np.power(vec_bd[1] - vec_bd[0], _val_p)\n",
    "        return _func_weight\n",
    "\n",
    "    def function_kernel(_name_kernel, _val_sigma=1.0):\n",
    "        if _name_kernel == \"Gaussian\":\n",
    "            def _func_kernel(vec_bd_1, vec_bd_2):\n",
    "                val_dist = np.power(np.linalg.norm(vec_bd_1 - vec_bd_2, 2), 2)\n",
    "                return np.exp(-1.0 * val_dist / (2.0 * np.power(val_sigma, 2)))\n",
    "        else:  # linear kernel\n",
    "            def _func_kernel(vec_bd_1, vec_bd_2):\n",
    "                return np.dot(vec_bd_1, vec_bd_2)\n",
    "        return _func_kernel\n",
    "\n",
    "    val_sigma = tda.parameter_sigma(list_pd)\n",
    "    val_c = tda.parameter_birth_death_pers(list_pd)[2]\n",
    "    val_p = 5\n",
    "    func_kernel = function_kernel(\"Gaussian\", val_sigma)\n",
    "    func_weight = function_weight(\"arctan\", val_c, val_p)\n",
    "\n",
    "    name_rkhs = [\"Gaussian\", \"Linear\"][1]\n",
    "    approx = [True, False][0]\n",
    "    class_pwk = pwk.Kernel(\n",
    "        list_pd, func_kernel, func_weight, val_sigma, name_rkhs, approx)\n",
    "    mat_gram = class_pwk.gram()\n",
    "    return mat_gram"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def process_data(x):\n",
    "    '''\n",
    "        Computes the gram matrix for the cubical filtration\n",
    "        We restrict to the cubical filtration because it becomes expensive in both cpu and memory to do all filtrations.\n",
    "        If needed other filtrationed could be individually tested, but when alone cubical filtration gives the best results.\n",
    "    '''\n",
    "    N = x[0].shape[0]\n",
    "    grams = []\n",
    "    \n",
    "    for xtf in x[:1]: # Get cubical filtration\n",
    "        list_pds = []\n",
    "        i=0\n",
    "        for pd_img in xtf[:1000]:\n",
    "            n = pd_img.shape[1]\n",
    "\n",
    "            #flatten accross axis=0 \n",
    "            pd_img = np.reshape(pd_img,newshape=(2*n,2))\n",
    "            \n",
    "            # Remove zeros\n",
    "            mask1 = pd_img[:,0]==0\n",
    "            mask2 = pd_img[:,1]==0\n",
    "            mask=mask1*mask2\n",
    "            pd_img = pd_img[~mask,:]\n",
    "            \n",
    "            # Add pd to list\n",
    "            list_pds.append(np.sqrt(pd_img))\n",
    "            \n",
    "        # Get gram matrix for current filtration\n",
    "        gr = get_gram(list_pds)\n",
    "        gr = nearestPD(gr)\n",
    "        \n",
    "        # Append it to the list\n",
    "        grams.append(gr)\n",
    "    return grams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0:00:07.873235\n"
     ]
    }
   ],
   "source": [
    "import datetime\n",
    "begin_time = datetime.datetime.now()\n",
    "K_train = process_data(x_train)\n",
    "K_test = process_data(x_test)\n",
    "end_time = datetime.datetime.now()\n",
    "print(datetime.datetime.now() - begin_time)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf = svm.SVC(kernel='precomputed')\n",
    "clf.fit(K_train[0],y_train[:1000])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf.score(K_test[0],y_test[:1000])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}