{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import time\n",
    "import pickle\n",
    "import torch\n",
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7f36f0176780>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# set device\n",
    "torch.cuda.set_device(0)\n",
    "\n",
    "# set random seed\n",
    "seed = 0\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from Pn_DataUtil import PndataTangentGaussianMixtureExpanded\n",
    "pd_dim = 2\n",
    "mix_num = 2\n",
    "var = 0.01\n",
    "vec_dim = int(pd_dim*(pd_dim+1) / 2)\n",
    "\n",
    "Pndataset = torch.load('P'+str(pd_dim)+'TangentGaussianMixture210912m'+str(mix_num)+'.pth')\n",
    "Pndataset2 = PndataTangentGaussianMixtureExpanded(Pndataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GDAE Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gae_pd import GDAE_P_n_fromLog\n",
    "\n",
    "noise_hyper_param_gae = 0.25\n",
    "input_dim = vec_dim\n",
    "hidden_dim = 1000\n",
    "\n",
    "dim = [input_dim, hidden_dim]\n",
    "num_hidden_layers = 2\n",
    "gae_noise_std = noise_hyper_param_gae * np.sqrt(var)\n",
    "useLeakyReLU = False\n",
    "initial = 'xavier'\n",
    "exp_approx = 2         # 1 ~ 4: 1st~4th order approx., 'else': no approximation\n",
    "log_approx = 2         # 1 ~ 4: 1st~4th order approx. for the loss, 'else': no approximation \n",
    "initDiv = 2\n",
    "\n",
    "model = GDAE_P_n_fromLog(dim, num_hidden_layers, gae_noise_std, useLeakyReLU = useLeakyReLU, \n",
    "                              initial = initial, \n",
    "                              exp_approx = exp_approx, log_approx = log_approx)\n",
    "if initDiv != 1:\n",
    "    for i in range(num_hidden_layers + 1):\n",
    "        model.autoencoder[2*i].weight.data /= initDiv\n",
    "        model.autoencoder[2*i].bias.data /= initDiv\n",
    "\n",
    "model = model.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train GDAE (batch gradient descent)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from Pn_util import vec2mat, mat2vec, batch_eigsym, Log_mat, metricInv_sqrt_P_n\n",
    "X = vec2mat(Pndataset2.train_data.cuda())\n",
    "X_sqrt = Pndataset2.train_data_sqrt.cuda()\n",
    "X_invsqrt = Pndataset2.train_data_invsqrt.cuda()\n",
    "x = Pndataset2.logx.cuda()\n",
    "\n",
    "# values required for estimating scores and estimated score errors\n",
    "metric_train = Pndataset2.metric.cuda()\n",
    "metricInv_sqrt_train = metricInv_sqrt_P_n(X)\n",
    "X_sqrt_dirderiv_set = Pndataset2.X_sqrt_dirderiv_set.cuda()\n",
    "dLog_xdx = Pndataset2.dLog_xdx.cuda()\n",
    "christoffel_sum_train = Pndataset2.christoffel_sum.cuda()\n",
    "other_quantities_at_x = [dLog_xdx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "#### NOTE: these hyperparameters may differ from those used in the experiments in the paper\n",
    "lr = 2.5e-5\n",
    "weight_decay = 1e-12\n",
    "max_iter_num = 500000\n",
    "lr_schedule_num = 1\n",
    "optimizer = torch.optim.Adam(model.parameters(), \n",
    "                             lr=lr, weight_decay = weight_decay)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, max_iter_num//(lr_schedule_num + 1), gamma=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0,  time: 0.1,  loss: 0.004468\n",
      "epoch: 1000,  time: 25.5,  loss: 0.001859\n",
      "epoch: 2000,  time: 51.0,  loss: 0.001856\n",
      "epoch: 3000,  time: 76.3,  loss: 0.001826\n",
      "epoch: 4000,  time: 100.7,  loss: 0.001819\n",
      "epoch: 5000,  time: 124.9,  loss: 0.001823\n",
      "epoch: 6000,  time: 149.4,  loss: 0.001832\n",
      "epoch: 7000,  time: 175.3,  loss: 0.001842\n",
      "epoch: 8000,  time: 201.2,  loss: 0.001823\n",
      "epoch: 9000,  time: 227.2,  loss: 0.001818\n",
      "epoch: 10000,  time: 252.9,  loss: 0.001841\n",
      "epoch: 11000,  time: 278.6,  loss: 0.001813\n",
      "epoch: 12000,  time: 304.4,  loss: 0.001839\n",
      "epoch: 13000,  time: 330.0,  loss: 0.001846\n",
      "epoch: 14000,  time: 355.8,  loss: 0.001826\n",
      "epoch: 15000,  time: 380.9,  loss: 0.001855\n",
      "epoch: 16000,  time: 406.3,  loss: 0.001829\n",
      "epoch: 17000,  time: 431.9,  loss: 0.001817\n",
      "epoch: 18000,  time: 457.5,  loss: 0.001811\n",
      "epoch: 19000,  time: 483.1,  loss: 0.001819\n",
      "epoch: 20000,  time: 508.7,  loss: 0.001823\n",
      "epoch: 21000,  time: 534.0,  loss: 0.001826\n",
      "epoch: 22000,  time: 559.0,  loss: 0.001839\n",
      "epoch: 23000,  time: 584.4,  loss: 0.001786\n",
      "epoch: 24000,  time: 610.0,  loss: 0.001804\n",
      "epoch: 25000,  time: 635.5,  loss: 0.001837\n",
      "epoch: 26000,  time: 660.9,  loss: 0.001804\n",
      "epoch: 27000,  time: 686.1,  loss: 0.001821\n",
      "epoch: 28000,  time: 711.3,  loss: 0.001814\n",
      "epoch: 29000,  time: 736.3,  loss: 0.001830\n",
      "epoch: 30000,  time: 761.3,  loss: 0.001816\n",
      "epoch: 31000,  time: 787.2,  loss: 0.001827\n",
      "epoch: 32000,  time: 813.3,  loss: 0.001864\n",
      "epoch: 33000,  time: 839.3,  loss: 0.001834\n",
      "epoch: 34000,  time: 865.3,  loss: 0.001799\n",
      "epoch: 35000,  time: 891.3,  loss: 0.001827\n",
      "epoch: 36000,  time: 917.2,  loss: 0.001846\n",
      "epoch: 37000,  time: 942.9,  loss: 0.001798\n",
      "epoch: 38000,  time: 968.6,  loss: 0.001829\n",
      "epoch: 39000,  time: 993.9,  loss: 0.001833\n",
      "epoch: 40000,  time: 1019.1,  loss: 0.001817\n",
      "epoch: 41000,  time: 1044.6,  loss: 0.001827\n",
      "epoch: 42000,  time: 1070.0,  loss: 0.001825\n",
      "epoch: 43000,  time: 1095.1,  loss: 0.001824\n",
      "epoch: 44000,  time: 1120.7,  loss: 0.001840\n",
      "epoch: 45000,  time: 1146.5,  loss: 0.001813\n",
      "epoch: 46000,  time: 1172.4,  loss: 0.001792\n",
      "epoch: 47000,  time: 1197.6,  loss: 0.001785\n",
      "epoch: 48000,  time: 1222.7,  loss: 0.001815\n",
      "epoch: 49000,  time: 1247.7,  loss: 0.001831\n",
      "epoch: 50000,  time: 1273.4,  loss: 0.001797\n",
      "epoch: 51000,  time: 1299.2,  loss: 0.001802\n",
      "epoch: 52000,  time: 1324.9,  loss: 0.001808\n",
      "epoch: 53000,  time: 1350.6,  loss: 0.001813\n",
      "epoch: 54000,  time: 1375.8,  loss: 0.001806\n",
      "epoch: 55000,  time: 1400.9,  loss: 0.001810\n",
      "epoch: 56000,  time: 1426.3,  loss: 0.001811\n",
      "epoch: 57000,  time: 1451.6,  loss: 0.001806\n",
      "epoch: 58000,  time: 1476.9,  loss: 0.001805\n",
      "epoch: 59000,  time: 1502.1,  loss: 0.001814\n",
      "epoch: 60000,  time: 1527.4,  loss: 0.001812\n",
      "epoch: 61000,  time: 1553.6,  loss: 0.001807\n",
      "epoch: 62000,  time: 1579.5,  loss: 0.001843\n",
      "epoch: 63000,  time: 1605.4,  loss: 0.001803\n",
      "epoch: 64000,  time: 1631.3,  loss: 0.001784\n",
      "epoch: 65000,  time: 1657.2,  loss: 0.001792\n",
      "epoch: 66000,  time: 1683.0,  loss: 0.001799\n",
      "epoch: 67000,  time: 1708.4,  loss: 0.001794\n",
      "epoch: 68000,  time: 1733.8,  loss: 0.001788\n",
      "epoch: 69000,  time: 1759.2,  loss: 0.001818\n",
      "epoch: 70000,  time: 1784.7,  loss: 0.001817\n",
      "epoch: 71000,  time: 1810.4,  loss: 0.001808\n",
      "epoch: 72000,  time: 1836.2,  loss: 0.001813\n",
      "epoch: 73000,  time: 1862.0,  loss: 0.001787\n",
      "epoch: 74000,  time: 1887.7,  loss: 0.001812\n",
      "epoch: 75000,  time: 1913.5,  loss: 0.001819\n",
      "epoch: 76000,  time: 1939.2,  loss: 0.001804\n",
      "epoch: 77000,  time: 1964.2,  loss: 0.001798\n",
      "epoch: 78000,  time: 1989.1,  loss: 0.001827\n",
      "epoch: 79000,  time: 2014.0,  loss: 0.001792\n",
      "epoch: 80000,  time: 2038.9,  loss: 0.001823\n",
      "epoch: 81000,  time: 2063.8,  loss: 0.001778\n",
      "epoch: 82000,  time: 2088.7,  loss: 0.001804\n",
      "epoch: 83000,  time: 2113.8,  loss: 0.001823\n",
      "epoch: 84000,  time: 2138.9,  loss: 0.001794\n",
      "epoch: 85000,  time: 2164.3,  loss: 0.001811\n",
      "epoch: 86000,  time: 2189.3,  loss: 0.001785\n",
      "epoch: 87000,  time: 2215.6,  loss: 0.001793\n",
      "epoch: 88000,  time: 2242.0,  loss: 0.001822\n",
      "epoch: 89000,  time: 2268.2,  loss: 0.001789\n",
      "epoch: 90000,  time: 2294.5,  loss: 0.001809\n",
      "epoch: 91000,  time: 2319.9,  loss: 0.001820\n",
      "epoch: 92000,  time: 2345.4,  loss: 0.001800\n",
      "epoch: 93000,  time: 2371.0,  loss: 0.001813\n",
      "epoch: 94000,  time: 2396.4,  loss: 0.001795\n",
      "epoch: 95000,  time: 2421.4,  loss: 0.001806\n",
      "epoch: 96000,  time: 2446.5,  loss: 0.001781\n",
      "epoch: 97000,  time: 2471.5,  loss: 0.001794\n",
      "epoch: 98000,  time: 2496.4,  loss: 0.001808\n",
      "epoch: 99000,  time: 2521.3,  loss: 0.001787\n",
      "epoch: 100000,  time: 2546.4,  loss: 0.001766\n",
      "epoch: 101000,  time: 2571.4,  loss: 0.001809\n",
      "epoch: 102000,  time: 2596.4,  loss: 0.001811\n",
      "epoch: 103000,  time: 2621.6,  loss: 0.001793\n",
      "epoch: 104000,  time: 2647.0,  loss: 0.001827\n",
      "epoch: 105000,  time: 2673.0,  loss: 0.001791\n",
      "epoch: 106000,  time: 2699.0,  loss: 0.001836\n",
      "epoch: 107000,  time: 2724.8,  loss: 0.001800\n",
      "epoch: 108000,  time: 2750.6,  loss: 0.001774\n",
      "epoch: 109000,  time: 2776.4,  loss: 0.001791\n",
      "epoch: 110000,  time: 2802.1,  loss: 0.001790\n",
      "epoch: 111000,  time: 2827.9,  loss: 0.001817\n",
      "epoch: 112000,  time: 2853.6,  loss: 0.001812\n",
      "epoch: 113000,  time: 2879.3,  loss: 0.001804\n",
      "epoch: 114000,  time: 2905.0,  loss: 0.001799\n",
      "epoch: 115000,  time: 2930.8,  loss: 0.001795\n",
      "epoch: 116000,  time: 2956.9,  loss: 0.001817\n",
      "epoch: 117000,  time: 2983.0,  loss: 0.001816\n",
      "epoch: 118000,  time: 3008.8,  loss: 0.001819\n",
      "epoch: 119000,  time: 3034.0,  loss: 0.001823\n",
      "epoch: 120000,  time: 3059.6,  loss: 0.001790\n",
      "epoch: 121000,  time: 3085.2,  loss: 0.001823\n",
      "epoch: 122000,  time: 3110.6,  loss: 0.001801\n",
      "epoch: 123000,  time: 3136.3,  loss: 0.001786\n",
      "epoch: 124000,  time: 3162.1,  loss: 0.001792\n",
      "epoch: 125000,  time: 3187.7,  loss: 0.001806\n",
      "epoch: 126000,  time: 3212.9,  loss: 0.001805\n",
      "epoch: 127000,  time: 3238.4,  loss: 0.001800\n",
      "epoch: 128000,  time: 3264.4,  loss: 0.001798\n",
      "epoch: 129000,  time: 3290.4,  loss: 0.001805\n",
      "epoch: 130000,  time: 3315.8,  loss: 0.001785\n",
      "epoch: 131000,  time: 3341.5,  loss: 0.001819\n",
      "epoch: 132000,  time: 3367.2,  loss: 0.001827\n",
      "epoch: 133000,  time: 3393.0,  loss: 0.001804\n",
      "epoch: 134000,  time: 3418.8,  loss: 0.001784\n",
      "epoch: 135000,  time: 3444.2,  loss: 0.001794\n",
      "epoch: 136000,  time: 3469.8,  loss: 0.001812\n",
      "epoch: 137000,  time: 3495.4,  loss: 0.001806\n",
      "epoch: 138000,  time: 3521.0,  loss: 0.001796\n",
      "epoch: 139000,  time: 3546.5,  loss: 0.001806\n",
      "epoch: 140000,  time: 3571.8,  loss: 0.001823\n",
      "epoch: 141000,  time: 3597.4,  loss: 0.001817\n",
      "epoch: 142000,  time: 3623.8,  loss: 0.001796\n",
      "epoch: 143000,  time: 3650.1,  loss: 0.001807\n",
      "epoch: 144000,  time: 3676.0,  loss: 0.001796\n",
      "epoch: 145000,  time: 3701.8,  loss: 0.001800\n",
      "epoch: 146000,  time: 3727.6,  loss: 0.001807\n",
      "epoch: 147000,  time: 3753.4,  loss: 0.001818\n",
      "epoch: 148000,  time: 3779.2,  loss: 0.001785\n",
      "epoch: 149000,  time: 3804.9,  loss: 0.001793\n",
      "epoch: 150000,  time: 3831.0,  loss: 0.001800\n",
      "epoch: 151000,  time: 3856.7,  loss: 0.001798\n",
      "epoch: 152000,  time: 3882.2,  loss: 0.001808\n",
      "epoch: 153000,  time: 3907.7,  loss: 0.001804\n",
      "epoch: 154000,  time: 3933.1,  loss: 0.001783\n",
      "epoch: 155000,  time: 3958.8,  loss: 0.001794\n",
      "epoch: 156000,  time: 3984.4,  loss: 0.001761\n",
      "epoch: 157000,  time: 4010.5,  loss: 0.001790\n",
      "epoch: 158000,  time: 4036.7,  loss: 0.001803\n",
      "epoch: 159000,  time: 4062.8,  loss: 0.001806\n",
      "epoch: 160000,  time: 4088.9,  loss: 0.001839\n",
      "epoch: 161000,  time: 4114.8,  loss: 0.001770\n",
      "epoch: 162000,  time: 4140.5,  loss: 0.001766\n",
      "epoch: 163000,  time: 4164.8,  loss: 0.001802\n",
      "epoch: 164000,  time: 4189.1,  loss: 0.001823\n",
      "epoch: 165000,  time: 4214.4,  loss: 0.001798\n",
      "epoch: 166000,  time: 4240.4,  loss: 0.001792\n",
      "epoch: 167000,  time: 4266.6,  loss: 0.001794\n",
      "epoch: 168000,  time: 4292.6,  loss: 0.001818\n",
      "epoch: 169000,  time: 4318.6,  loss: 0.001799\n",
      "epoch: 170000,  time: 4344.4,  loss: 0.001807\n",
      "epoch: 171000,  time: 4370.4,  loss: 0.001782\n",
      "epoch: 172000,  time: 4396.4,  loss: 0.001797\n",
      "epoch: 173000,  time: 4422.4,  loss: 0.001803\n",
      "epoch: 174000,  time: 4448.5,  loss: 0.001805\n",
      "epoch: 175000,  time: 4474.4,  loss: 0.001794\n",
      "epoch: 176000,  time: 4500.0,  loss: 0.001779\n",
      "epoch: 177000,  time: 4525.5,  loss: 0.001820\n",
      "epoch: 178000,  time: 4550.8,  loss: 0.001790\n",
      "epoch: 179000,  time: 4576.3,  loss: 0.001783\n",
      "epoch: 180000,  time: 4602.3,  loss: 0.001804\n",
      "epoch: 181000,  time: 4628.5,  loss: 0.001782\n",
      "epoch: 182000,  time: 4654.2,  loss: 0.001826\n",
      "epoch: 183000,  time: 4679.5,  loss: 0.001802\n",
      "epoch: 184000,  time: 4704.7,  loss: 0.001798\n",
      "epoch: 185000,  time: 4730.0,  loss: 0.001781\n",
      "epoch: 186000,  time: 4755.7,  loss: 0.001797\n",
      "epoch: 187000,  time: 4781.6,  loss: 0.001807\n",
      "epoch: 188000,  time: 4807.4,  loss: 0.001823\n",
      "epoch: 189000,  time: 4832.6,  loss: 0.001804\n",
      "epoch: 190000,  time: 4857.7,  loss: 0.001787\n",
      "epoch: 191000,  time: 4882.7,  loss: 0.001801\n",
      "epoch: 192000,  time: 4907.7,  loss: 0.001805\n",
      "epoch: 193000,  time: 4932.6,  loss: 0.001796\n",
      "epoch: 194000,  time: 4957.8,  loss: 0.001791\n",
      "epoch: 195000,  time: 4983.1,  loss: 0.001796\n",
      "epoch: 196000,  time: 5008.3,  loss: 0.001798\n",
      "epoch: 197000,  time: 5033.7,  loss: 0.001808\n",
      "epoch: 198000,  time: 5058.6,  loss: 0.001817\n",
      "epoch: 199000,  time: 5082.8,  loss: 0.001801\n",
      "epoch: 200000,  time: 5108.1,  loss: 0.001807\n",
      "epoch: 201000,  time: 5134.0,  loss: 0.001786\n",
      "epoch: 202000,  time: 5159.9,  loss: 0.001800\n",
      "epoch: 203000,  time: 5186.2,  loss: 0.001813\n",
      "epoch: 204000,  time: 5212.3,  loss: 0.001824\n",
      "epoch: 205000,  time: 5238.1,  loss: 0.001808\n",
      "epoch: 206000,  time: 5263.9,  loss: 0.001806\n",
      "epoch: 207000,  time: 5289.7,  loss: 0.001820\n",
      "epoch: 208000,  time: 5315.2,  loss: 0.001814\n",
      "epoch: 209000,  time: 5340.5,  loss: 0.001811\n",
      "epoch: 210000,  time: 5366.0,  loss: 0.001795\n",
      "epoch: 211000,  time: 5391.8,  loss: 0.001790\n",
      "epoch: 212000,  time: 5417.5,  loss: 0.001791\n",
      "epoch: 213000,  time: 5443.1,  loss: 0.001784\n",
      "epoch: 214000,  time: 5468.2,  loss: 0.001774\n",
      "epoch: 215000,  time: 5493.3,  loss: 0.001795\n",
      "epoch: 216000,  time: 5518.6,  loss: 0.001805\n",
      "epoch: 217000,  time: 5544.3,  loss: 0.001791\n",
      "epoch: 218000,  time: 5570.1,  loss: 0.001765\n",
      "epoch: 219000,  time: 5595.7,  loss: 0.001826\n",
      "epoch: 220000,  time: 5620.9,  loss: 0.001821\n",
      "epoch: 221000,  time: 5645.9,  loss: 0.001811\n",
      "epoch: 222000,  time: 5671.0,  loss: 0.001803\n",
      "epoch: 223000,  time: 5696.6,  loss: 0.001805\n",
      "epoch: 224000,  time: 5722.3,  loss: 0.001825\n",
      "epoch: 225000,  time: 5747.7,  loss: 0.001783\n",
      "epoch: 226000,  time: 5773.0,  loss: 0.001811\n",
      "epoch: 227000,  time: 5798.3,  loss: 0.001815\n",
      "epoch: 228000,  time: 5823.6,  loss: 0.001811\n",
      "epoch: 229000,  time: 5849.1,  loss: 0.001776\n",
      "epoch: 230000,  time: 5874.6,  loss: 0.001813\n",
      "epoch: 231000,  time: 5900.1,  loss: 0.001797\n",
      "epoch: 232000,  time: 5924.8,  loss: 0.001811\n",
      "epoch: 233000,  time: 5950.1,  loss: 0.001788\n",
      "epoch: 234000,  time: 5976.3,  loss: 0.001808\n",
      "epoch: 235000,  time: 6002.5,  loss: 0.001825\n",
      "epoch: 236000,  time: 6028.6,  loss: 0.001797\n",
      "epoch: 237000,  time: 6054.5,  loss: 0.001791\n",
      "epoch: 238000,  time: 6080.3,  loss: 0.001805\n",
      "epoch: 239000,  time: 6106.0,  loss: 0.001809\n",
      "epoch: 240000,  time: 6131.9,  loss: 0.001795\n",
      "epoch: 241000,  time: 6157.7,  loss: 0.001802\n",
      "epoch: 242000,  time: 6183.3,  loss: 0.001795\n",
      "epoch: 243000,  time: 6208.8,  loss: 0.001794\n",
      "epoch: 244000,  time: 6234.6,  loss: 0.001809\n",
      "epoch: 245000,  time: 6259.7,  loss: 0.001809\n",
      "epoch: 246000,  time: 6285.1,  loss: 0.001785\n",
      "epoch: 247000,  time: 6310.5,  loss: 0.001795\n",
      "epoch: 248000,  time: 6336.4,  loss: 0.001804\n",
      "epoch: 249000,  time: 6362.3,  loss: 0.001776\n",
      "epoch: 250000,  time: 6388.3,  loss: 0.001807\n",
      "epoch: 251000,  time: 6414.3,  loss: 0.001808\n",
      "epoch: 252000,  time: 6440.3,  loss: 0.001795\n",
      "epoch: 253000,  time: 6466.3,  loss: 0.001790\n",
      "epoch: 254000,  time: 6492.4,  loss: 0.001769\n",
      "epoch: 255000,  time: 6517.6,  loss: 0.001801\n",
      "epoch: 256000,  time: 6542.9,  loss: 0.001782\n",
      "epoch: 257000,  time: 6567.9,  loss: 0.001782\n",
      "epoch: 258000,  time: 6593.1,  loss: 0.001775\n",
      "epoch: 259000,  time: 6618.4,  loss: 0.001793\n",
      "epoch: 260000,  time: 6643.7,  loss: 0.001809\n",
      "epoch: 261000,  time: 6669.7,  loss: 0.001796\n",
      "epoch: 262000,  time: 6695.5,  loss: 0.001793\n",
      "epoch: 263000,  time: 6720.7,  loss: 0.001800\n",
      "epoch: 264000,  time: 6745.8,  loss: 0.001801\n",
      "epoch: 265000,  time: 6771.4,  loss: 0.001798\n",
      "epoch: 266000,  time: 6797.1,  loss: 0.001793\n",
      "epoch: 267000,  time: 6822.8,  loss: 0.001775\n",
      "epoch: 268000,  time: 6848.1,  loss: 0.001801\n",
      "epoch: 269000,  time: 6873.5,  loss: 0.001819\n",
      "epoch: 270000,  time: 6899.2,  loss: 0.001787\n",
      "epoch: 271000,  time: 6925.3,  loss: 0.001790\n",
      "epoch: 272000,  time: 6950.7,  loss: 0.001784\n",
      "epoch: 273000,  time: 6976.2,  loss: 0.001792\n",
      "epoch: 274000,  time: 7001.8,  loss: 0.001780\n",
      "epoch: 275000,  time: 7026.9,  loss: 0.001789\n",
      "epoch: 276000,  time: 7052.1,  loss: 0.001786\n",
      "epoch: 277000,  time: 7077.4,  loss: 0.001796\n",
      "epoch: 278000,  time: 7102.6,  loss: 0.001780\n",
      "epoch: 279000,  time: 7128.5,  loss: 0.001801\n",
      "epoch: 280000,  time: 7154.8,  loss: 0.001800\n",
      "epoch: 281000,  time: 7181.1,  loss: 0.001802\n",
      "epoch: 282000,  time: 7207.2,  loss: 0.001831\n",
      "epoch: 283000,  time: 7233.1,  loss: 0.001797\n",
      "epoch: 284000,  time: 7258.4,  loss: 0.001824\n",
      "epoch: 285000,  time: 7283.4,  loss: 0.001778\n",
      "epoch: 286000,  time: 7308.6,  loss: 0.001766\n",
      "epoch: 287000,  time: 7333.7,  loss: 0.001803\n",
      "epoch: 288000,  time: 7359.5,  loss: 0.001781\n",
      "epoch: 289000,  time: 7385.8,  loss: 0.001784\n",
      "epoch: 290000,  time: 7412.1,  loss: 0.001805\n",
      "epoch: 291000,  time: 7438.2,  loss: 0.001815\n",
      "epoch: 292000,  time: 7463.0,  loss: 0.001823\n",
      "epoch: 293000,  time: 7488.3,  loss: 0.001774\n",
      "epoch: 294000,  time: 7514.4,  loss: 0.001780\n",
      "epoch: 295000,  time: 7540.7,  loss: 0.001788\n",
      "epoch: 296000,  time: 7566.6,  loss: 0.001826\n",
      "epoch: 297000,  time: 7592.3,  loss: 0.001814\n",
      "epoch: 298000,  time: 7617.7,  loss: 0.001798\n",
      "epoch: 299000,  time: 7643.1,  loss: 0.001775\n",
      "epoch: 300000,  time: 7668.7,  loss: 0.001770\n",
      "epoch: 301000,  time: 7694.5,  loss: 0.001787\n",
      "epoch: 302000,  time: 7720.3,  loss: 0.001787\n",
      "epoch: 303000,  time: 7746.1,  loss: 0.001813\n",
      "epoch: 304000,  time: 7771.8,  loss: 0.001793\n",
      "epoch: 305000,  time: 7797.6,  loss: 0.001813\n",
      "epoch: 306000,  time: 7823.3,  loss: 0.001777\n",
      "epoch: 307000,  time: 7849.1,  loss: 0.001790\n",
      "epoch: 308000,  time: 7874.7,  loss: 0.001805\n",
      "epoch: 309000,  time: 7900.8,  loss: 0.001802\n",
      "epoch: 310000,  time: 7926.9,  loss: 0.001819\n",
      "epoch: 311000,  time: 7952.9,  loss: 0.001774\n",
      "epoch: 312000,  time: 7978.2,  loss: 0.001812\n",
      "epoch: 313000,  time: 8003.4,  loss: 0.001791\n",
      "epoch: 314000,  time: 8028.7,  loss: 0.001801\n",
      "epoch: 315000,  time: 8054.0,  loss: 0.001801\n",
      "epoch: 316000,  time: 8079.3,  loss: 0.001768\n",
      "epoch: 317000,  time: 8104.6,  loss: 0.001808\n",
      "epoch: 318000,  time: 8130.4,  loss: 0.001782\n",
      "epoch: 319000,  time: 8156.1,  loss: 0.001808\n",
      "epoch: 320000,  time: 8181.8,  loss: 0.001809\n",
      "epoch: 321000,  time: 8207.7,  loss: 0.001797\n",
      "epoch: 322000,  time: 8233.3,  loss: 0.001781\n",
      "epoch: 323000,  time: 8258.9,  loss: 0.001790\n",
      "epoch: 324000,  time: 8284.1,  loss: 0.001796\n",
      "epoch: 325000,  time: 8309.9,  loss: 0.001795\n",
      "epoch: 326000,  time: 8335.4,  loss: 0.001821\n",
      "epoch: 327000,  time: 8360.8,  loss: 0.001808\n",
      "epoch: 328000,  time: 8386.2,  loss: 0.001808\n",
      "epoch: 329000,  time: 8412.1,  loss: 0.001790\n",
      "epoch: 330000,  time: 8438.0,  loss: 0.001788\n",
      "epoch: 331000,  time: 8464.2,  loss: 0.001792\n",
      "epoch: 332000,  time: 8490.3,  loss: 0.001768\n",
      "epoch: 333000,  time: 8516.3,  loss: 0.001792\n",
      "epoch: 334000,  time: 8542.3,  loss: 0.001806\n",
      "epoch: 335000,  time: 8568.0,  loss: 0.001815\n",
      "epoch: 336000,  time: 8593.8,  loss: 0.001782\n",
      "epoch: 337000,  time: 8619.5,  loss: 0.001792\n",
      "epoch: 338000,  time: 8644.8,  loss: 0.001819\n",
      "epoch: 339000,  time: 8670.1,  loss: 0.001827\n",
      "epoch: 340000,  time: 8695.6,  loss: 0.001786\n",
      "epoch: 341000,  time: 8720.8,  loss: 0.001798\n",
      "epoch: 342000,  time: 8745.9,  loss: 0.001820\n",
      "epoch: 343000,  time: 8771.0,  loss: 0.001799\n",
      "epoch: 344000,  time: 8796.6,  loss: 0.001781\n",
      "epoch: 345000,  time: 8822.2,  loss: 0.001784\n",
      "epoch: 346000,  time: 8848.0,  loss: 0.001846\n",
      "epoch: 347000,  time: 8873.9,  loss: 0.001806\n",
      "epoch: 348000,  time: 8899.5,  loss: 0.001793\n",
      "epoch: 349000,  time: 8925.3,  loss: 0.001792\n",
      "epoch: 350000,  time: 8951.1,  loss: 0.001781\n",
      "epoch: 351000,  time: 8976.7,  loss: 0.001788\n",
      "epoch: 352000,  time: 9002.5,  loss: 0.001799\n",
      "epoch: 353000,  time: 9028.3,  loss: 0.001798\n",
      "epoch: 354000,  time: 9054.2,  loss: 0.001799\n",
      "epoch: 355000,  time: 9080.0,  loss: 0.001815\n",
      "epoch: 356000,  time: 9106.0,  loss: 0.001827\n",
      "epoch: 357000,  time: 9131.6,  loss: 0.001772\n",
      "epoch: 358000,  time: 9156.9,  loss: 0.001796\n",
      "epoch: 359000,  time: 9182.0,  loss: 0.001809\n",
      "epoch: 360000,  time: 9207.1,  loss: 0.001786\n",
      "epoch: 361000,  time: 9232.3,  loss: 0.001808\n",
      "epoch: 362000,  time: 9257.7,  loss: 0.001786\n",
      "epoch: 363000,  time: 9283.1,  loss: 0.001799\n",
      "epoch: 364000,  time: 9309.0,  loss: 0.001806\n",
      "epoch: 365000,  time: 9335.0,  loss: 0.001791\n",
      "epoch: 366000,  time: 9360.8,  loss: 0.001781\n",
      "epoch: 367000,  time: 9386.6,  loss: 0.001810\n",
      "epoch: 368000,  time: 9412.8,  loss: 0.001780\n",
      "epoch: 369000,  time: 9438.6,  loss: 0.001810\n",
      "epoch: 370000,  time: 9464.3,  loss: 0.001798\n",
      "epoch: 371000,  time: 9490.1,  loss: 0.001781\n",
      "epoch: 372000,  time: 9515.6,  loss: 0.001798\n",
      "epoch: 373000,  time: 9541.1,  loss: 0.001819\n",
      "epoch: 374000,  time: 9566.5,  loss: 0.001787\n",
      "epoch: 375000,  time: 9592.3,  loss: 0.001773\n",
      "epoch: 376000,  time: 9618.4,  loss: 0.001793\n",
      "epoch: 377000,  time: 9644.5,  loss: 0.001806\n",
      "epoch: 378000,  time: 9670.3,  loss: 0.001774\n",
      "epoch: 379000,  time: 9696.2,  loss: 0.001791\n",
      "epoch: 380000,  time: 9722.0,  loss: 0.001804\n",
      "epoch: 381000,  time: 9747.7,  loss: 0.001789\n",
      "epoch: 382000,  time: 9773.5,  loss: 0.001799\n",
      "epoch: 383000,  time: 9799.2,  loss: 0.001775\n",
      "epoch: 384000,  time: 9824.9,  loss: 0.001813\n",
      "epoch: 385000,  time: 9850.6,  loss: 0.001754\n",
      "epoch: 386000,  time: 9876.2,  loss: 0.001795\n",
      "epoch: 387000,  time: 9901.7,  loss: 0.001808\n",
      "epoch: 388000,  time: 9927.8,  loss: 0.001798\n",
      "epoch: 389000,  time: 9953.7,  loss: 0.001795\n",
      "epoch: 390000,  time: 9979.2,  loss: 0.001796\n",
      "epoch: 391000,  time: 10004.8,  loss: 0.001779\n",
      "epoch: 392000,  time: 10030.2,  loss: 0.001772\n",
      "epoch: 393000,  time: 10055.3,  loss: 0.001807\n",
      "epoch: 394000,  time: 10080.6,  loss: 0.001784\n",
      "epoch: 395000,  time: 10105.9,  loss: 0.001823\n",
      "epoch: 396000,  time: 10131.5,  loss: 0.001793\n",
      "epoch: 397000,  time: 10157.3,  loss: 0.001800\n",
      "epoch: 398000,  time: 10182.8,  loss: 0.001817\n",
      "epoch: 399000,  time: 10208.5,  loss: 0.001781\n",
      "epoch: 400000,  time: 10234.1,  loss: 0.001804\n",
      "epoch: 401000,  time: 10259.9,  loss: 0.001810\n",
      "epoch: 402000,  time: 10285.5,  loss: 0.001784\n",
      "epoch: 403000,  time: 10311.1,  loss: 0.001797\n",
      "epoch: 404000,  time: 10337.2,  loss: 0.001802\n",
      "epoch: 405000,  time: 10363.3,  loss: 0.001817\n",
      "epoch: 406000,  time: 10389.3,  loss: 0.001781\n",
      "epoch: 407000,  time: 10414.7,  loss: 0.001812\n",
      "epoch: 408000,  time: 10439.8,  loss: 0.001784\n",
      "epoch: 409000,  time: 10464.9,  loss: 0.001794\n",
      "epoch: 410000,  time: 10490.1,  loss: 0.001802\n",
      "epoch: 411000,  time: 10515.5,  loss: 0.001798\n",
      "epoch: 412000,  time: 10541.1,  loss: 0.001810\n",
      "epoch: 413000,  time: 10566.5,  loss: 0.001790\n",
      "epoch: 414000,  time: 10592.3,  loss: 0.001802\n",
      "epoch: 415000,  time: 10617.9,  loss: 0.001769\n",
      "epoch: 416000,  time: 10643.2,  loss: 0.001809\n",
      "epoch: 417000,  time: 10668.4,  loss: 0.001787\n",
      "epoch: 418000,  time: 10693.5,  loss: 0.001799\n",
      "epoch: 419000,  time: 10718.7,  loss: 0.001787\n",
      "epoch: 420000,  time: 10744.1,  loss: 0.001796\n",
      "epoch: 421000,  time: 10769.8,  loss: 0.001788\n",
      "epoch: 422000,  time: 10795.2,  loss: 0.001767\n",
      "epoch: 423000,  time: 10820.5,  loss: 0.001772\n",
      "epoch: 424000,  time: 10845.9,  loss: 0.001800\n",
      "epoch: 425000,  time: 10871.1,  loss: 0.001796\n",
      "epoch: 426000,  time: 10896.3,  loss: 0.001808\n",
      "epoch: 427000,  time: 10921.5,  loss: 0.001792\n",
      "epoch: 428000,  time: 10946.4,  loss: 0.001814\n",
      "epoch: 429000,  time: 10972.0,  loss: 0.001789\n",
      "epoch: 430000,  time: 10997.2,  loss: 0.001808\n",
      "epoch: 431000,  time: 11022.8,  loss: 0.001804\n",
      "epoch: 432000,  time: 11048.6,  loss: 0.001793\n",
      "epoch: 433000,  time: 11074.3,  loss: 0.001774\n",
      "epoch: 434000,  time: 11099.8,  loss: 0.001790\n",
      "epoch: 435000,  time: 11125.2,  loss: 0.001786\n",
      "epoch: 436000,  time: 11150.2,  loss: 0.001780\n",
      "epoch: 437000,  time: 11175.6,  loss: 0.001793\n",
      "epoch: 438000,  time: 11200.7,  loss: 0.001798\n",
      "epoch: 439000,  time: 11226.2,  loss: 0.001748\n",
      "epoch: 440000,  time: 11251.4,  loss: 0.001796\n",
      "epoch: 441000,  time: 11276.4,  loss: 0.001806\n",
      "epoch: 442000,  time: 11301.4,  loss: 0.001793\n",
      "epoch: 443000,  time: 11326.7,  loss: 0.001804\n",
      "epoch: 444000,  time: 11352.0,  loss: 0.001801\n",
      "epoch: 445000,  time: 11377.8,  loss: 0.001777\n",
      "epoch: 446000,  time: 11403.8,  loss: 0.001799\n",
      "epoch: 447000,  time: 11429.7,  loss: 0.001799\n",
      "epoch: 448000,  time: 11455.1,  loss: 0.001794\n",
      "epoch: 449000,  time: 11480.6,  loss: 0.001782\n",
      "epoch: 450000,  time: 11505.8,  loss: 0.001786\n",
      "epoch: 451000,  time: 11531.2,  loss: 0.001786\n",
      "epoch: 452000,  time: 11557.0,  loss: 0.001811\n",
      "epoch: 453000,  time: 11582.7,  loss: 0.001785\n",
      "epoch: 454000,  time: 11608.6,  loss: 0.001770\n",
      "epoch: 455000,  time: 11634.5,  loss: 0.001793\n",
      "epoch: 456000,  time: 11660.4,  loss: 0.001801\n",
      "epoch: 457000,  time: 11686.4,  loss: 0.001781\n",
      "epoch: 458000,  time: 11712.1,  loss: 0.001802\n",
      "epoch: 459000,  time: 11737.4,  loss: 0.001788\n",
      "epoch: 460000,  time: 11763.1,  loss: 0.001792\n",
      "epoch: 461000,  time: 11788.9,  loss: 0.001769\n",
      "epoch: 462000,  time: 11814.7,  loss: 0.001794\n",
      "epoch: 463000,  time: 11840.5,  loss: 0.001787\n",
      "epoch: 464000,  time: 11866.2,  loss: 0.001779\n",
      "epoch: 465000,  time: 11892.0,  loss: 0.001770\n",
      "epoch: 466000,  time: 11917.7,  loss: 0.001763\n",
      "epoch: 467000,  time: 11943.4,  loss: 0.001808\n",
      "epoch: 468000,  time: 11969.0,  loss: 0.001803\n",
      "epoch: 469000,  time: 11995.0,  loss: 0.001818\n",
      "epoch: 470000,  time: 12021.2,  loss: 0.001798\n",
      "epoch: 471000,  time: 12047.0,  loss: 0.001803\n",
      "epoch: 472000,  time: 12072.5,  loss: 0.001799\n",
      "epoch: 473000,  time: 12097.9,  loss: 0.001818\n",
      "epoch: 474000,  time: 12123.4,  loss: 0.001793\n",
      "epoch: 475000,  time: 12149.0,  loss: 0.001799\n",
      "epoch: 476000,  time: 12174.7,  loss: 0.001791\n",
      "epoch: 477000,  time: 12200.9,  loss: 0.001810\n",
      "epoch: 478000,  time: 12227.1,  loss: 0.001778\n",
      "epoch: 479000,  time: 12252.8,  loss: 0.001771\n",
      "epoch: 480000,  time: 12277.9,  loss: 0.001798\n",
      "epoch: 481000,  time: 12303.0,  loss: 0.001784\n",
      "epoch: 482000,  time: 12328.2,  loss: 0.001805\n",
      "epoch: 483000,  time: 12353.7,  loss: 0.001798\n",
      "epoch: 484000,  time: 12379.9,  loss: 0.001793\n",
      "epoch: 485000,  time: 12405.6,  loss: 0.001824\n",
      "epoch: 486000,  time: 12431.5,  loss: 0.001789\n",
      "epoch: 487000,  time: 12457.3,  loss: 0.001807\n",
      "epoch: 488000,  time: 12482.5,  loss: 0.001778\n",
      "epoch: 489000,  time: 12507.9,  loss: 0.001804\n",
      "epoch: 490000,  time: 12533.2,  loss: 0.001805\n",
      "epoch: 491000,  time: 12558.7,  loss: 0.001818\n",
      "epoch: 492000,  time: 12584.1,  loss: 0.001805\n",
      "epoch: 493000,  time: 12609.5,  loss: 0.001804\n",
      "epoch: 494000,  time: 12634.8,  loss: 0.001792\n",
      "epoch: 495000,  time: 12660.1,  loss: 0.001811\n",
      "epoch: 496000,  time: 12685.2,  loss: 0.001778\n",
      "epoch: 497000,  time: 12710.3,  loss: 0.001784\n",
      "epoch: 498000,  time: 12735.4,  loss: 0.001802\n",
      "epoch: 499000,  time: 12760.5,  loss: 0.001781\n"
     ]
    }
   ],
   "source": [
    "from gae_pd_score_estimation import gae_P_n_estimate_score, gae_P_n_estimate_score_error\n",
    "\n",
    "checkEstErrorPeriod = 20\n",
    "gscore_est_error_set = []\n",
    "\n",
    "start = time.time()\n",
    "for epoch in range(max_iter_num):\n",
    "    optimizer.zero_grad()\n",
    "    loss = model.calculate_loss(x, X, X_sqrt, X_invsqrt)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    scheduler.step()\n",
    "    \n",
    "    if (epoch % checkEstErrorPeriod == 0 or epoch == max_iter_num-1):\n",
    "        est_train = gae_P_n_estimate_score(x, X_sqrt, metric_train, model)\n",
    "        cur_error = gae_P_n_estimate_score_error(x, X_sqrt, est_train, model, model.noise_std**2, \n",
    "                         metricInv_sqrt_train, X_sqrt_dirderiv_set, christoffel_sum_train, \n",
    "                             diagonal_metric=False, other_quantities_at_x = other_quantities_at_x)\n",
    "        gscore_est_error_set.append(cur_error)\n",
    "        \n",
    "    if epoch == 0:\n",
    "        best_model = copy.deepcopy(model.state_dict())\n",
    "        min_val = gscore_est_error_set[-1]\n",
    "        min_epoch = epoch\n",
    "    elif gscore_est_error_set[-1] <= min_val:\n",
    "        best_model = copy.deepcopy(model.state_dict())\n",
    "        min_val = gscore_est_error_set[-1]\n",
    "        min_epoch = epoch\n",
    "    if epoch % 1000 == 0:\n",
    "        print(\"epoch: {:d},  time: {:.1f},  loss: {:.6f}\".format(epoch, time.time() - start, loss.item()/x.shape[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-205.65348691406143"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "min_val"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Score estimation error on test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from Pn_util import metric_P_n, geometricScore_tangentGaussianMixture\n",
    "\n",
    "#### Load test data\n",
    "Pndataset_testscore = torch.load('P'+str(pd_dim)+'TangentGaussianMixtureForTest210912m'+str(mix_num)+'.pth')\n",
    "x_test = Pndataset_testscore.train_data.cuda()\n",
    "X_test = vec2mat(Pndataset_testscore.train_data)\n",
    "X_test_sqrt = Pndataset_testscore.train_data_sqrt.cuda()\n",
    "metric_test = metric_P_n(X_test)\n",
    "metricInv_sqrt_test = metricInv_sqrt_P_n(X_test)\n",
    "\n",
    "#### Calculate true score for TangentGaussianMixture data used in the experiments\n",
    "Means = Pndataset.Means\n",
    "Cov_sqrts = Pndataset.Cov_sqrts\n",
    "\n",
    "Nmix = Means.shape[0]\n",
    "weights = torch.cuda.FloatTensor([1/Nmix]*Nmix)\n",
    "CovInvs = torch.zeros(Nmix, vec_dim, vec_dim)\n",
    "for i in range(Nmix):\n",
    "    CovInvs[i] = torch.inverse(torch.mm(Cov_sqrts[i].permute(1,0), Cov_sqrts[i]))\n",
    "score_true, _ = geometricScore_tangentGaussianMixture(X_test.cuda(), weights, Means.cuda(), CovInvs.cuda())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "score mse on test data after training: 6.027549\n"
     ]
    }
   ],
   "source": [
    "from gae_score_estimation import compareScores\n",
    "x_test = mat2vec(Log_mat(X_test.cuda()))\n",
    "est_test = gae_P_n_estimate_score(x_test, X_test_sqrt, metric_test, model)\n",
    "error = compareScores(score_true, est_test, metricInv_sqrt_test.cuda())\n",
    "\n",
    "print(\"score mse on test data after training: {:f}\".format(float(error[0].cpu())))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
