{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import time\n",
    "import pickle\n",
    "import torch\n",
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7f6e79fac780>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# set device\n",
    "torch.cuda.set_device(0)\n",
    "\n",
    "# set random seed\n",
    "seed = 0\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from Pn_DataUtil import PndataTangentGaussianMixtureExpanded\n",
    "pd_dim = 2\n",
    "var = 0.01\n",
    "vec_dim = int(pd_dim*(pd_dim+1) / 2)\n",
    "\n",
    "Pndataset = torch.load('P'+str(pd_dim)+'TangentGaussianMixture210912m2.pth')\n",
    "Pndataset2 = PndataTangentGaussianMixtureExpanded(Pndataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GDAE Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gae_pd import GDAE_P_n_fromLog\n",
    "\n",
    "noise_hyper_param_gae = 0.25\n",
    "input_dim = vec_dim\n",
    "hidden_dim = 1000\n",
    "\n",
    "dim = [input_dim, hidden_dim]\n",
    "num_hidden_layers = 2\n",
    "gae_noise_std = noise_hyper_param_gae * np.sqrt(var)\n",
    "useLeakyReLU = False\n",
    "initial = 'xavier'\n",
    "exp_approx = 2         # 1 ~ 4: 1st~4th order approx., 'else': no approximation\n",
    "log_approx = 2         # 1 ~ 4: 1st~4th order approx. for the loss, 'else': no approximation \n",
    "\n",
    "model = GDAE_P_n_fromLog(dim, num_hidden_layers, gae_noise_std, useLeakyReLU = useLeakyReLU, \n",
    "                              initial = initial, \n",
    "                              exp_approx = exp_approx, log_approx = log_approx)\n",
    "model = model.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train GDAE (batch gradient descent)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from Pn_util import vec2mat, mat2vec, batch_eigsym, Log_mat, metricInv_sqrt_P_n\n",
    "X = vec2mat(Pndataset2.train_data.cuda())\n",
    "X_sqrt = Pndataset2.train_data_sqrt.cuda()\n",
    "X_invsqrt = Pndataset2.train_data_invsqrt.cuda()\n",
    "x = Pndataset2.logx.cuda()\n",
    "\n",
    "# values required for estimating scores and estimated score errors\n",
    "metric_train = Pndataset2.metric.cuda()\n",
    "metricInv_sqrt_train = metricInv_sqrt_P_n(X)\n",
    "X_sqrt_dirderiv_set = Pndataset2.X_sqrt_dirderiv_set.cuda()\n",
    "dLog_xdx = Pndataset2.dLog_xdx.cuda()\n",
    "christoffel_sum_train = Pndataset2.christoffel_sum.cuda()\n",
    "other_quantities_at_x = [dLog_xdx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 2.5e-5\n",
    "weight_decay = 1e-12\n",
    "optimizer = torch.optim.Adam(model.parameters(), \n",
    "                             lr=lr, weight_decay = weight_decay)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, loss: 0.091563\n",
      "epoch: 100, loss: 0.001881\n",
      "epoch: 200, loss: 0.001831\n",
      "epoch: 300, loss: 0.001864\n",
      "epoch: 400, loss: 0.001829\n",
      "epoch: 500, loss: 0.001837\n",
      "epoch: 600, loss: 0.001812\n",
      "epoch: 700, loss: 0.001818\n",
      "epoch: 800, loss: 0.001846\n",
      "epoch: 900, loss: 0.001853\n"
     ]
    }
   ],
   "source": [
    "from gae_pd_score_estimation import gae_P_n_estimate_score, gae_P_n_estimate_score_error\n",
    "\n",
    "max_iter_num = 1000\n",
    "checkEstErrorPeriod = 20\n",
    "\n",
    "gscore_est_error_set = []\n",
    "\n",
    "for epoch in range(max_iter_num):\n",
    "    optimizer.zero_grad()\n",
    "    loss = model.calculate_loss(x, X, X_sqrt, X_invsqrt)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "    if (epoch % checkEstErrorPeriod == 0 or epoch == max_iter_num-1):\n",
    "        est_train = gae_P_n_estimate_score(x, X_sqrt, metric_train, model)\n",
    "        cur_error = gae_P_n_estimate_score_error(x, X_sqrt, est_train, model, model.noise_std**2, \n",
    "                         metricInv_sqrt_train, X_sqrt_dirderiv_set, christoffel_sum_train, \n",
    "                             diagonal_metric=False, other_quantities_at_x = other_quantities_at_x)\n",
    "        gscore_est_error_set.append(cur_error)\n",
    "        \n",
    "    if epoch == 0:\n",
    "        best_model = copy.deepcopy(model.state_dict())\n",
    "        min_val = gscore_est_error_set[-1]\n",
    "        min_epoch = epoch\n",
    "    elif gscore_est_error_set[-1] <= min_val:\n",
    "        best_model = copy.deepcopy(model.state_dict())\n",
    "        min_val = gscore_est_error_set[-1]\n",
    "        min_epoch = epoch\n",
    "    if epoch % 100 == 0:\n",
    "        print(\"epoch: {:d}, loss: {:.6f}\".format(epoch, loss.item()/x.shape[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GRCAE Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gae_pd import GRCAE_P_n_fromLog\n",
    "\n",
    "noise_hyper_param_gae = 0.25\n",
    "input_dim = vec_dim\n",
    "hidden_dim = 1000\n",
    "\n",
    "dim = [input_dim, hidden_dim]\n",
    "num_hidden_layers = 2\n",
    "gae_noise_std = noise_hyper_param_gae * np.sqrt(var)\n",
    "useLeakyReLU = False\n",
    "initial = 'xavier'\n",
    "exp_approx = 2         # 1 ~ 4: 1st~4th order approx., 'else': no approximation\n",
    "\n",
    "model = GRCAE_P_n_fromLog(dim, num_hidden_layers, gae_noise_std, useLeakyReLU = useLeakyReLU, \n",
    "                          initial = initial, exp_approx = exp_approx)\n",
    "model = model.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train GRCAE (stochastic gradient descent)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from Pn_util import vec2mat, mat2vec, batch_eigsym, Log_mat, metricInv_sqrt_P_n\n",
    "X = vec2mat(Pndataset2.train_data.cuda())\n",
    "X_sqrt = Pndataset2.train_data_sqrt.cuda()\n",
    "X_invsqrt = Pndataset2.train_data_invsqrt.cuda()\n",
    "x = Pndataset2.logx.cuda()\n",
    "\n",
    "# values required for estimating scores and estimated score errors\n",
    "metric_train = Pndataset2.metric.cuda()\n",
    "metricInv_sqrt_train = metricInv_sqrt_P_n(X)\n",
    "X_sqrt_dirderiv_set = Pndataset2.X_sqrt_dirderiv_set.cuda()\n",
    "dLog_xdx = Pndataset2.dLog_xdx.cuda()\n",
    "christoffel_sum_train = Pndataset2.christoffel_sum.cuda()\n",
    "other_quantities_at_x = [dLog_xdx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 2.5e-5\n",
    "weight_decay = 1e-12\n",
    "optimizer = torch.optim.Adam(model.parameters(), \n",
    "                             lr=lr, weight_decay = weight_decay)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 1000\n",
    "trainloader = torch.utils.data.DataLoader(Pndataset2, batch_size=batch_size, \n",
    "                                              shuffle=True, num_workers = 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, loss: 0.445487\n",
      "epoch: 10, loss: 0.002312\n",
      "epoch: 20, loss: 0.001924\n",
      "epoch: 30, loss: 0.001881\n",
      "epoch: 40, loss: 0.001872\n",
      "epoch: 50, loss: 0.001868\n",
      "epoch: 60, loss: 0.001866\n",
      "epoch: 70, loss: 0.001864\n",
      "epoch: 80, loss: 0.001862\n",
      "epoch: 90, loss: 0.001860\n"
     ]
    }
   ],
   "source": [
    "from gae_pd_score_estimation import gae_P_n_estimate_score, gae_P_n_estimate_score_error\n",
    "\n",
    "max_iter_num = 100\n",
    "checkEstErrorPeriod = 20\n",
    "\n",
    "gscore_est_error_set = []\n",
    "\n",
    "for epoch in range(max_iter_num):\n",
    "    cur_loss = 0\n",
    "    for ii, data in enumerate(trainloader, 0):\n",
    "        _, cur_logx, cur_X_sqrt, _, _, cur_metricInv_train, _, cur_X_sqrt_dirderiv_set, cur_dLog_xdx, \\\n",
    "        _, _, _ = data\n",
    "        optimizer.zero_grad()\n",
    "        loss = model.calculate_loss(cur_logx.cuda(), cur_X_sqrt.cuda(), cur_metricInv_train.cuda(), \n",
    "                                    X_sqrt_dirderiv_set = cur_X_sqrt_dirderiv_set.cuda(), \n",
    "                                    other_quantities_for_loss_at_x = [cur_dLog_xdx.cuda()])\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        cur_loss += loss.item()\n",
    "        \n",
    "    if (epoch % checkEstErrorPeriod == 0 or epoch == max_iter_num-1):\n",
    "        est_train = gae_P_n_estimate_score(x, X_sqrt, metric_train, model)\n",
    "        cur_error = gae_P_n_estimate_score_error(x, X_sqrt, est_train, model, model.noise_std**2, \n",
    "                         metricInv_sqrt_train, X_sqrt_dirderiv_set, christoffel_sum_train, \n",
    "                             diagonal_metric=False, other_quantities_at_x = other_quantities_at_x)\n",
    "        gscore_est_error_set.append(cur_error)\n",
    "        \n",
    "        if epoch == 0:\n",
    "            best_model = copy.deepcopy(model.state_dict())\n",
    "            min_val = gscore_est_error_set[-1]\n",
    "            min_epoch = epoch\n",
    "        elif gscore_est_error_set[-1] <= min_val:\n",
    "            best_model = copy.deepcopy(model.state_dict())\n",
    "            min_val = gscore_est_error_set[-1]\n",
    "            min_epoch = epoch\n",
    "    if epoch % 10 == 0:\n",
    "        print(\"epoch: {:d}, loss: {:.6f}\".format(epoch, cur_loss/len(trainloader.dataset)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
