{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.metrics import top_k_accuracy_score, accuracy_score\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from torch import optim\n",
    "import numpy as np\n",
    "from time import process_time as time\n",
    "\n",
    "from data_utils import fast_scandir_for_folders, load_shape_data, samples_n_vertices\n",
    "from metrics_utility import my_pairwise_distances, metric_gw, metric_sgw, metric_sse, metric_min_sse, metric_distrib_sse, metric_distrib_min_sse\n",
    "from TransformNet import TransformNet, TransformLatenttoOrig"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load shape data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "subfolders = fast_scandir_for_folders('./data/nonrigid3d')\n",
    "data, labels = load_shape_data(subfolders)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D  \n",
    "from matplotlib import gridspec\n",
    "%matplotlib inline\n",
    "\n",
    "fig = plt.figure(figsize=(15,10))\n",
    "\n",
    "X1= data[0]\n",
    "ax = fig.add_subplot(121, projection='3d')\n",
    "ax.scatter(X1[:,0],X1[:,2],X1[:,1], marker='o', s=20, c=\"goldenrod\", alpha=0.6)\n",
    "ax.view_init(elev=0., azim=0)\n",
    "ax.set_axis_off()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Remove the class \"shark\" that includes only one sample and pack the remaining shapes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ind = np.asarray([k for k, l in enumerate(labels) if l != \"shark\"])\n",
    "\n",
    "# Pack the remaining data and save\n",
    "y = np.array([labels[k] for k  in ind])\n",
    "X = np.array([data[k] for k in ind], dtype=object)\n",
    "\n",
    "pathdata = './data/'\n",
    "filename = 'shapes'\n",
    "np.savez(pathdata+filename,\n",
    "         y=y,\n",
    "         X = X,\n",
    "         )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3D shape comparisons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Parameters \n",
    "vector_n = [100, 250, 500, 1000] #number of vertices to be sampled on each shape 250, 500, 1000\n",
    "n_pairs = 100 # number of pairs to be considered to estimate computation time\n",
    "vector_k = [1, 3, 5] # K of top-K accuracy\n",
    "\n",
    "pathdata = './data/'\n",
    "filename = 'shapes'\n",
    "\n",
    "# pairwise distance computation\n",
    "n_jobs_pw = 3\n",
    "n_jobs_time = 1 # here we use a single CPU to monitore average computation time of n_pairs distances\n",
    "\n",
    "\n",
    "# Init recorded performances\n",
    "allperf_dist_gw = np.zeros((len(vector_n), len(vector_k)))\n",
    "allperf_dist_sgw = np.zeros((len(vector_n), len(vector_k)))\n",
    "allperf_dist_sse = np.zeros((len(vector_n), len(vector_k)))\n",
    "allperf_min_sse = np.zeros((len(vector_n), len(vector_k)))\n",
    "allperf_distrib_sse = np.zeros((len(vector_n), len(vector_k)))\n",
    "allperf_distrib_min_sse = np.zeros((len(vector_n), len(vector_k)))\n",
    "\n",
    "time_gw = np.zeros(len(vector_n))\n",
    "time_sgw = np.zeros(len(vector_n))\n",
    "time_sse = np.zeros(len(vector_n))\n",
    "time_min_sse = np.zeros(len(vector_n))\n",
    "time_distrib_sse = np.zeros(len(vector_n))\n",
    "time_distrib_min_sse = np.zeros(len(vector_n))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cpu\"\n",
    "nproj = 1000\n",
    "nproj_dist = 10\n",
    "nproj_dist_d = 10\n",
    "lr = 0.005\n",
    "nb_iter = 50\n",
    "dim_latent = 5\n",
    "\n",
    "dim_s = X[0].shape[1] # = 3\n",
    "dim_t = dim_s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_samples_timing = int(np.sqrt(n_pairs)) # we will randomly select n_samples_timing shapes to evalution computation burden\n",
    "\n",
    "for i in range(len(vector_n)):\n",
    "    print(f\"n = {vector_n[i]}\")\n",
    "    \n",
    "    # in-place sampling of n vertices on each shape\n",
    "    with np.load(pathdata+filename+'.npz', allow_pickle=True) as data:\n",
    "        y = data['y']\n",
    "        X = data['X']\n",
    "    \n",
    "    X = samples_n_vertices(X, vector_n[i])\n",
    "    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.5, random_state=42)\n",
    "    \n",
    "        \n",
    "    # normalize each shape (otherwise distance computation for GW messes up)\n",
    "    scaler = StandardScaler()\n",
    "    for idx, shape in enumerate(X_train):\n",
    "        X_train[idx] = scaler.fit_transform(shape)\n",
    "        \n",
    "    for idx, shape in enumerate(X_test):\n",
    "        X_test[idx] = scaler.fit_transform(shape)\n",
    "        \n",
    "    # selected samples to evaluate computation time of each method\n",
    "    ind_samples = np.random.choice(X_train.shape[0], n_samples_timing, replace=False)\n",
    "        \n",
    "    # ----- Compute pairwise distance matrix using according to benchmarked metrics and perform 1-NN classification ------\n",
    "    knn  = KNeighborsClassifier(n_neighbors=1, metric=\"precomputed\")\n",
    "    \n",
    "    ## ========== GW =========\n",
    "    print(\"Gromov-Wasserstein\")\n",
    "    # 1-NN classification using pre-computed pairwise distance matrix\n",
    "    mat_gw = my_pairwise_distances(X=X_train, Y=X_train, metric=metric_gw, n_jobs=n_jobs_pw)\n",
    "    knn.fit(np.abs(mat_gw), y_train)\n",
    "    \n",
    "    # top-k accuracy\n",
    "    mat_gw_test = my_pairwise_distances(X=X_test, Y=X_train, metric=metric_gw, n_jobs=n_jobs_pw)\n",
    "    y_gw = knn.predict_proba(np.abs(mat_gw_test))\n",
    "    \n",
    "    for j in range(len(vector_k)):\n",
    "        allperf_dist_gw[i, j] = top_k_accuracy_score(y_test, y_gw, k=vector_k[j])\n",
    "        \n",
    "    # computation time\n",
    "    tic = time()\n",
    "    _ = my_pairwise_distances(X=X_train[ind_samples], Y=X_train[ind_samples], metric=metric_gw, n_jobs = n_jobs_time)\n",
    "    time_gw[i] = (time()-tic)/n_pairs\n",
    "    \n",
    "    print(f\"perf = {allperf_dist_gw[i]} and time = {time_gw[i]}\")\n",
    "    \n",
    "    \n",
    "    ## ============= Sliced GW ================\n",
    "    print(\"\\nSliced GW\")\n",
    "    # 1-NN\n",
    "    mat_sgw = my_pairwise_distances(X=X_train, Y=X_train, metric=metric_sgw, n_jobs=n_jobs_pw, nproj=nproj)\n",
    "    knn.fit(np.abs(mat_sgw), y_train)\n",
    "    \n",
    "    # top-k accuracy\n",
    "    mat_sgw_test = my_pairwise_distances(X=X_test, Y=X_train, metric=metric_sgw, n_jobs=n_jobs_pw, nproj=nproj)\n",
    "    y_sgw = knn.predict_proba(np.abs(mat_sgw_test))\n",
    "    \n",
    "    for j in range(len(vector_k)):\n",
    "        allperf_dist_sgw[i, j] = top_k_accuracy_score(y_test, y_sgw, k=vector_k[j])\n",
    "        \n",
    "    # computation time\n",
    "    tic = time()\n",
    "    _ = my_pairwise_distances(X=X_train[ind_samples], Y=X_train[ind_samples], metric=metric_sgw, n_jobs = n_jobs_time, nproj=nproj)\n",
    "    time_sgw[i] = (time()-tic)/n_pairs\n",
    "    \n",
    "    print(f\"perf = {allperf_dist_sgw[i]} and time = {time_sgw[i]}\")\n",
    "    \n",
    "    ## ========== SSE =========\n",
    "    print(\"\\nSSE\")\n",
    "    # 1-NN\n",
    "    mat_sse = my_pairwise_distances(X=X_train, Y=X_train, metric=metric_sse, n_jobs=n_jobs_pw, nproj=nproj)\n",
    "    knn.fit(np.abs(mat_sse), y_train)\n",
    "    \n",
    "    # accuracy\n",
    "    mat_sse_test = my_pairwise_distances(X=X_test, Y=X_train, metric=metric_sse, n_jobs=n_jobs_pw, nproj=nproj)\n",
    "    y_sse = knn.predict_proba(np.abs(mat_sse_test))\n",
    "    \n",
    "    for j in range(len(vector_k)):\n",
    "        allperf_dist_sse[i, j] = top_k_accuracy_score(y_test, y_sse, k=vector_k[j])\n",
    "    \n",
    "    #timing\n",
    "    tic = time()\n",
    "    _ = my_pairwise_distances(X=X_train[ind_samples], Y=X_train[ind_samples], metric=metric_sse, n_jobs = n_jobs_time, nproj=nproj)\n",
    "    time_sse[i] = (time()-tic)/n_pairs\n",
    "    \n",
    "    print(f\"perf = {allperf_dist_sse[i]} and time = {time_sse[i]}\")\n",
    "    \n",
    "    \n",
    "    ## ============= Min SSE ===========\n",
    "    print(\"\\nMin SSE\")\n",
    "    fp = TransformLatenttoOrig(dim_latent,dim_s).to(device)\n",
    "    fq = TransformLatenttoOrig(dim_latent,dim_t).to(device)\n",
    "    \n",
    "    fp_optim = optim.Adam(fp.parameters(), lr=0.001, betas=(0.5, 0.999),weight_decay=0.5)\n",
    "    fq_optim = optim.Adam(fq.parameters(), lr=0.001, betas=(0.5, 0.999),weight_decay=0.5)\n",
    "    \n",
    "    # 1-NN\n",
    "    mat_min_sse = my_pairwise_distances(X=X_train, Y=X_train, metric=metric_min_sse, n_jobs=1, s_latent2orig_net=fp,\n",
    "                                        t_latent2orig_net=fq, opt_s=fp_optim, opt_t=fq_optim, \n",
    "                                        dim_latent=dim_latent, nproj_dist=nproj_dist, max_iter=nb_iter)\n",
    "    knn.fit(np.abs(mat_min_sse), y_train)\n",
    "    \n",
    "    # accuracy\n",
    "    mat_min_sse_test = my_pairwise_distances(X=X_test, Y=X_train, metric=metric_min_sse, n_jobs=1, s_latent2orig_net=fp,\n",
    "                                        t_latent2orig_net=fq, opt_s=fp_optim, opt_t=fq_optim, \n",
    "                                        dim_latent=dim_latent, nproj_dist=nproj_dist, max_iter=nb_iter)\n",
    "    y_min_sse = knn.predict_proba(np.abs(mat_min_sse_test))\n",
    "    \n",
    "    \n",
    "    for j in range(len(vector_k)):\n",
    "        allperf_min_sse[i, j] = top_k_accuracy_score(y_test, y_min_sse, k=vector_k[j])\n",
    "        \n",
    "    # time\n",
    "    tic = time()\n",
    "    _ = my_pairwise_distances(X=X_train[ind_samples], Y=X_train[ind_samples], metric=metric_min_sse, n_jobs=n_jobs_time, s_latent2orig_net=fp,\n",
    "                                        t_latent2orig_net=fq, opt_s=fp_optim, opt_t=fq_optim, \n",
    "                                        dim_latent=dim_latent, nproj_dist=nproj_dist, max_iter=nb_iter)\n",
    "    time_min_sse[i] = (time()-tic)/n_pairs\n",
    "    \n",
    "    print(f\"perf = {allperf_min_sse[i]} and time = {time_min_sse[i]}\")\n",
    "    \n",
    "    ## =========== Distributional SSE ================\n",
    "    print(\"\\nDistributional SSE\")\n",
    "    fs = TransformNet(dim_s).to(device)\n",
    "    ft = TransformNet(dim_t).to(device)\n",
    "    \n",
    "    fs_optim = optim.Adam(fs.parameters(), lr=lr, betas=(0.5, 0.999))\n",
    "    ft_optim = optim.Adam(ft.parameters(), lr=lr, betas=(0.5, 0.999))\n",
    "    \n",
    "    # 1-NN\n",
    "    mat_distrib_sse = my_pairwise_distances(X=X_train, Y=X_train, metric=metric_distrib_sse, n_jobs=1, s_net = fs, \n",
    "                                     t_net = ft, opt_s=fs_optim, opt_t = ft_optim, nproj_dist=nproj_dist, max_iter=nb_iter)\n",
    "    \n",
    "    knn.fit(np.abs(mat_distrib_sse), y_train)\n",
    "\n",
    "    # accuracy\n",
    "    mat_distrib_sse_test = my_pairwise_distances(X=X_test, Y=X_train, metric=metric_distrib_sse, n_jobs=1, s_net = fs, \n",
    "                                     t_net = ft, opt_s=fs_optim, opt_t = ft_optim, nproj_dist=nproj_dist, max_iter=nb_iter)\n",
    "    \n",
    "    y_distrib_sse = knn.predict_proba(np.abs(mat_distrib_sse_test))\n",
    "    for j in range(len(vector_k)):\n",
    "        allperf_distrib_sse[i, j] = top_k_accuracy_score(y_test, y_distrib_sse, k=vector_k[j])\n",
    "        \n",
    "    # time\n",
    "    tic = time()\n",
    "    _ = my_pairwise_distances(X=X_train[ind_samples], Y=X_train[ind_samples],metric=metric_distrib_sse, n_jobs=1, s_net = fs, \n",
    "                              t_net = ft, opt_s=fs_optim, opt_t = ft_optim, nproj_dist=nproj_dist, max_iter=nb_iter)\n",
    "    time_distrib_sse[i] = (time()-tic)/n_pairs\n",
    "    \n",
    "    print(f\"perf = {allperf_distrib_sse[i]} and time = {time_distrib_sse[i]}\")\n",
    "    \n",
    "    \n",
    "    ## =========== Distributional Min SSE ================\n",
    "    print(\"\\nDistributional Min SSE\")\n",
    "    transf_net = TransformNet(dim_latent).to(device)\n",
    "    fp = TransformLatenttoOrig(dim_latent,dim_s).to(device)\n",
    "    fq = TransformLatenttoOrig(dim_latent,dim_t).to(device)\n",
    "    \n",
    "    transf_net_optim = optim.Adam(transf_net.parameters(), lr=0.001, betas=(0.5, 0.999),weight_decay=0.5)\n",
    "    fp_optim = optim.Adam(fp.parameters(), lr=0.001, betas=(0.5, 0.999),weight_decay=0.5)\n",
    "    fq_optim = optim.Adam(fq.parameters(), lr=0.001, betas=(0.5, 0.999),weight_decay=0.5)\n",
    "    \n",
    "    # 1-NN\n",
    "    mat_distrib_min_sse = my_pairwise_distances(X=X_train, Y=X_train, metric=metric_distrib_min_sse, n_jobs=1, \n",
    "                                                transf_net = transf_net, s_latent2orig_net = fp, t_latent2orig_net = fq, \n",
    "                                                opt_trannet = transf_net_optim, opt_s=fp_optim, opt_t = fq_optim,\n",
    "                                                dim_latent = dim_latent, nproj_dist = nproj_dist_d, \n",
    "                                               num_epochs=nb_iter, num_sup_iter = 2, num_inf_iter = 2)\n",
    "    knn.fit(np.abs(mat_distrib_min_sse), y_train)\n",
    "    \n",
    "    # accuracy\n",
    "    mat_distrib_min_sse_test = my_pairwise_distances(X=X_test, Y=X_train, metric=metric_distrib_min_sse, n_jobs=1, \n",
    "                                                transf_net = transf_net, s_latent2orig_net = fp, t_latent2orig_net = fq, \n",
    "                                                opt_trannet = transf_net_optim, opt_s=fp_optim, opt_t = fq_optim,\n",
    "                                                dim_latent = dim_latent, nproj_dist = nproj_dist_d, \n",
    "                                               num_epochs=nb_iter, num_sup_iter = 2, num_inf_iter = 2)\n",
    "    y_distrib_min_sse = knn.predict_proba(np.abs(mat_distrib_min_sse_test))\n",
    "    for j in range(len(vector_k)):\n",
    "        allperf_distrib_min_sse[i, j] = top_k_accuracy_score(y_test, y_distrib_min_sse, k=vector_k[j])\n",
    "        \n",
    "    #time\n",
    "    tic = time()\n",
    "    _ = my_pairwise_distances(X=X_train[ind_samples], Y=X_train[ind_samples],metric=metric_distrib_min_sse, n_jobs=1, \n",
    "                                                transf_net = transf_net, s_latent2orig_net = fp, t_latent2orig_net = fq, \n",
    "                                                opt_trannet = transf_net_optim, opt_s=fp_optim, opt_t = fq_optim,\n",
    "                                                dim_latent = dim_latent, nproj_dist = nproj_dist_d, \n",
    "                                               num_epochs=nb_iter, num_sup_iter = 2, num_inf_iter = 2)\n",
    "    \n",
    "    time_distrib_min_sse[i] = (time()-tic)/n_pairs\n",
    "    \n",
    "    print(f\"perf = {allperf_distrib_min_sse[i]} and time = {time_distrib_min_sse[i]}\")\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### PAck and save the results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
