{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "from torch import optim\n",
    "import numpy as np\n",
    "from time import process_time as time\n",
    "import pylab as plt\n",
    "\n",
    "from data_utils import fast_scandir_for_folders, load_shape_data, samples_n_vertices\n",
    "from metrics_utility import my_pairwise_distances, metric_gw, metric_sgw, metric_distrib_min_sse\n",
    "from TransformNet import TransformNet, TransformLatenttoOrig"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Computation burden\n",
    "* Sample $npairs = 100$ pairs of shapes $(x_i, x_j)$ with x_i,  x_j formed of $n$ 3d-vertices\n",
    "* Monitore the average computation time of $distance(x_i, x_j)$ where distance for GW, SGW and DMinSSE distances\n",
    "* Proceed so for different values of $n$  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vector_n = [100, 250, 500, 1000, 1500, 2000] #number of vertices to be sampled on each shape\n",
    "n_pairs = 100 # number of pairs to be considered to estimate computation time\n",
    "n_jobs_time = 1 # here we use a single CPU to monitore average computation time of n_pairs distances\n",
    "\n",
    "device = \"cpu\"\n",
    "nproj = 1000 # number of  unit random vectors to sample for SGW\n",
    "nproj_dist_d = 10\n",
    "max_epoch = 50\n",
    "n_iter_inner = 1\n",
    "n_repeat = 10\n",
    "dim_latent = 5\n",
    "err_gw = 0\n",
    "\n",
    "# Init recorded performances\n",
    "time_gw = np.zeros((len(vector_n), n_repeat))\n",
    "time_sgw = np.zeros((len(vector_n), n_repeat))\n",
    "time_distrib_min_sse = np.zeros((len(vector_n), n_repeat))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pathdata = './data/'\n",
    "filename = 'shapes'\n",
    "\n",
    "expe = \"timing\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_samples_timing = int(np.sqrt(n_pairs)) # we will randomly select n_samples_timing shapes to evalution computation burden\n",
    "\n",
    "for i in range(len(vector_n)):\n",
    "    print(f\"\\nn = {vector_n[i]}\")\n",
    "    \n",
    "    # load shapes dataset\n",
    "    with np.load(pathdata+filename+'.npz', allow_pickle=True) as data:\n",
    "        y = data['y']\n",
    "        X = data['X']\n",
    "    \n",
    "    # draw randomly the indices of samples used to evaluate computation time of each method\n",
    "    ind_samples = np.random.choice(X.shape[0], n_samples_timing, replace=False)\n",
    "    X = X[ind_samples]\n",
    "    \n",
    "    dim_s = X[0].shape[1] # = 3\n",
    "    dim_t = dim_s \n",
    "    \n",
    "    for j in range(n_repeat):\n",
    "        print(f\"\\nj = {j}\")\n",
    "        # random sampling of n vertices on each shape\n",
    "        Xred = samples_n_vertices(X, vector_n[i])\n",
    "           \n",
    "        # normalize each shape (otherwise distance computation for GW messes up)\n",
    "        scaler = StandardScaler()\n",
    "        for idx, shape in enumerate(Xred):\n",
    "            Xred[idx] = scaler.fit_transform(shape)\n",
    "\n",
    "        ## ========== GW =========\n",
    "        print(\"GW\")\n",
    "        tic = time()\n",
    "        try:\n",
    "            _ = my_pairwise_distances(X=Xred, Y=Xred, metric=metric_gw, n_jobs = n_jobs_time)\n",
    "            time_gw[i,j] = (time()-tic)/n_pairs\n",
    "        except ZeroDivisionError as err:\n",
    "            print('Handling run-time error:', err)\n",
    "            err_gw += 1\n",
    "            \n",
    "            \n",
    "\n",
    "        ## ============= Sliced GW ================\n",
    "        print(\"SGW\")\n",
    "        tic = time()\n",
    "        _ = my_pairwise_distances(X=Xred, Y=Xred, metric=metric_sgw, n_jobs = n_jobs_time, nproj=nproj)\n",
    "        time_sgw[i,j] = (time()-tic)/n_pairs\n",
    "\n",
    "        #print(f\"time = {time_sgw[i]}\")\n",
    "\n",
    "        ## =========== Distributional Min SSE ================\n",
    "        print(\"DMinSSE\")\n",
    "        transf_net = TransformNet(dim_latent).to(device)\n",
    "        fp = TransformLatenttoOrig(dim_latent,dim_s).to(device)\n",
    "        fq = TransformLatenttoOrig(dim_latent,dim_t).to(device)\n",
    "\n",
    "        transf_net_optim = optim.Adam(transf_net.parameters(), lr=0.001, betas=(0.5, 0.999),weight_decay=0.5)\n",
    "        fp_optim = optim.Adam(fp.parameters(), lr=0.001, betas=(0.5, 0.999),weight_decay=0.5)\n",
    "        fq_optim = optim.Adam(fq.parameters(), lr=0.001, betas=(0.5, 0.999),weight_decay=0.5)\n",
    "\n",
    "        tic = time()\n",
    "        _ = my_pairwise_distances(X=Xred, Y=Xred,metric=metric_distrib_min_sse, n_jobs=1,\n",
    "                                  transf_net = transf_net, s_latent2orig_net = fp, t_latent2orig_net = fq, \n",
    "                                  opt_trannet = transf_net_optim, opt_s=fp_optim, opt_t = fq_optim,\n",
    "                                  dim_latent = dim_latent, nproj_dist = nproj_dist_d, \n",
    "                                  num_epochs=max_epoch, num_sup_iter = n_iter_inner, num_inf_iter = n_iter_inner)\n",
    "\n",
    "        time_distrib_min_sse[i,j] = (time()-tic)/n_pairs    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = f\"shape_expe_{expe}_nproj_{nproj}_nproj_d_{nproj_dist_d}_nb{max_epoch}_iterinner{n_iter_inner:d}_latent{dim_latent:d}\"\n",
    "pathres='./result/shapes/'\n",
    "\n",
    "np.savez(pathres+filename,\n",
    "         time_gw=time_gw,c\n",
    "         time_sgw=time_sgw,\n",
    "         time_distrib_min_sse = time_distrib_min_sse,\n",
    "         vector_n = vector_n,\n",
    "         n_pairs = n_pairs,\n",
    "         n_repeat = n_repeat\n",
    "         )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "time_gw_m, time_gw_std = time_gw.mean(axis=1), time_gw.std(axis=1)\n",
    "time_sgw_m, time_sgw_std = time_sgw.mean(axis=1), time_sgw.std(axis=1)\n",
    "time_dmsse_m, time_dmsse_std= time_distrib_min_sse.mean(axis=1), time_distrib_min_sse.std(axis=1)\n",
    "\n",
    "plt.errorbar(vector_n, time_gw_m, yerr=time_gw_std, fmt='o-.', label=\"GW\")\n",
    "plt.errorbar(vector_n, time_sgw_m, yerr=time_sgw_std, fmt='s--', label=f\"SGW {nproj}\")\n",
    "plt.errorbar(vector_n, time_dmsse_m, yerr=time_dmsse_std, fmt='p-', label=f\"Distrib min SSE {nproj_dist_d}\")\n",
    "plt.legend()\n",
    "\n",
    "plt.yscale('log')\n",
    "plt.xscale('log')\n",
    "\n",
    "\n",
    "plt.savefig(f\"./figure/{filename}_2.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "time_gw_std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "time_gw_m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
