{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import time\n",
    "import pickle\n",
    "import torch\n",
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set device\n",
    "torch.cuda.set_device(0)\n",
    "\n",
    "# set random seed\n",
    "seed = 0\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sph_n_DataUtil import SnPosdataFromTh\n",
    "\n",
    "sph_dim = 2\n",
    "Sndataset = torch.load('S'+str(sph_dim)+'TangentGaussianMixture210521m2.pth')\n",
    "SnPosDataset = SnPosdataFromTh(Sndataset.train_data)\n",
    "trainPosInput = SnPosDataset.train_data.clone().cuda()\n",
    "N = Sndataset.train_data.shape[0]\n",
    "var = 0.01"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sph_n_DataUtil import getCoord_torch, metricInvSqrt_torch, christoffelSum_torch, getPosJacobianFromPos_torch\n",
    "traininput = getCoord_torch(trainPosInput)\n",
    "\n",
    "# values required for estimating scores and estimated score errors\n",
    "metricInv_sqrt_train = metricInvSqrt_torch(traininput)\n",
    "christoffel_sum_train = christoffelSum_torch(traininput)\n",
    "dx_dxth_train = getPosJacobianFromPos_torch(trainPosInput, eps=1e-6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GDAE Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gae_sph_n_ambient import GDAE_sph_ambient, GRCAE_sph_ambient\n",
    "\n",
    "noise_hyper_param_gae = 0.25\n",
    "hidden_dim = 1000\n",
    "input_dim = sph_dim + 1\n",
    "\n",
    "dim = [input_dim, hidden_dim]\n",
    "num_hidden_layers = 2\n",
    "gae_noise_std = noise_hyper_param_gae*np.sqrt(var)\n",
    "useLeakyReLU = False\n",
    "initial = 'xavier'\n",
    "\n",
    "model = GDAE_sph_ambient(dim, num_hidden_layers, gae_noise_std, \n",
    "                              useLeakyReLU = useLeakyReLU, initial = initial)\n",
    "\n",
    "model = model.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train GDAE (batch gradient descent)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 2.5e-5\n",
    "weight_decay = 1e-12\n",
    "optimizer = torch.optim.Adam(model.parameters(), \n",
    "                            lr=lr, weight_decay = weight_decay)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, loss: 1.466001\n",
      "epoch: 100, loss: 0.001587\n",
      "epoch: 200, loss: 0.001439\n",
      "epoch: 300, loss: 0.001357\n",
      "epoch: 400, loss: 0.001296\n",
      "epoch: 500, loss: 0.001270\n",
      "epoch: 600, loss: 0.001259\n",
      "epoch: 700, loss: 0.001237\n",
      "epoch: 800, loss: 0.001220\n",
      "epoch: 900, loss: 0.001197\n"
     ]
    }
   ],
   "source": [
    "from gae_sph_n_ambient_score_estimation import gae_sph_n_amb_estimate_score, gae_sph_n_amb_estimate_score_error\n",
    "\n",
    "max_iter_num = 1000\n",
    "checkEstErrorPeriod = 20\n",
    "\n",
    "gscore_est_error_set = []\n",
    "\n",
    "for epoch in range(max_iter_num):\n",
    "    optimizer.zero_grad()\n",
    "    loss = model.calculate_loss(trainPosInput)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "    ### below codes are for the case if we want to save models with the minimum estimated score error\n",
    "    if (epoch % checkEstErrorPeriod == 0 or epoch == max_iter_num-1):\n",
    "        est_train = gae_sph_n_amb_estimate_score(trainPosInput, model, dx_dxth_train)\n",
    "        cur_error = gae_sph_n_amb_estimate_score_error(trainPosInput, est_train, model, \n",
    "                         metricInv_sqrt_train, christoffel_sum_train, \n",
    "                             dx_dxth = dx_dxth_train)\n",
    "        gscore_est_error_set.append(cur_error)\n",
    "        \n",
    "        if epoch == 0:\n",
    "            best_model = copy.deepcopy(model.state_dict())\n",
    "            min_val = gscore_est_error_set[-1]\n",
    "            min_epoch = epoch\n",
    "        elif gscore_est_error_set[-1] <= min_val:\n",
    "            best_model = copy.deepcopy(model.state_dict())\n",
    "            min_val = gscore_est_error_set[-1]\n",
    "            min_epoch = epoch\n",
    "    if epoch % 100 == 0:\n",
    "        print(\"epoch: {:d}, loss: {:.6f}\".format(epoch, loss.item()/trainPosInput.shape[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GRCAE Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gae_sph_n_ambient import GDAE_sph_ambient, GRCAE_sph_ambient\n",
    "\n",
    "noise_hyper_param_gae = 0.25\n",
    "hidden_dim = 1000\n",
    "input_dim = sph_dim + 1\n",
    "\n",
    "dim = [input_dim, hidden_dim]\n",
    "num_hidden_layers = 2\n",
    "gae_noise_std = noise_hyper_param_gae*np.sqrt(var)\n",
    "useLeakyReLU = False\n",
    "initial = 'xavier'\n",
    "\n",
    "model = GRCAE_sph_ambient(dim, num_hidden_layers, gae_noise_std, \n",
    "                              useLeakyReLU = useLeakyReLU, initial = initial)\n",
    "\n",
    "model = model.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train GRCAE (stochastic gradient descent)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 2.5e-5\n",
    "weight_decay = 1e-12\n",
    "optimizer = torch.optim.Adam(model.parameters(), \n",
    "                            lr=lr, weight_decay = weight_decay)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 1000\n",
    "trainloader = torch.utils.data.DataLoader(SnPosDataset, batch_size=batch_size, \n",
    "                                              shuffle=True, num_workers = 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, loss: 0.001996\n",
      "epoch: 10, loss: 0.001593\n",
      "epoch: 20, loss: 0.001363\n",
      "epoch: 30, loss: 0.001264\n",
      "epoch: 40, loss: 0.001228\n",
      "epoch: 50, loss: 0.001213\n",
      "epoch: 60, loss: 0.001206\n",
      "epoch: 70, loss: 0.001201\n",
      "epoch: 80, loss: 0.001197\n",
      "epoch: 90, loss: 0.001194\n"
     ]
    }
   ],
   "source": [
    "from gae_sph_n_ambient_score_estimation import gae_sph_n_amb_estimate_score, gae_sph_n_amb_estimate_score_error\n",
    "\n",
    "max_iter_num = 100\n",
    "checkEstErrorPeriod = 20\n",
    "\n",
    "gscore_est_error_set = []\n",
    "\n",
    "for epoch in range(max_iter_num):\n",
    "    cur_loss = 0\n",
    "    for ii, data in enumerate(trainloader, 0):\n",
    "        optimizer.zero_grad()\n",
    "        loss = model.calculate_loss(data.cuda())\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        cur_loss += loss.item()\n",
    "        \n",
    "    if (epoch % checkEstErrorPeriod == 0 or epoch == max_iter_num-1):\n",
    "        est_train = gae_sph_n_amb_estimate_score(trainPosInput, model, dx_dxth_train)\n",
    "        cur_error = gae_sph_n_amb_estimate_score_error(trainPosInput, est_train, model, \n",
    "                         metricInv_sqrt_train, christoffel_sum_train, \n",
    "                             dx_dxth = dx_dxth_train)\n",
    "        gscore_est_error_set.append(cur_error)\n",
    "        \n",
    "        if epoch == 0:\n",
    "            best_model = copy.deepcopy(model.state_dict())\n",
    "            min_val = gscore_est_error_set[-1]\n",
    "            min_epoch = epoch\n",
    "        elif gscore_est_error_set[-1] <= min_val:\n",
    "            best_model = copy.deepcopy(model.state_dict())\n",
    "            min_val = gscore_est_error_set[-1]\n",
    "            min_epoch = epoch\n",
    "    if epoch % 10 == 0:\n",
    "        print(\"epoch: {:d}, loss: {:.6f}\".format(epoch, cur_loss/len(trainloader.dataset)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
