{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import time\n",
    "import pickle\n",
    "import torch\n",
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set device\n",
    "torch.cuda.set_device(0)\n",
    "\n",
    "# set random seed\n",
    "seed = 0\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd_dim = 2\n",
    "var = 0.01\n",
    "vec_dim = int(pd_dim*(pd_dim+1) / 2)\n",
    "\n",
    "Pndataset = torch.load('P'+str(pd_dim)+'TangentGaussianMixture210912m2.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DAE Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gae2 import DAE\n",
    "\n",
    "noise_hyper_param_dae = 0.25\n",
    "input_dim = vec_dim\n",
    "hidden_dim = 1000\n",
    "\n",
    "dim = [input_dim, hidden_dim]\n",
    "num_hidden_layers = 2\n",
    "dae_noise_std = noise_hyper_param_dae * np.sqrt(var)\n",
    "useLeakyReLU = False\n",
    "initial = 'xavier'\n",
    "\n",
    "model = DAE(dim, num_hidden_layers, dae_noise_std, useLeakyReLU = useLeakyReLU, initial = initial)\n",
    "model = model.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train DAE (batch gradient descent)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from Pn_util import vec2mat, mat2vec, metricInv_sqrt_P_n, metricInv_P_n, metricInvDeriv_P_n, christoffelSum_P_n, christoffelSumDeriv_P_n\n",
    "traininput = Pndataset.train_data.cuda()\n",
    "\n",
    "# values required for estimating scores and estimated score errors\n",
    "X = vec2mat(Pndataset.train_data.cuda())\n",
    "X_inv = torch.inverse(X)\n",
    "metricInv_sqrt_train = metricInv_sqrt_P_n(X, X_inv)\n",
    "metricInv_train = metricInv_P_n(X)\n",
    "metricInvDeriv_train = metricInvDeriv_P_n(X)\n",
    "christoffel_sum_train = christoffelSum_P_n(X, X_inv)\n",
    "christoffel_sumDeriv_train = christoffelSumDeriv_P_n(X, X_inv)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 2.5e-5\n",
    "weight_decay = 1e-12\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay = weight_decay)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, loss: 2.897554\n",
      "epoch: 100, loss: 0.003064\n",
      "epoch: 200, loss: 0.002212\n",
      "epoch: 300, loss: 0.002011\n",
      "epoch: 400, loss: 0.001942\n",
      "epoch: 500, loss: 0.001944\n",
      "epoch: 600, loss: 0.001870\n",
      "epoch: 700, loss: 0.001870\n",
      "epoch: 800, loss: 0.001876\n",
      "epoch: 900, loss: 0.001853\n"
     ]
    }
   ],
   "source": [
    "from gae_score_estimation import dae_estimate_score, dae_estimate_score_deriv, estimate_gscore_error\n",
    "\n",
    "max_iter_num = 1000\n",
    "checkEstErrorPeriod = 20\n",
    "\n",
    "gscore_est_error_set = []\n",
    "\n",
    "for epoch in range(max_iter_num):\n",
    "    optimizer.zero_grad()\n",
    "    loss = model.calculate_loss(traininput)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "    if (epoch % checkEstErrorPeriod == 0 or epoch == max_iter_num-1):\n",
    "        est_train = dae_estimate_score(traininput, model) - christoffel_sum_train\n",
    "        estDeriv_train = dae_estimate_score_deriv(traininput, model, model.noise_std**2, force_cpu=False) \\\n",
    "        - christoffel_sumDeriv_train\n",
    "        cur_error = estimate_gscore_error(est_train, estDeriv_train, metricInv_train, metricInv_sqrt_train, \n",
    "                          metricInvDeriv_train, christoffel_sum_train, diagonal_metric=False)\n",
    "        gscore_est_error_set.append(cur_error)\n",
    "        \n",
    "    if epoch == 0:\n",
    "        best_model = copy.deepcopy(model.state_dict())\n",
    "        min_val = gscore_est_error_set[-1]\n",
    "        min_epoch = epoch\n",
    "    elif gscore_est_error_set[-1] <= min_val:\n",
    "        best_model = copy.deepcopy(model.state_dict())\n",
    "        min_val = gscore_est_error_set[-1]\n",
    "        min_epoch = epoch\n",
    "        \n",
    "    if epoch % 100 == 0:\n",
    "        print(\"epoch: {:d}, loss: {:.6f}\".format(epoch, loss.item()/traininput.shape[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RCAE Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gae2 import RCAE\n",
    "\n",
    "noise_hyper_param_rcae = 0.25\n",
    "input_dim = vec_dim\n",
    "hidden_dim = 1000\n",
    "\n",
    "dim = [input_dim, hidden_dim]\n",
    "num_hidden_layers = 2\n",
    "rcae_noise_std = noise_hyper_param_rcae * np.sqrt(var)\n",
    "useLeakyReLU = False\n",
    "initial = 'xavier'\n",
    "\n",
    "model = RCAE(dim, num_hidden_layers, rcae_noise_std, useLeakyReLU = useLeakyReLU, initial = initial)\n",
    "model = model.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train RCAE (stochastic gradient descent)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from Pn_util import vec2mat, mat2vec, metricInv_sqrt_P_n, metricInv_P_n, metricInvDeriv_P_n, christoffelSum_P_n, christoffelSumDeriv_P_n\n",
    "traininput = Pndataset.train_data.cuda()\n",
    "X = vec2mat(Pndataset.train_data.cuda())\n",
    "X_inv = torch.inverse(X)\n",
    "metricInv_sqrt_train = metricInv_sqrt_P_n(X, X_inv)\n",
    "metricInv_train = metricInv_P_n(X)\n",
    "metricInvDeriv_train = metricInvDeriv_P_n(X)\n",
    "christoffel_sum_train = christoffelSum_P_n(X, X_inv)\n",
    "christoffel_sumDeriv_train = christoffelSumDeriv_P_n(X, X_inv)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 2.5e-5\n",
    "weight_decay = 1e-12\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay = weight_decay)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 1000\n",
    "trainloader = torch.utils.data.DataLoader(Pndataset, batch_size=batch_size, \n",
    "                                              shuffle=True, num_workers = 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, loss: 0.680584\n",
      "epoch: 10, loss: 0.002793\n",
      "epoch: 20, loss: 0.002150\n",
      "epoch: 30, loss: 0.001927\n",
      "epoch: 40, loss: 0.001867\n",
      "epoch: 50, loss: 0.001853\n",
      "epoch: 60, loss: 0.001848\n",
      "epoch: 70, loss: 0.001845\n",
      "epoch: 80, loss: 0.001842\n",
      "epoch: 90, loss: 0.001840\n"
     ]
    }
   ],
   "source": [
    "from gae_score_estimation import dae_estimate_score, dae_estimate_score_deriv, estimate_gscore_error\n",
    "\n",
    "max_iter_num = 100\n",
    "checkEstErrorPeriod = 20\n",
    "\n",
    "gscore_est_error_set = []\n",
    "\n",
    "for epoch in range(max_iter_num):\n",
    "    cur_loss = 0\n",
    "    for ii, (data, _, _) in enumerate(trainloader, 0):\n",
    "        optimizer.zero_grad()\n",
    "        loss = model.calculate_loss(data.cuda())\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        cur_loss += loss.item()\n",
    "    \n",
    "    if (epoch % checkEstErrorPeriod == 0 or epoch == max_iter_num-1):\n",
    "        est_train = dae_estimate_score(traininput, model) - christoffel_sum_train\n",
    "        estDeriv_train = dae_estimate_score_deriv(traininput, model, model.noise_std**2, force_cpu=False) \\\n",
    "        - christoffel_sumDeriv_train\n",
    "        cur_error = estimate_gscore_error(est_train, estDeriv_train, metricInv_train, metricInv_sqrt_train, \n",
    "                          metricInvDeriv_train, christoffel_sum_train, diagonal_metric=False)\n",
    "        gscore_est_error_set.append(cur_error)\n",
    "        \n",
    "    if epoch == 0:\n",
    "        best_model = copy.deepcopy(model.state_dict())\n",
    "        min_val = gscore_est_error_set[-1]\n",
    "        min_epoch = epoch\n",
    "    elif gscore_est_error_set[-1] <= min_val:\n",
    "        best_model = copy.deepcopy(model.state_dict())\n",
    "        min_val = gscore_est_error_set[-1]\n",
    "        min_epoch = epoch\n",
    "    if epoch % 10 == 0:\n",
    "        print(\"epoch: {:d}, loss: {:.6f}\".format(epoch, cur_loss/len(trainloader.dataset)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
