{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c8497b0e-1827-47d8-8d68-2e56e420db20",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import time\n",
    "import pickle\n",
    "import torch\n",
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "92727ffc-37ca-46ea-ac54-fd7de460e3c7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7fd555736780>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# set device\n",
    "torch.cuda.set_device(0)\n",
    "\n",
    "# set random seed\n",
    "seed = 0\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8bff0199-e640-4592-b41a-108820aae39f",
   "metadata": {},
   "source": [
    "# Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "52c5cc40-6368-40ed-b3bb-442da9d7c526",
   "metadata": {},
   "outputs": [],
   "source": [
    "#datasetname = 'S49NewsgroupDocs210525c1.pth'      # 'crypt' task in Table 3 (or Appendix H.1.)\n",
    "#datasetname = 'S49NewsgroupDocs210525c2.pth'     # 'electronics' task in Table 3 (or Appendix H.1.)\n",
    "#datasetname = 'S49NewsgroupDocs210525c3.pth'     # 'med' task in Table 3 (or Appendix H.1.)\n",
    "datasetname = 'S49NewsgroupDocs210525c4.pth'     # 'space' task in Table 3 (or Appendix H.1.)\n",
    "SnPosDataset = torch.load(datasetname)\n",
    "\n",
    "trainPosInput = SnPosDataset.train_data.clone().cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "bc67d75e-311b-46ec-82bf-ff99a86f2321",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sph_n_DataUtil import getCoord_torch, metricInvSqrt_torch, christoffelSum_torch, getPosJacobianFromPos_torch\n",
    "traininput = getCoord_torch(trainPosInput)\n",
    "\n",
    "# values required for estimating scores and estimated score errors\n",
    "metricInv_sqrt_train = metricInvSqrt_torch(traininput)\n",
    "christoffel_sum_train = christoffelSum_torch(traininput)\n",
    "dx_dxth_train = getPosJacobianFromPos_torch(trainPosInput, eps=1e-6)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "354d5c6f-6622-4222-a1f2-8acf6a3cc2fe",
   "metadata": {},
   "source": [
    "# GDAE Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3603b3fc-aa9d-40db-bf38-9bdafa541f30",
   "metadata": {},
   "outputs": [],
   "source": [
    "from gae_sph_n_ambient import GDAE_sph_ambient\n",
    "\n",
    "sph_dim = 49\n",
    "input_dim = sph_dim + 1\n",
    "hidden_dim = 1000\n",
    "\n",
    "dim = [input_dim, hidden_dim]\n",
    "num_hidden_layers = 2\n",
    "gae_noise_std = 0.025\n",
    "useLeakyReLU = False          # use Tanh if False\n",
    "initial = 'xavier'\n",
    "\n",
    "model = GDAE_sph_ambient(dim, num_hidden_layers, gae_noise_std, useLeakyReLU = useLeakyReLU, initial = initial)\n",
    "model = model.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52ad379f-7a5a-4562-9490-3b70792a4dfa",
   "metadata": {},
   "source": [
    "# Train GDAE (batch gradient descent)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1e5e2b5b-2461-4b12-a31d-8f23a820adc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### NOTE: these hyperparameters may differ from those used in the experiments in the paper\n",
    "lr = 2.5e-3\n",
    "lr_schedule_num = 1\n",
    "weight_decay = 1e-12\n",
    "max_iter_num = 500000\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay = weight_decay)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, max_iter_num//(lr_schedule_num + 1), gamma=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9156c412-ff77-4037-b3f9-10df270327a8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0,  time: 8.5,  loss: 0.910179\n",
      "epoch: 10000,  time: 162.0,  loss: 0.015048\n",
      "epoch: 20000,  time: 316.4,  loss: 0.014411\n",
      "epoch: 30000,  time: 480.0,  loss: 0.013726\n",
      "epoch: 40000,  time: 631.2,  loss: 0.013245\n",
      "epoch: 50000,  time: 794.3,  loss: 0.012885\n",
      "epoch: 60000,  time: 948.0,  loss: 0.012726\n",
      "epoch: 70000,  time: 1101.7,  loss: 0.012432\n",
      "epoch: 80000,  time: 1263.9,  loss: 0.012311\n",
      "epoch: 90000,  time: 1418.3,  loss: 0.012130\n",
      "epoch: 100000,  time: 1580.6,  loss: 0.011963\n",
      "epoch: 110000,  time: 1734.1,  loss: 0.011826\n",
      "epoch: 120000,  time: 1888.4,  loss: 0.011856\n",
      "epoch: 130000,  time: 2051.1,  loss: 0.011734\n",
      "epoch: 140000,  time: 2204.8,  loss: 0.011615\n",
      "epoch: 150000,  time: 2367.1,  loss: 0.011517\n",
      "epoch: 160000,  time: 2521.3,  loss: 0.011468\n",
      "epoch: 170000,  time: 2675.9,  loss: 0.011393\n",
      "epoch: 180000,  time: 2839.2,  loss: 0.011358\n",
      "epoch: 190000,  time: 2993.5,  loss: 0.011304\n",
      "epoch: 200000,  time: 3156.7,  loss: 0.011314\n",
      "epoch: 210000,  time: 3311.0,  loss: 0.011230\n",
      "epoch: 220000,  time: 3467.3,  loss: 0.011232\n",
      "epoch: 230000,  time: 3630.8,  loss: 0.011187\n",
      "epoch: 240000,  time: 3787.3,  loss: 0.011151\n",
      "epoch: 250000,  time: 3951.5,  loss: 0.011137\n",
      "epoch: 260000,  time: 4106.1,  loss: 0.010991\n",
      "epoch: 270000,  time: 4260.5,  loss: 0.011004\n",
      "epoch: 280000,  time: 4422.9,  loss: 0.010994\n",
      "epoch: 290000,  time: 4576.1,  loss: 0.010971\n",
      "epoch: 300000,  time: 4741.0,  loss: 0.011012\n",
      "epoch: 310000,  time: 4896.5,  loss: 0.010974\n",
      "epoch: 320000,  time: 5050.7,  loss: 0.011004\n",
      "epoch: 330000,  time: 5213.6,  loss: 0.010981\n",
      "epoch: 340000,  time: 5368.6,  loss: 0.010981\n",
      "epoch: 350000,  time: 5537.4,  loss: 0.010938\n",
      "epoch: 360000,  time: 5692.1,  loss: 0.010974\n",
      "epoch: 370000,  time: 5846.1,  loss: 0.010970\n",
      "epoch: 380000,  time: 6009.3,  loss: 0.010980\n",
      "epoch: 390000,  time: 6163.2,  loss: 0.010967\n",
      "epoch: 400000,  time: 6325.8,  loss: 0.010916\n",
      "epoch: 410000,  time: 6480.6,  loss: 0.010920\n",
      "epoch: 420000,  time: 6635.0,  loss: 0.010911\n",
      "epoch: 430000,  time: 6799.1,  loss: 0.010940\n",
      "epoch: 440000,  time: 6953.8,  loss: 0.010887\n",
      "epoch: 450000,  time: 7117.1,  loss: 0.010908\n",
      "epoch: 460000,  time: 7271.1,  loss: 0.010909\n",
      "epoch: 470000,  time: 7425.3,  loss: 0.010926\n",
      "epoch: 480000,  time: 7589.0,  loss: 0.010929\n",
      "epoch: 490000,  time: 7743.4,  loss: 0.010906\n"
     ]
    }
   ],
   "source": [
    "from gae_sph_n_ambient_score_estimation import gae_sph_n_amb_estimate_score, gae_sph_n_amb_estimate_score_error\n",
    "\n",
    "checkEstErrorPeriod = max_iter_num // 20\n",
    "gscore_est_error_set = []\n",
    "\n",
    "start = time.time()\n",
    "for epoch in range(max_iter_num):\n",
    "    optimizer.zero_grad()\n",
    "    loss = model.calculate_loss(trainPosInput)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    scheduler.step()\n",
    "    \n",
    "    ### below codes are for the case if we want to save models with the minimum estimated score error\n",
    "    if (epoch % checkEstErrorPeriod == 0 or epoch == max_iter_num-1):\n",
    "        est_train = gae_sph_n_amb_estimate_score(trainPosInput, model, dx_dxth_train)\n",
    "        cur_error = gae_sph_n_amb_estimate_score_error(trainPosInput, est_train, model, \n",
    "                         metricInv_sqrt_train, christoffel_sum_train, \n",
    "                             dx_dxth = dx_dxth_train)\n",
    "        gscore_est_error_set.append(cur_error)\n",
    "        \n",
    "        if epoch == 0:\n",
    "            best_model = copy.deepcopy(model.state_dict())\n",
    "            min_val = gscore_est_error_set[-1]\n",
    "            min_epoch = epoch\n",
    "        elif gscore_est_error_set[-1] <= min_val:\n",
    "            best_model = copy.deepcopy(model.state_dict())\n",
    "            min_val = gscore_est_error_set[-1]\n",
    "            min_epoch = epoch\n",
    "    if epoch % 10000 == 0:\n",
    "        print(\"epoch: {:d},  time: {:.1f},  loss: {:.6f}\".format(epoch, time.time() - start, \n",
    "                                                                 loss.item()/trainPosInput.shape[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "344f507b-68ab-4cbe-a20b-6048e55b0095",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-106644.0267851555"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "min_val"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5f8a90c-789e-4957-80fd-01fd8ffcf9f1",
   "metadata": {},
   "source": [
    "# Mean shift clustering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "50f1a684-9345-4cdd-87b9-52bc534212e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_iter = 10000\n",
    "step_size = 1.0\n",
    "mode = 'large_memory'      # 'small_memory': use smaller memory but a bit slow, 'large_memory': use larger memory but faster\n",
    "dist_tolerance = 0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8e469993-f1c9-44b8-98fe-42098568b7c9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iter: 1000 ---- time 1.7 ----  num points to shift: 5\n",
      "iter: 2000 ---- time 3.3 ----  num points to shift: 5\n",
      "iter: 3000 ---- time 4.9 ----  num points to shift: 5\n",
      "iter: 4000 ---- time 6.4 ----  num points to shift: 5\n",
      "iter: 5000 ---- time 8.0 ----  num points to shift: 5\n",
      "iter: 6000 ---- time 9.6 ----  num points to shift: 5\n",
      "iter: 7000 ---- time 11.1 ----  num points to shift: 5\n",
      "iter: 8000 ---- time 12.7 ----  num points to shift: 5\n",
      "iter: 9000 ---- time 14.3 ----  num points to shift: 5\n",
      "iter: 10000 ---- time 15.8 ----  num points to shift: 5\n",
      "shifting exceeded max_iter... 5 points not converged with max, mean, min distances 0.010902, 0.006756, 0.000846\n",
      "\n",
      " start point grouping...\n",
      "time 0.0 --- calculate pairwise distances using gpu\n",
      "time 7.5 --- calculation finished\n",
      "time 7.5 --- 0-th point grouping... currently having 0 groups\n",
      "time 10.4 --- 1000-th point grouping... currently having 167 groups\n",
      "time 17.6 --- 2000-th point grouping... currently having 276 groups\n",
      "time 28.9 --- 3000-th point grouping... currently having 408 groups\n",
      "time 44.7 --- 4000-th point grouping... currently having 536 groups\n",
      "time 64.9 --- 5000-th point grouping... currently having 662 groups\n",
      "time 89.1 --- 6000-th point grouping... currently having 779 groups\n",
      "time 117.0 --- 7000-th point grouping... currently having 895 groups\n",
      "time 149.3 --- 8000-th point grouping... currently having 1031 groups\n",
      "time 186.1 --- 9000-th point grouping... currently having 1164 groups\n",
      "time 226.9 --- 10000-th point grouping... currently having 1267 groups\n",
      "time 271.4 --- 11000-th point grouping... currently having 1384 groups\n",
      "time 319.6 --- 12000-th point grouping... currently having 1478 groups\n",
      "time 371.1 --- 13000-th point grouping... currently having 1588 groups\n",
      "time 426.6 --- 14000-th point grouping... currently having 1691 groups\n",
      "time 485.5 --- 15000-th point grouping... currently having 1799 groups\n",
      "adjusted random index: 0.4437836593517063\n"
     ]
    }
   ],
   "source": [
    "from mean_shift_sph import MeanShift_gdae_sph_ambient\n",
    "from sklearn import metrics\n",
    "\n",
    "##### perform mean shift clustering\n",
    "ms_gdae = MeanShift_gdae_sph_ambient(model, step_size = step_size)\n",
    "\n",
    "result = ms_gdae.cluster(SnPosDataset.train_data.cuda(), max_iter, printEpochPeriod = 1000, loggingFileName = None, \n",
    "                         mode = mode, distTol = dist_tolerance)\n",
    "\n",
    "labels_true = SnPosDataset.true_labels\n",
    "labels_pred = result.cluster_ids\n",
    "ari = metrics.adjusted_rand_score(labels_true, labels_pred)\n",
    "print(\"adjusted random index: \" + str(ari))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff8fea73-3afd-4b62-8cf6-120fa944d790",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
