{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import time\n",
    "import pickle\n",
    "import torch\n",
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7f6ff40ef780>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# set device\n",
    "torch.cuda.set_device(0)\n",
    "\n",
    "# set random seed\n",
    "seed = 0\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "sph_dim = 2\n",
    "Sndataset = torch.load('S'+str(sph_dim)+'TangentGaussianMixture210521m2.pth')\n",
    "traininput = Sndataset.train_data.clone().cuda()\n",
    "N = Sndataset.train_data.shape[0]\n",
    "var = 0.01"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sph_n_DataUtil import metricInvSqrt_torch, metricInvDeriv_torch, christoffelSum_torch, christoffelSumDeriv_torch\n",
    "\n",
    "# values required for estimating scores and estimated score errors\n",
    "\n",
    "metricInv_sqrt_train = metricInvSqrt_torch(traininput)\n",
    "metricInv_train = metricInv_sqrt_train**2\n",
    "metricInvDeriv_train = metricInvDeriv_torch(traininput)\n",
    "christoffel_sum_train = christoffelSum_torch(traininput)\n",
    "christoffel_sumDeriv_train = christoffelSumDeriv_torch(traininput)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DAE Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gae2 import DAE, RCAE\n",
    "\n",
    "noise_hyper_param_dae = 0.25\n",
    "input_dim = sph_dim\n",
    "hidden_dim = 1000\n",
    "\n",
    "dim = [input_dim, hidden_dim]\n",
    "num_hidden_layers = 2\n",
    "dae_noise_std = noise_hyper_param_dae*np.sqrt(var)\n",
    "useLeakyReLU = False\n",
    "initial = 'x'\n",
    "\n",
    "model = DAE(dim, num_hidden_layers, dae_noise_std, useLeakyReLU = useLeakyReLU, \n",
    "            initial = initial)\n",
    "model = model.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train DAE (batch gradient descent)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 2.5e-5\n",
    "weight_decay = 1e-12\n",
    "optimizer = torch.optim.Adam(model.parameters(), \n",
    "                            lr=lr, weight_decay = weight_decay)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, loss: 0.109149\n",
      "epoch: 100, loss: 0.001238\n",
      "epoch: 200, loss: 0.001208\n",
      "epoch: 300, loss: 0.001226\n",
      "epoch: 400, loss: 0.001207\n",
      "epoch: 500, loss: 0.001223\n",
      "epoch: 600, loss: 0.001191\n",
      "epoch: 700, loss: 0.001200\n",
      "epoch: 800, loss: 0.001212\n",
      "epoch: 900, loss: 0.001215\n"
     ]
    }
   ],
   "source": [
    "from gae_score_estimation import dae_estimate_score, dae_estimate_score_deriv, estimate_gscore_error\n",
    "\n",
    "max_iter_num = 1000\n",
    "checkEstErrorPeriod = 20\n",
    "\n",
    "gscore_est_error_set = []\n",
    "\n",
    "for epoch in range(max_iter_num):\n",
    "    optimizer.zero_grad()\n",
    "    loss = model.calculate_loss(traininput)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "    if epoch % checkEstErrorPeriod == 0 or epoch == max_iter_num-1:\n",
    "        est_train = dae_estimate_score(traininput, model) - christoffel_sum_train\n",
    "        estDeriv_train = dae_estimate_score_deriv(traininput, model, model.noise_std**2, force_cpu=False) \\\n",
    "        - christoffel_sumDeriv_train\n",
    "        cur_error = estimate_gscore_error(est_train, estDeriv_train, metricInv_train, \n",
    "                                          metricInv_sqrt_train, metricInvDeriv_train, \n",
    "                                          christoffel_sum_train, diagonal_metric=True)\n",
    "        gscore_est_error_set.append(cur_error)\n",
    "        \n",
    "        if epoch == 0:\n",
    "            best_model = copy.deepcopy(model.state_dict())\n",
    "            min_val = gscore_est_error_set[-1]\n",
    "            min_epoch = epoch\n",
    "        elif gscore_est_error_set[-1] <= min_val:\n",
    "            best_model = copy.deepcopy(model.state_dict())\n",
    "            min_val = gscore_est_error_set[-1]\n",
    "            min_epoch = epoch\n",
    "    if epoch % 100 == 0:\n",
    "        print(\"epoch: {:d}, loss: {:.6f}\".format(epoch, loss.item()/traininput.shape[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RCAE Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gae2 import DAE, RCAE\n",
    "\n",
    "noise_hyper_param_rcae = 0.25\n",
    "input_dim = sph_dim\n",
    "\n",
    "dim = [input_dim, hidden_dim]\n",
    "num_hidden_layers = 2\n",
    "rcae_noise_std = noise_hyper_param_rcae*np.sqrt(var)\n",
    "useLeakyReLU = False\n",
    "initial = 'x'\n",
    "\n",
    "model = RCAE(dim, num_hidden_layers, rcae_noise_std, useLeakyReLU = useLeakyReLU, \n",
    "            initial = initial)\n",
    "model = model.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train RCAE (stochastic gradient descent)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 2.5e-5\n",
    "weight_decay = 1e-12\n",
    "optimizer = torch.optim.Adam(model.parameters(), \n",
    "                            lr=lr, weight_decay = weight_decay)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 1000\n",
    "trainloader = torch.utils.data.DataLoader(Sndataset, batch_size=batch_size, \n",
    "                                              shuffle=True, num_workers = 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, loss: 0.034078\n",
      "epoch: 10, loss: 0.001240\n",
      "epoch: 20, loss: 0.001231\n",
      "epoch: 30, loss: 0.001223\n",
      "epoch: 40, loss: 0.001217\n",
      "epoch: 50, loss: 0.001211\n",
      "epoch: 60, loss: 0.001207\n",
      "epoch: 70, loss: 0.001204\n",
      "epoch: 80, loss: 0.001203\n",
      "epoch: 90, loss: 0.001201\n"
     ]
    }
   ],
   "source": [
    "from gae_score_estimation import dae_estimate_score, dae_estimate_score_deriv, estimate_gscore_error\n",
    "\n",
    "max_iter_num = 100\n",
    "checkEstErrorPeriod = 20\n",
    "\n",
    "gscore_est_error_set = []\n",
    "\n",
    "for epoch in range(max_iter_num):\n",
    "    cur_loss = 0\n",
    "    for ii, (data, _, _) in enumerate(trainloader, 0):\n",
    "        optimizer.zero_grad()\n",
    "        loss = model.calculate_loss(data.cuda())\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        cur_loss += loss.item()\n",
    "        \n",
    "    if epoch % checkEstErrorPeriod == 0 or epoch == max_iter_num-1:\n",
    "        est_train = dae_estimate_score(traininput, model) - christoffel_sum_train\n",
    "        estDeriv_train = dae_estimate_score_deriv(traininput, model, model.noise_std**2, force_cpu=False) \\\n",
    "        - christoffel_sumDeriv_train\n",
    "        cur_error = estimate_gscore_error(est_train, estDeriv_train, metricInv_train, \n",
    "                                          metricInv_sqrt_train, metricInvDeriv_train, \n",
    "                                          christoffel_sum_train, diagonal_metric=True)\n",
    "        gscore_est_error_set.append(cur_error)\n",
    "        \n",
    "        if epoch == 0:\n",
    "            best_model = copy.deepcopy(model.state_dict())\n",
    "            min_val = gscore_est_error_set[-1]\n",
    "            min_epoch = epoch\n",
    "        elif gscore_est_error_set[-1] <= min_val:\n",
    "            best_model = copy.deepcopy(model.state_dict())\n",
    "            min_val = gscore_est_error_set[-1]\n",
    "            min_epoch = epoch\n",
    "    if epoch % 10 == 0:\n",
    "        print(\"epoch: {:d}, loss: {:.6f}\".format(epoch, cur_loss/len(trainloader.dataset)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
