{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7f018401a828>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# set device\n",
    "useGPU = torch.cuda.is_available()\n",
    "torch.cuda.set_device(0)\n",
    "\n",
    "# set random seed\n",
    "seed = 0\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from DTI_dataset import N_n_DataSet\n",
    "\n",
    "shuffle = False\n",
    "file_prefix = '033_S_4179_20180125_131417'\n",
    "filelist = ['../DTI_data_1mm_axialcrop/'+file_prefix[:10]+'/'+file_prefix+'_1mm_6_tensor_cropped.nii.gz']\n",
    "randPos = True\n",
    "roi_range = [0, 143, 0, 181, 0, 13]\n",
    "dx = 1\n",
    "weight_mode = None\n",
    "data_num = 1\n",
    "data_prefix = '_fix'\n",
    "            \n",
    "traindataset = N_n_DataSet(filelist, \n",
    "                           match_num_per_subject = data_num, shuffle = shuffle, randPos = randPos,\n",
    "                           dx = dx, roi_range = roi_range, data_prefix = data_prefix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_batch_size = 60000\n",
    "trainloader = torch.utils.data.DataLoader(traindataset, batch_size=train_batch_size, \n",
    "                                          shuffle=True, num_workers = min(6, len(traindataset)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.0018)\n",
      "tensor(0.1939)\n"
     ]
    }
   ],
   "source": [
    "### get constant for setting noise_std and covMetricCoeff for balancing noise scale to be comparable to dae case\n",
    "posMetric_sq = torch.bmm(traindataset.covInv, traindataset.covInv)\n",
    "const1 = torch.sum(traindataset.covInv * torch.eye(3).view(1,3,3)) \\\n",
    "/ torch.sum(posMetric_sq * torch.eye(3).view(1,3,3))\n",
    "print(const1)\n",
    "posMetricInv = torch.bmm(traindataset.cov_sqrt.permute(0,2,1), traindataset.cov_sqrt)\n",
    "posMetricInv_sq = torch.bmm(posMetricInv, posMetricInv)\n",
    "const2 = torch.sum(posMetricInv_sq * torch.eye(3).view(1,3,3)) / \\\n",
    "torch.sum(posMetricInv * torch.eye(3).view(1,3,3))\n",
    "print(const2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GDAE Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gae2 import GDAE_N_n\n",
    "from tensor_data_util import posMetric_func_N_n\n",
    "\n",
    "dim = [9, 1000]\n",
    "num_hidden_layers = 2\n",
    "noise_std = 0.01 / np.sqrt(2.0)\n",
    "covCoeff = 1.0 / const1.numpy() * 2.0\n",
    "approx_order = 1\n",
    "use_exp_map_sqrt = True\n",
    "use_exp_map_corrupt = False\n",
    "gdae = GDAE_N_n(dim, num_hidden_layers, noise_std, covCoeff = covCoeff,\n",
    "                 posMetricFunc = posMetric_func_N_n, useLeakyReLU = False, approx_order = approx_order,\n",
    "               use_exp_map_sqrt = use_exp_map_sqrt, use_exp_map_corrupt = use_exp_map_corrupt)\n",
    "data_rms = torch.sqrt(torch.mean(traindataset.posAndCov**2, dim=0)).view(1,-1)\n",
    "gdae.autoencoder[0].weight.data /= data_rms\n",
    "gdae = gdae.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train GDAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.optim.lr_scheduler import MultiStepLR\n",
    "\n",
    "max_iter_num = 40000\n",
    "optimizer = torch.optim.Adam( gdae.parameters(), lr=1e-4, weight_decay=1e-6)\n",
    "scheduler = MultiStepLR(optimizer, milestones=[max_iter_num//2], gamma=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iter: 0 ---- time 1.5\n",
      "loss: 0.023485\n",
      "iter: 100 ---- time 66.7\n",
      "loss: 0.000584\n",
      "iter: 200 ---- time 132.8\n",
      "loss: 0.000412\n",
      "iter: 300 ---- time 198.6\n",
      "loss: 0.000368\n",
      "iter: 400 ---- time 264.3\n",
      "loss: 0.000349\n",
      "iter: 500 ---- time 333.8\n",
      "loss: 0.000369\n",
      "iter: 600 ---- time 401.3\n",
      "loss: 0.000453\n",
      "iter: 700 ---- time 465.9\n",
      "loss: 0.000469\n",
      "iter: 800 ---- time 530.3\n",
      "loss: 0.000353\n",
      "iter: 900 ---- time 595.2\n",
      "loss: 0.000323\n",
      "iter: 1000 ---- time 658.7\n",
      "loss: 0.000317\n",
      "iter: 1100 ---- time 725.2\n",
      "loss: 0.000366\n",
      "iter: 1200 ---- time 792.1\n",
      "loss: 0.000413\n",
      "iter: 1300 ---- time 858.5\n",
      "loss: 0.000391\n",
      "iter: 1400 ---- time 922.4\n",
      "loss: 0.000450\n",
      "iter: 1500 ---- time 988.9\n",
      "loss: 0.000334\n",
      "iter: 1600 ---- time 1054.5\n",
      "loss: 0.000495\n",
      "iter: 1700 ---- time 1120.9\n",
      "loss: 0.000312\n",
      "iter: 1800 ---- time 1185.9\n",
      "loss: 0.000303\n",
      "iter: 1900 ---- time 1250.2\n",
      "loss: 0.000438\n",
      "iter: 2000 ---- time 1314.3\n",
      "loss: 0.000309\n",
      "iter: 2100 ---- time 1381.3\n",
      "loss: 0.000501\n",
      "iter: 2200 ---- time 1449.1\n",
      "loss: 0.000326\n",
      "iter: 2300 ---- time 1512.4\n",
      "loss: 0.000734\n",
      "iter: 2400 ---- time 1574.3\n",
      "loss: 0.000310\n",
      "iter: 2500 ---- time 1639.9\n",
      "loss: 0.000512\n",
      "iter: 2600 ---- time 1706.3\n",
      "loss: 0.000301\n",
      "iter: 2700 ---- time 1772.1\n",
      "loss: 0.000298\n",
      "iter: 2800 ---- time 1836.4\n",
      "loss: 0.000317\n",
      "iter: 2900 ---- time 1904.7\n",
      "loss: 0.000318\n",
      "iter: 3000 ---- time 1969.3\n",
      "loss: 0.000296\n",
      "iter: 3100 ---- time 2034.2\n",
      "loss: 0.000303\n",
      "iter: 3200 ---- time 2097.8\n",
      "loss: 0.000376\n",
      "iter: 3300 ---- time 2162.7\n",
      "loss: 0.000298\n",
      "iter: 3400 ---- time 2230.2\n",
      "loss: 0.000553\n",
      "iter: 3500 ---- time 2295.2\n",
      "loss: 0.000309\n",
      "iter: 3600 ---- time 2363.7\n",
      "loss: 0.000374\n",
      "iter: 3700 ---- time 2427.5\n",
      "loss: 0.000306\n",
      "iter: 3800 ---- time 2493.7\n",
      "loss: 0.000974\n",
      "iter: 3900 ---- time 2558.9\n",
      "loss: 0.000334\n",
      "iter: 4000 ---- time 2625.3\n",
      "loss: 0.000338\n",
      "iter: 4100 ---- time 2689.6\n",
      "loss: 0.000311\n",
      "iter: 4200 ---- time 2753.0\n",
      "loss: 0.000354\n",
      "iter: 4300 ---- time 2815.4\n",
      "loss: 0.000308\n",
      "iter: 4400 ---- time 2879.4\n",
      "loss: 0.000297\n",
      "iter: 4500 ---- time 2945.8\n",
      "loss: 0.000294\n",
      "iter: 4600 ---- time 3012.7\n",
      "loss: 0.000371\n",
      "iter: 4700 ---- time 3077.6\n",
      "loss: 0.000303\n",
      "iter: 4800 ---- time 3142.4\n",
      "loss: 0.000805\n",
      "iter: 4900 ---- time 3208.2\n",
      "loss: 0.000302\n",
      "iter: 5000 ---- time 3272.8\n",
      "loss: 0.000290\n",
      "iter: 5100 ---- time 3339.1\n",
      "loss: 0.000490\n",
      "iter: 5200 ---- time 3401.8\n",
      "loss: 0.000303\n",
      "iter: 5300 ---- time 3465.1\n",
      "loss: 0.000323\n",
      "iter: 5400 ---- time 3530.5\n",
      "loss: 0.000320\n",
      "iter: 5500 ---- time 3597.4\n",
      "loss: 0.000380\n",
      "iter: 5600 ---- time 3661.7\n",
      "loss: 0.000324\n",
      "iter: 5700 ---- time 3729.5\n",
      "loss: 0.000651\n",
      "iter: 5800 ---- time 3795.1\n",
      "loss: 0.000294\n",
      "iter: 5900 ---- time 3861.3\n",
      "loss: 0.000362\n",
      "iter: 6000 ---- time 3923.8\n",
      "loss: 0.000312\n",
      "iter: 6100 ---- time 3990.1\n",
      "loss: 0.000419\n",
      "iter: 6200 ---- time 4055.9\n",
      "loss: 0.000292\n",
      "iter: 6300 ---- time 4122.1\n",
      "loss: 0.000316\n",
      "iter: 6400 ---- time 4186.7\n",
      "loss: 0.000377\n",
      "iter: 6500 ---- time 4252.1\n",
      "loss: 0.000331\n",
      "iter: 6600 ---- time 4316.6\n",
      "loss: 0.000303\n",
      "iter: 6700 ---- time 4383.4\n",
      "loss: 0.000298\n",
      "iter: 6800 ---- time 4449.1\n",
      "loss: 0.000328\n",
      "iter: 6900 ---- time 4516.3\n",
      "loss: 0.000388\n",
      "iter: 7000 ---- time 4580.1\n",
      "loss: 0.000322\n",
      "iter: 7100 ---- time 4644.5\n",
      "loss: 0.000299\n",
      "iter: 7200 ---- time 4709.0\n",
      "loss: 0.000290\n",
      "iter: 7300 ---- time 4776.9\n",
      "loss: 0.000295\n",
      "iter: 7400 ---- time 4843.9\n",
      "loss: 0.000491\n",
      "iter: 7500 ---- time 4907.9\n",
      "loss: 0.000320\n",
      "iter: 7600 ---- time 4970.5\n",
      "loss: 0.000571\n",
      "iter: 7700 ---- time 5037.5\n",
      "loss: 0.000296\n",
      "iter: 7800 ---- time 5106.0\n",
      "loss: 0.000328\n",
      "iter: 7900 ---- time 5172.4\n",
      "loss: 0.000290\n",
      "iter: 8000 ---- time 5236.7\n",
      "loss: 0.000329\n",
      "iter: 8100 ---- time 5301.2\n",
      "loss: 0.000293\n",
      "iter: 8200 ---- time 5368.6\n",
      "loss: 0.000319\n",
      "iter: 8300 ---- time 5433.5\n",
      "loss: 0.000292\n",
      "iter: 8400 ---- time 5495.8\n",
      "loss: 0.000325\n",
      "iter: 8500 ---- time 5561.3\n",
      "loss: 0.000301\n",
      "iter: 8600 ---- time 5626.0\n",
      "loss: 0.000590\n",
      "iter: 8700 ---- time 5692.7\n",
      "loss: 0.000302\n",
      "iter: 8800 ---- time 5757.9\n",
      "loss: 0.000291\n",
      "iter: 8900 ---- time 5826.2\n",
      "loss: 0.000375\n",
      "iter: 9000 ---- time 5892.3\n",
      "loss: 0.000345\n",
      "iter: 9100 ---- time 5958.6\n",
      "loss: 0.000339\n",
      "iter: 9200 ---- time 6023.5\n",
      "loss: 0.000304\n",
      "iter: 9300 ---- time 6090.7\n",
      "loss: 0.000352\n",
      "iter: 9400 ---- time 6155.8\n",
      "loss: 0.000302\n",
      "iter: 9500 ---- time 6219.8\n",
      "loss: 0.000298\n",
      "iter: 9600 ---- time 6284.4\n",
      "loss: 0.000356\n",
      "iter: 9700 ---- time 6352.0\n",
      "loss: 0.000290\n",
      "iter: 9800 ---- time 6418.6\n",
      "loss: 0.000334\n",
      "iter: 9900 ---- time 6485.8\n",
      "loss: 0.000464\n",
      "iter: 10000 ---- time 6549.8\n",
      "loss: 0.000322\n",
      "iter: 10100 ---- time 6605.7\n",
      "loss: 0.000546\n",
      "iter: 10200 ---- time 6668.1\n",
      "loss: 0.000356\n",
      "iter: 10300 ---- time 6729.8\n",
      "loss: 0.000324\n",
      "iter: 10400 ---- time 6796.5\n",
      "loss: 0.000339\n",
      "iter: 10500 ---- time 6858.0\n",
      "loss: 0.000341\n",
      "iter: 10600 ---- time 6917.9\n",
      "loss: 0.000315\n",
      "iter: 10700 ---- time 6981.4\n",
      "loss: 0.000299\n",
      "iter: 10800 ---- time 7037.5\n",
      "loss: 0.000293\n",
      "iter: 10900 ---- time 7094.8\n",
      "loss: 0.000287\n",
      "iter: 11000 ---- time 7160.3\n",
      "loss: 0.000297\n",
      "iter: 11100 ---- time 7220.6\n",
      "loss: 0.000328\n",
      "iter: 11200 ---- time 7286.5\n",
      "loss: 0.000306\n",
      "iter: 11300 ---- time 7347.9\n",
      "loss: 0.000289\n",
      "iter: 11400 ---- time 7410.8\n",
      "loss: 0.000329\n",
      "iter: 11500 ---- time 7476.0\n",
      "loss: 0.000298\n",
      "iter: 11600 ---- time 7539.3\n",
      "loss: 0.000335\n",
      "iter: 11700 ---- time 7598.0\n",
      "loss: 0.000297\n",
      "iter: 11800 ---- time 7656.1\n",
      "loss: 0.000356\n",
      "iter: 11900 ---- time 7720.2\n",
      "loss: 0.000301\n",
      "iter: 12000 ---- time 7782.4\n",
      "loss: 0.000313\n",
      "iter: 12100 ---- time 7842.4\n",
      "loss: 0.000316\n",
      "iter: 12200 ---- time 7907.8\n",
      "loss: 0.000421\n",
      "iter: 12300 ---- time 7974.7\n",
      "loss: 0.000294\n",
      "iter: 12400 ---- time 8040.2\n",
      "loss: 0.000381\n",
      "iter: 12500 ---- time 8103.7\n",
      "loss: 0.000316\n",
      "iter: 12600 ---- time 8169.2\n",
      "loss: 0.000306\n",
      "iter: 12700 ---- time 8235.3\n",
      "loss: 0.000303\n",
      "iter: 12800 ---- time 8296.6\n",
      "loss: 0.000288\n",
      "iter: 12900 ---- time 8354.3\n",
      "loss: 0.000320\n",
      "iter: 13000 ---- time 8415.5\n",
      "loss: 0.000298\n",
      "iter: 13100 ---- time 8479.6\n",
      "loss: 0.000319\n",
      "iter: 13200 ---- time 8539.1\n",
      "loss: 0.000522\n",
      "iter: 13300 ---- time 8603.7\n",
      "loss: 0.000313\n",
      "iter: 13400 ---- time 8660.0\n",
      "loss: 0.000289\n",
      "iter: 13500 ---- time 8722.5\n",
      "loss: 0.000291\n",
      "iter: 13600 ---- time 8785.1\n",
      "loss: 0.000416\n",
      "iter: 13700 ---- time 8846.7\n",
      "loss: 0.000329\n",
      "iter: 13800 ---- time 8904.9\n",
      "loss: 0.000303\n",
      "iter: 13900 ---- time 8967.6\n",
      "loss: 0.000295\n",
      "iter: 14000 ---- time 9028.1\n",
      "loss: 0.000316\n",
      "iter: 14100 ---- time 9089.1\n",
      "loss: 0.000328\n",
      "iter: 14200 ---- time 9155.2\n",
      "loss: 0.000384\n",
      "iter: 14300 ---- time 9211.1\n",
      "loss: 0.000331\n",
      "iter: 14400 ---- time 9271.3\n",
      "loss: 0.000334\n",
      "iter: 14500 ---- time 9336.3\n",
      "loss: 0.000331\n",
      "iter: 14600 ---- time 9400.6\n",
      "loss: 0.000292\n",
      "iter: 14700 ---- time 9462.5\n",
      "loss: 0.000298\n",
      "iter: 14800 ---- time 9517.5\n",
      "loss: 0.000290\n",
      "iter: 14900 ---- time 9577.1\n",
      "loss: 0.000312\n",
      "iter: 15000 ---- time 9636.4\n",
      "loss: 0.000318\n",
      "iter: 15100 ---- time 9700.2\n",
      "loss: 0.000291\n",
      "iter: 15200 ---- time 9765.6\n",
      "loss: 0.000306\n",
      "iter: 15300 ---- time 9827.7\n",
      "loss: 0.000290\n",
      "iter: 15400 ---- time 9892.9\n",
      "loss: 0.000294\n",
      "iter: 15500 ---- time 9957.5\n",
      "loss: 0.000323\n",
      "iter: 15600 ---- time 10021.0\n",
      "loss: 0.000349\n",
      "iter: 15700 ---- time 10078.4\n",
      "loss: 0.000324\n",
      "iter: 15800 ---- time 10143.7\n",
      "loss: 0.000331\n",
      "iter: 15900 ---- time 10205.2\n",
      "loss: 0.000337\n",
      "iter: 16000 ---- time 10269.3\n",
      "loss: 0.000321\n",
      "iter: 16100 ---- time 10326.4\n",
      "loss: 0.000345\n",
      "iter: 16200 ---- time 10383.3\n",
      "loss: 0.000296\n",
      "iter: 16300 ---- time 10445.9\n",
      "loss: 0.000329\n",
      "iter: 16400 ---- time 10506.7\n",
      "loss: 0.000295\n",
      "iter: 16500 ---- time 10570.2\n",
      "loss: 0.000320\n",
      "iter: 16600 ---- time 10635.6\n",
      "loss: 0.000290\n",
      "iter: 16700 ---- time 10698.9\n",
      "loss: 0.000353\n",
      "iter: 16800 ---- time 10761.1\n",
      "loss: 0.000300\n",
      "iter: 16900 ---- time 10823.3\n",
      "loss: 0.000293\n",
      "iter: 17000 ---- time 10879.1\n",
      "loss: 0.000295\n",
      "iter: 17100 ---- time 10941.0\n",
      "loss: 0.000467\n",
      "iter: 17200 ---- time 11001.4\n",
      "loss: 0.000312\n",
      "iter: 17300 ---- time 11060.1\n",
      "loss: 0.000310\n",
      "iter: 17400 ---- time 11121.0\n",
      "loss: 0.000287\n",
      "iter: 17500 ---- time 11178.5\n",
      "loss: 0.000320\n",
      "iter: 17600 ---- time 11243.8\n",
      "loss: 0.000301\n",
      "iter: 17700 ---- time 11305.5\n",
      "loss: 0.000315\n",
      "iter: 17800 ---- time 11364.2\n",
      "loss: 0.000358\n",
      "iter: 17900 ---- time 11426.2\n",
      "loss: 0.000297\n",
      "iter: 18000 ---- time 11486.2\n",
      "loss: 0.000302\n",
      "iter: 18100 ---- time 11545.8\n",
      "loss: 0.000290\n",
      "iter: 18200 ---- time 11608.8\n",
      "loss: 0.000359\n",
      "iter: 18300 ---- time 11675.8\n",
      "loss: 0.000316\n",
      "iter: 18400 ---- time 11739.7\n",
      "loss: 0.000308\n",
      "iter: 18500 ---- time 11803.1\n",
      "loss: 0.000348\n",
      "iter: 18600 ---- time 11860.6\n",
      "loss: 0.000293\n",
      "iter: 18700 ---- time 11921.9\n",
      "loss: 0.000289\n",
      "iter: 18800 ---- time 11982.3\n",
      "loss: 0.000297\n",
      "iter: 18900 ---- time 12042.0\n",
      "loss: 0.000303\n",
      "iter: 19000 ---- time 12107.3\n",
      "loss: 0.000309\n",
      "iter: 19100 ---- time 12163.7\n"
     ]
    }
   ],
   "source": [
    "from gae2_trainer import gae_N_n_trainer_batchall\n",
    "from tensor_data_util import posMetric_func_N_n\n",
    "\n",
    "model_wts = gae_N_n_trainer_batchall(trainloader, gdae, optimizer, scheduler, max_iter_num, use_gpu = useGPU, \n",
    "                saveModel = True, printEpochPeriod = 100, weight_mode = weight_mode)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Check Score Estimation Error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Ntrunc = 10000\n",
    "input = traindataset.posAndCov.clone()\n",
    "idxSet = [Ntrunc*i for i in range(0, int(input.shape[0] / Ntrunc)+1)]\n",
    "idxSet += [input.shape[0]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Post iterations for GDAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from gae2_score_estimation import gae_N_n_estimate_score_error_truncated\n",
    "optimizer = torch.optim.Adam( gdae.parameters(), lr=1e-6, weight_decay=1e-6)\n",
    "max_iter_num = 100\n",
    "score_error_set = []\n",
    "model_set = []\n",
    "for i in range(10):\n",
    "    model_wts = gae_N_n_trainer_batchall(trainloader, gdae, optimizer, scheduler, max_iter_num, use_gpu = useGPU, \n",
    "                saveModel = True, printEpochPeriod = 100, weight_mode = weight_mode)\n",
    "    score_error_set.append(\n",
    "        gae_N_n_estimate_score_error_truncated(traindataset.posAndCov, traindataset.covInv, \n",
    "                                       traindataset.cov_sqrt, idxSet, gdae, \n",
    "                                       printError = False)\n",
    "    )\n",
    "    if i == 0:\n",
    "        best_model = model_wts\n",
    "        min_val = score_error_set[-1]\n",
    "    elif score_error_set[-1] < min_val:\n",
    "        best_model = model_wts\n",
    "        min_val = score_error_set[-1]\n",
    "    print(score_error_set[-1])\n",
    "    #print(best_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gdae.load_state_dict(best_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Data to Filter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "shuffle = False\n",
    "randPos = False\n",
    "roi_range = [0, 143, 0, 181, 0, 13]\n",
    "data_prefix_test = '_fix'\n",
    "            \n",
    "testdataset = N_n_DataSet(filelist, \n",
    "                           match_num_per_subject = None, shuffle = shuffle, randPos = randPos,\n",
    "                           dx = dx, roi_range = roi_range, data_prefix = data_prefix_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Mean Shift for an Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7f018401a828>"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_num = 0\n",
    "start_num = 0\n",
    "end_num = testdataset.split_nums[0]\n",
    "for i in range(len(testdataset.split_nums)):\n",
    "    if i > 0:\n",
    "        start_num += testdataset.split_nums[i-1]\n",
    "        end_num += testdataset.split_nums[i]\n",
    "    if i == input_num:\n",
    "        input = testdataset.posAndCov[start_num:end_num].cuda()\n",
    "        cur_posMetric = testdataset.covInv[start_num:end_num]\n",
    "        cur_posMetric_sqrt = testdataset.covInv_sqrt[start_num:end_num]\n",
    "        cur_posMetricInv_sqrt = testdataset.cov_sqrt[start_num:end_num]\n",
    "torch.manual_seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.0018)\n",
      "tensor(0.1939)\n"
     ]
    }
   ],
   "source": [
    "### get constant for threshold\n",
    "posMetric_sq = torch.bmm(cur_posMetric, cur_posMetric)\n",
    "const1 = torch.sum(cur_posMetric * torch.eye(3).view(1,3,3)) \\\n",
    "/ torch.sum(posMetric_sq * torch.eye(3).view(1,3,3))\n",
    "print(const1)\n",
    "posMetricInv = torch.bmm(cur_posMetricInv_sqrt.permute(0,2,1), cur_posMetricInv_sqrt)\n",
    "posMetricInv_sq = torch.bmm(posMetricInv, posMetricInv)\n",
    "const2 = torch.sum(posMetricInv_sq * torch.eye(3).view(1,3,3)) / \\\n",
    "torch.sum(posMetricInv * torch.eye(3).view(1,3,3))\n",
    "print(const2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(1.6753e-05, device='cuda:0')\n",
      "tensor(0.4123, device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "from DTI_meanShift2 import N_n_meanShift_dae\n",
    "from torch_batch_svd import svd\n",
    "from data_util import vector2tensor_1dim\n",
    "\n",
    "step_size = 0.1\n",
    "pos_threshold = 0.002*step_size*float(torch.sqrt(const1))\n",
    "cov_threshold = 0.0001*step_size\n",
    "pos_group_threshold = 10*float(torch.sqrt(const1))\n",
    "cov_group_threshold = 0.5\n",
    "eliminate_threshold = 10\n",
    "lowerbound_pos, _ = torch.min(input[:,:3],dim=0)\n",
    "upperbound_pos, _ = torch.max(input[:,:3],dim=0)\n",
    "\n",
    "U,S,V = svd(vector2tensor_1dim(input[:,3:]))\n",
    "alpha = 0.1\n",
    "lowerbound_s = torch.min(S)\n",
    "upperbound_s = torch.max(S)\n",
    "print(lowerbound_s)\n",
    "print(upperbound_s)\n",
    "lowerbound = torch.cat((lowerbound_pos.view(-1), lowerbound_s.view(-1)), 0)\n",
    "upperbound = torch.cat((upperbound_pos.view(-1), upperbound_s.view(-1)), 0)\n",
    "pos_metric_choice = 'riemannian'\n",
    "ms_gdae = N_n_meanShift_dae(gdae, step_size, pos_threshold = pos_threshold, cov_threshold = cov_threshold,\n",
    "                pos_group_threshold = pos_group_threshold, cov_group_threshold = cov_group_threshold, \n",
    "                           eliminate_threshold = eliminate_threshold, \n",
    "                            lowerBound = lowerbound, upperBound = upperbound, pos_metric = pos_metric_choice)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([274595, 9])\n",
      "iter: 1, num points to shift: 274595\n",
      "tensor(0.0029, device='cuda:0')\n",
      "tensor(2.5477e-05, device='cuda:0')\n",
      "iter: 2, num points to shift: 274595\n",
      "tensor(0.0029, device='cuda:0')\n",
      "tensor(2.5477e-05, device='cuda:0')\n",
      "iter: 3, num points to shift: 274595\n",
      "tensor(0.0028, device='cuda:0')\n",
      "tensor(2.5462e-05, device='cuda:0')\n",
      "iter: 4, num points to shift: 274595\n",
      "tensor(0.0028, device='cuda:0')\n",
      "tensor(2.5451e-05, device='cuda:0')\n",
      "iter: 5, num points to shift: 274595\n",
      "tensor(0.0027, device='cuda:0')\n",
      "tensor(2.5455e-05, device='cuda:0')\n",
      "iter: 6, num points to shift: 274595\n",
      "tensor(0.0027, device='cuda:0')\n",
      "tensor(2.5441e-05, device='cuda:0')\n",
      "iter: 7, num points to shift: 274595\n",
      "tensor(0.0026, device='cuda:0')\n",
      "tensor(2.5431e-05, device='cuda:0')\n",
      "iter: 8, num points to shift: 274595\n",
      "tensor(0.0026, device='cuda:0')\n",
      "tensor(2.5442e-05, device='cuda:0')\n",
      "iter: 9, num points to shift: 274595\n",
      "tensor(0.0026, device='cuda:0')\n",
      "tensor(2.5413e-05, device='cuda:0')\n",
      "iter: 10, num points to shift: 274595\n",
      "tensor(0.0026, device='cuda:0')\n",
      "tensor(2.5418e-05, device='cuda:0')\n",
      "iter: 11, num points to shift: 274595\n",
      "tensor(0.0026, device='cuda:0')\n",
      "tensor(2.5417e-05, device='cuda:0')\n",
      "iter: 12, num points to shift: 274595\n",
      "tensor(0.0025, device='cuda:0')\n",
      "tensor(2.5411e-05, device='cuda:0')\n",
      "iter: 13, num points to shift: 274595\n",
      "tensor(0.0025, device='cuda:0')\n",
      "tensor(2.5391e-05, device='cuda:0')\n",
      "iter: 14, num points to shift: 274595\n",
      "tensor(0.0025, device='cuda:0')\n",
      "tensor(2.5383e-05, device='cuda:0')\n",
      "iter: 15, num points to shift: 274595\n",
      "tensor(0.0025, device='cuda:0')\n",
      "tensor(2.5380e-05, device='cuda:0')\n",
      "iter: 16, num points to shift: 274595\n",
      "tensor(0.0025, device='cuda:0')\n",
      "tensor(2.5377e-05, device='cuda:0')\n",
      "iter: 17, num points to shift: 274595\n",
      "tensor(0.0025, device='cuda:0')\n",
      "tensor(2.5382e-05, device='cuda:0')\n",
      "iter: 18, num points to shift: 274595\n",
      "tensor(0.0024, device='cuda:0')\n",
      "tensor(2.5360e-05, device='cuda:0')\n",
      "iter: 19, num points to shift: 274595\n",
      "tensor(0.0024, device='cuda:0')\n",
      "tensor(2.5359e-05, device='cuda:0')\n",
      "iter: 20, num points to shift: 274595\n",
      "tensor(0.0024, device='cuda:0')\n",
      "tensor(2.5372e-05, device='cuda:0')\n",
      "iter: 21, num points to shift: 274595\n",
      "tensor(0.0024, device='cuda:0')\n",
      "tensor(2.5364e-05, device='cuda:0')\n",
      "iter: 22, num points to shift: 274595\n",
      "tensor(0.0024, device='cuda:0')\n",
      "tensor(2.5351e-05, device='cuda:0')\n",
      "iter: 23, num points to shift: 274595\n",
      "tensor(0.0023, device='cuda:0')\n",
      "tensor(2.5341e-05, device='cuda:0')\n",
      "iter: 24, num points to shift: 274595\n",
      "tensor(0.0023, device='cuda:0')\n",
      "tensor(2.5338e-05, device='cuda:0')\n",
      "iter: 25, num points to shift: 274595\n",
      "tensor(0.0023, device='cuda:0')\n",
      "tensor(2.5325e-05, device='cuda:0')\n",
      "iter: 26, num points to shift: 274595\n",
      "tensor(0.0023, device='cuda:0')\n",
      "tensor(2.5329e-05, device='cuda:0')\n",
      "iter: 27, num points to shift: 274595\n",
      "tensor(0.0023, device='cuda:0')\n",
      "tensor(2.5337e-05, device='cuda:0')\n",
      "iter: 28, num points to shift: 274595\n",
      "tensor(0.0023, device='cuda:0')\n",
      "tensor(2.5326e-05, device='cuda:0')\n",
      "iter: 29, num points to shift: 274595\n",
      "tensor(0.0022, device='cuda:0')\n",
      "tensor(2.5323e-05, device='cuda:0')\n",
      "iter: 30, num points to shift: 274595\n",
      "tensor(0.0022, device='cuda:0')\n",
      "tensor(2.5321e-05, device='cuda:0')\n",
      "iter: 31, num points to shift: 274595\n",
      "tensor(0.0022, device='cuda:0')\n",
      "tensor(2.5309e-05, device='cuda:0')\n",
      "iter: 32, num points to shift: 274595\n",
      "tensor(0.0022, device='cuda:0')\n",
      "tensor(2.5310e-05, device='cuda:0')\n",
      "iter: 33, num points to shift: 274595\n",
      "tensor(0.0022, device='cuda:0')\n",
      "tensor(2.5319e-05, device='cuda:0')\n",
      "iter: 34, num points to shift: 274595\n",
      "tensor(0.0022, device='cuda:0')\n",
      "tensor(2.5061e-05, device='cuda:0')\n",
      "iter: 35, num points to shift: 274595\n",
      "tensor(0.0021, device='cuda:0')\n",
      "tensor(2.4291e-05, device='cuda:0')\n",
      "iter: 36, num points to shift: 274595\n",
      "tensor(0.0021, device='cuda:0')\n",
      "tensor(2.3537e-05, device='cuda:0')\n",
      "iter: 37, num points to shift: 274595\n",
      "tensor(0.0021, device='cuda:0')\n",
      "tensor(2.2778e-05, device='cuda:0')\n",
      "iter: 38, num points to shift: 274595\n",
      "tensor(0.0021, device='cuda:0')\n",
      "tensor(2.2069e-05, device='cuda:0')\n",
      "iter: 39, num points to shift: 274595\n",
      "tensor(0.0021, device='cuda:0')\n",
      "tensor(2.1386e-05, device='cuda:0')\n",
      "iter: 40, num points to shift: 274595\n",
      "tensor(0.0021, device='cuda:0')\n",
      "tensor(2.0754e-05, device='cuda:0')\n",
      "iter: 41, num points to shift: 274595\n",
      "tensor(0.0021, device='cuda:0')\n",
      "tensor(2.0076e-05, device='cuda:0')\n",
      "iter: 42, num points to shift: 274595\n",
      "tensor(0.0020, device='cuda:0')\n",
      "tensor(1.9525e-05, device='cuda:0')\n",
      "iter: 43, num points to shift: 274595\n",
      "tensor(0.0020, device='cuda:0')\n",
      "tensor(1.8936e-05, device='cuda:0')\n",
      "iter: 44, num points to shift: 274595\n",
      "tensor(0.0020, device='cuda:0')\n",
      "tensor(1.8389e-05, device='cuda:0')\n",
      "iter: 45, num points to shift: 274595\n",
      "tensor(0.0020, device='cuda:0')\n",
      "tensor(1.7853e-05, device='cuda:0')\n",
      "iter: 46, num points to shift: 274595\n",
      "tensor(0.0020, device='cuda:0')\n",
      "tensor(1.7325e-05, device='cuda:0')\n",
      "iter: 47, num points to shift: 274595\n",
      "tensor(0.0020, device='cuda:0')\n",
      "tensor(1.6861e-05, device='cuda:0')\n",
      "iter: 48, num points to shift: 274595\n",
      "tensor(0.0019, device='cuda:0')\n",
      "tensor(1.6415e-05, device='cuda:0')\n",
      "iter: 49, num points to shift: 274595\n",
      "tensor(0.0019, device='cuda:0')\n",
      "tensor(1.6005e-05, device='cuda:0')\n",
      "iter: 50, num points to shift: 274595\n",
      "tensor(0.0019, device='cuda:0')\n",
      "tensor(1.5556e-05, device='cuda:0')\n",
      "iter: 51, num points to shift: 274595\n",
      "tensor(0.0019, device='cuda:0')\n",
      "tensor(1.5138e-05, device='cuda:0')\n",
      "iter: 52, num points to shift: 274595\n",
      "tensor(0.0019, device='cuda:0')\n",
      "tensor(1.4741e-05, device='cuda:0')\n",
      "iter: 53, num points to shift: 274595\n",
      "tensor(0.0019, device='cuda:0')\n",
      "tensor(1.4361e-05, device='cuda:0')\n",
      "iter: 54, num points to shift: 274595\n",
      "tensor(0.0019, device='cuda:0')\n",
      "tensor(1.4015e-05, device='cuda:0')\n",
      "iter: 55, num points to shift: 274595\n",
      "tensor(0.0018, device='cuda:0')\n",
      "tensor(1.3673e-05, device='cuda:0')\n",
      "iter: 56, num points to shift: 274595\n",
      "tensor(0.0018, device='cuda:0')\n",
      "tensor(1.3349e-05, device='cuda:0')\n",
      "iter: 57, num points to shift: 274595\n",
      "tensor(0.0018, device='cuda:0')\n",
      "tensor(1.3068e-05, device='cuda:0')\n",
      "iter: 58, num points to shift: 274595\n",
      "tensor(0.0018, device='cuda:0')\n",
      "tensor(1.2734e-05, device='cuda:0')\n",
      "iter: 59, num points to shift: 274595\n",
      "tensor(0.0018, device='cuda:0')\n",
      "tensor(1.2485e-05, device='cuda:0')\n",
      "iter: 60, num points to shift: 274595\n",
      "tensor(0.0018, device='cuda:0')\n",
      "tensor(1.2204e-05, device='cuda:0')\n",
      "iter: 61, num points to shift: 274595\n",
      "tensor(0.0018, device='cuda:0')\n",
      "tensor(1.1965e-05, device='cuda:0')\n",
      "iter: 62, num points to shift: 274595\n",
      "tensor(0.0018, device='cuda:0')\n",
      "tensor(1.1714e-05, device='cuda:0')\n",
      "iter: 63, num points to shift: 274595\n",
      "tensor(0.0017, device='cuda:0')\n",
      "tensor(1.1485e-05, device='cuda:0')\n",
      "iter: 64, num points to shift: 274595\n",
      "tensor(0.0017, device='cuda:0')\n",
      "tensor(1.1251e-05, device='cuda:0')\n",
      "iter: 65, num points to shift: 274595\n",
      "tensor(0.0017, device='cuda:0')\n",
      "tensor(1.1072e-05, device='cuda:0')\n",
      "iter: 66, num points to shift: 274595\n",
      "tensor(0.0017, device='cuda:0')\n",
      "tensor(1.0888e-05, device='cuda:0')\n",
      "iter: 67, num points to shift: 274595\n",
      "tensor(0.0017, device='cuda:0')\n",
      "tensor(1.0695e-05, device='cuda:0')\n",
      "iter: 68, num points to shift: 274595\n",
      "tensor(0.0017, device='cuda:0')\n",
      "tensor(1.0544e-05, device='cuda:0')\n",
      "iter: 69, num points to shift: 274595\n",
      "tensor(0.0017, device='cuda:0')\n",
      "tensor(1.0378e-05, device='cuda:0')\n",
      "iter: 70, num points to shift: 274595\n",
      "tensor(0.0017, device='cuda:0')\n",
      "tensor(1.0221e-05, device='cuda:0')\n",
      "iter: 71, num points to shift: 274595\n",
      "tensor(0.0016, device='cuda:0')\n",
      "tensor(1.0075e-05, device='cuda:0')\n",
      "iter: 72, num points to shift: 274595\n",
      "tensor(0.0016, device='cuda:0')\n",
      "tensor(9.9544e-06, device='cuda:0')\n",
      "iter: 73, num points to shift: 274595\n",
      "tensor(0.0016, device='cuda:0')\n",
      "tensor(9.8183e-06, device='cuda:0')\n",
      "iter: 74, num points to shift: 274595\n",
      "tensor(0.0016, device='cuda:0')\n",
      "tensor(9.6792e-06, device='cuda:0')\n",
      "iter: 75, num points to shift: 274595\n",
      "tensor(0.0016, device='cuda:0')\n",
      "tensor(9.5914e-06, device='cuda:0')\n",
      "iter: 76, num points to shift: 274595\n",
      "tensor(0.0016, device='cuda:0')\n",
      "tensor(9.5031e-06, device='cuda:0')\n",
      "iter: 77, num points to shift: 274595\n",
      "tensor(0.0016, device='cuda:0')\n",
      "tensor(9.4090e-06, device='cuda:0')\n",
      "iter: 78, num points to shift: 274595\n",
      "tensor(0.0016, device='cuda:0')\n",
      "tensor(9.3070e-06, device='cuda:0')\n",
      "iter: 79, num points to shift: 274595\n",
      "tensor(0.0016, device='cuda:0')\n",
      "tensor(9.2232e-06, device='cuda:0')\n",
      "iter: 80, num points to shift: 274595\n",
      "tensor(0.0016, device='cuda:0')\n",
      "tensor(9.1215e-06, device='cuda:0')\n",
      "iter: 81, num points to shift: 274595\n",
      "tensor(0.0016, device='cuda:0')\n",
      "tensor(9.0676e-06, device='cuda:0')\n",
      "iter: 82, num points to shift: 274595\n",
      "tensor(0.0016, device='cuda:0')\n",
      "tensor(8.9857e-06, device='cuda:0')\n",
      "iter: 83, num points to shift: 274595\n",
      "tensor(0.0015, device='cuda:0')\n",
      "tensor(8.9346e-06, device='cuda:0')\n",
      "iter: 84, num points to shift: 274595\n",
      "tensor(0.0015, device='cuda:0')\n",
      "tensor(8.8774e-06, device='cuda:0')\n",
      "iter: 85, num points to shift: 274595\n",
      "tensor(0.0015, device='cuda:0')\n",
      "tensor(8.6888e-06, device='cuda:0')\n",
      "iter: 86, num points to shift: 274595\n",
      "tensor(0.0015, device='cuda:0')\n",
      "tensor(8.4974e-06, device='cuda:0')\n",
      "iter: 87, num points to shift: 274595\n",
      "tensor(0.0015, device='cuda:0')\n",
      "tensor(8.3155e-06, device='cuda:0')\n",
      "iter: 88, num points to shift: 274595\n",
      "tensor(0.0015, device='cuda:0')\n",
      "tensor(8.1266e-06, device='cuda:0')\n",
      "iter: 89, num points to shift: 274595\n",
      "tensor(0.0015, device='cuda:0')\n",
      "tensor(7.9894e-06, device='cuda:0')\n",
      "iter: 90, num points to shift: 274595\n",
      "tensor(0.0015, device='cuda:0')\n",
      "tensor(7.7993e-06, device='cuda:0')\n",
      "iter: 91, num points to shift: 274595\n",
      "tensor(0.0015, device='cuda:0')\n",
      "tensor(7.6311e-06, device='cuda:0')\n",
      "iter: 92, num points to shift: 274595\n",
      "tensor(0.0015, device='cuda:0')\n",
      "tensor(7.4744e-06, device='cuda:0')\n",
      "iter: 93, num points to shift: 274595\n",
      "tensor(0.0015, device='cuda:0')\n",
      "tensor(7.3348e-06, device='cuda:0')\n",
      "iter: 94, num points to shift: 274595\n",
      "tensor(0.0015, device='cuda:0')\n",
      "tensor(7.1527e-06, device='cuda:0')\n",
      "iter: 95, num points to shift: 274595\n",
      "tensor(0.0014, device='cuda:0')\n",
      "tensor(7.0329e-06, device='cuda:0')\n",
      "iter: 96, num points to shift: 274595\n",
      "tensor(0.0014, device='cuda:0')\n",
      "tensor(6.9072e-06, device='cuda:0')\n",
      "iter: 97, num points to shift: 274595\n",
      "tensor(0.0014, device='cuda:0')\n",
      "tensor(6.7664e-06, device='cuda:0')\n",
      "iter: 98, num points to shift: 274595\n",
      "tensor(0.0014, device='cuda:0')\n",
      "tensor(6.6606e-06, device='cuda:0')\n",
      "iter: 99, num points to shift: 274595\n",
      "tensor(0.0014, device='cuda:0')\n",
      "tensor(6.5267e-06, device='cuda:0')\n",
      "iter: 100, num points to shift: 274595\n",
      "tensor(0.0014, device='cuda:0')\n",
      "tensor(6.3766e-06, device='cuda:0')\n",
      "mean shift terminated !!! elapsed time: 3086.6\n"
     ]
    }
   ],
   "source": [
    "max_iter = 100\n",
    "test_num = 1\n",
    "le_errorsSet = []\n",
    "le_wt_errorsSet = []\n",
    "ai_errorsSet = []\n",
    "ai_wt_errorsSet = []\n",
    "for i in range(test_num):\n",
    "    noised_input = input.clone()\n",
    "    print(input.shape)\n",
    "    filteringResults = ms_gdae.run_meanShift(noised_input, max_iter, save_prefix = file_prefix, \n",
    "            pos_metric = pos_metric_choice, save_iter=None, cleanInput = input.clone(), error_weight = None)\n",
    "    le_errorsSet.append(filteringResults.le_errors.view(1,-1))\n",
    "    le_wt_errorsSet.append(filteringResults.le_wt_errors.view(1,-1))\n",
    "    ai_errorsSet.append(filteringResults.ai_errors.view(1,-1))\n",
    "    ai_wt_errorsSet.append(filteringResults.ai_wt_errors.view(1,-1))\n",
    "le_errorsSet = torch.cat(le_errorsSet, 0)\n",
    "le_wt_errorsSet = torch.cat(le_wt_errorsSet, 0)\n",
    "ai_errorsSet = torch.cat(ai_errorsSet, 0)\n",
    "ai_wt_errorsSet = torch.cat(ai_wt_errorsSet, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7f001816b668>]"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAD4CAYAAADo30HgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAnfklEQVR4nO3dd5gUVdr+8e/TXc2Qo4qSBBFRTKgjhjWtETFg4BXWuK677r6ru+aA62tAUTEhSlBgyAiShJEVJJsRBlHJMgLKEAYkSRyY7vP7o8v9zcwOMsAMNd19f67La7pOnap6zlU493Sd6i5zziEiIvKrUNAFiIhI+aJgEBGRQhQMIiJSiIJBREQKUTCIiEghXtAFlIbDDjvMNW7cOOgyREQSypw5c352zh1etD0pgqFx48ZkZWUFXYaISEIxsx+La9elJBERKUTBICIihSgYRESkEAWDiIgUomAQEZFCFAwiIlKIgkFERApRMIiIJKDtefk8k7mAX3btKfV9KxhERBLMuq27uPvtCZyVdT/fLVpS6vtPik8+i4ikimXrt/FExji67HyWBpEthGtsKPVjKBhERBLE3J828fqAYfSKvUi1Sh7hW/8NDdJL/TgKBhGRBDBlYS4jh/Wld7gbXvW6eLe/D4cdWybHUjCIiJRzQ7/6kXmZb9EzkkHsiJOI3DYKqtUts+MpGEREyinnHF0nLSH0aRdeiowheszFRNoPgrRqZXpcBYOISDm0JxrjiVFzOWNeJzp4M4idejPha9+EcKTMj61gEBEpZ7bl5XP/oM+45aen+L33Le78hwld/CSYHZLjKxhERMqR3F928UDGJDpueooTvZ/g6m7YGX88pDUoGEREyomluVt5KmMsL+d1ol5kK6H2w+C4Kw55HQoGEZFyYOayDfQcNIRevEzVSmmEb/s31D8jkFoUDCIiAcv8djWTRvamr9cdq9kQ7/bRUPuYwOpRMIiIBMQ5xzufLGPdpK68GRlCrN4ZeLeMgCp1Aq1LwSAiEoD8aIxOmfM5es4LPBWZQLT51Xg39oEKlYMuTcEgInKo7didz0NDZ3LNsmdp483Ctfor4dYvQigcdGmAgkFE5JBavzWP+/tP4YGfn+GM8FK4vDN27r1Bl1WIgkFE5BDJXreVJzIyeXnXczT0NmA3DoATrwu6rP+iYBAROQS+WraBtwa9yzu8TLWKYcK3jIdGZwVdVrEUDCIiZWzcN6uYPKovGV53QjXq4d0+Buo0DbqsvVIwiIiUEeccPadns3HqG7wZGerfjvoeVDks6NJ+k4JBRKQM7InGeOr9bznumxe5J/IR0eOvxruxL0QqBV3aPikYRERK2dZde3hg8Oe0/+lZLvO+xp1zL+HLnoNQKOjSSkTBICJSilZv3slD/SbRcfMznBReAW1exVr9Jeiy9ouCQUSklMxftYXn+4/m9T2dqRvZTuimd6H5lUGXtd8UDCIipWDqolzeHTaQvqGuVKxSjfCtE6Fey6DLOiAKBhGRgzTg8+Us+rAn70QyoE4zvNtGQc2GQZd1wEo0E2Jmrc1siZllm9njxaxPM7P3/PVfmVnjAus6+u1LzOwKv62hmU03s4VmtsDM7ivQv7aZTTazpf7PWqUwThGRUheNOZ7NnMe2CU/TJdIba3I+3l8mJXQoQAmCwczCQA/gSqAF8Acza1Gk213AJufcsUBXoIu/bQugA3Ai0Bro6e8vH3jIOdcCOBu4p8A+HwemOueaAVP9ZRGRcmV7Xj73DPyC02Y/wr3eOGKn30H41lFQsUbQpR20krxjaAVkO+eWOed2A8OBtkX6tAUG+q9HAZeYmfntw51zec655UA20Mo5t8Y59zWAc24rsAioX8y+BgLXHdDIRETKyNotu/hzr4n8efn9XBv+Ei59htA13SAcCbq0UlGSOYb6wMoCyzlA0S/4+E8f51y+mW0B6vjtM4tsW7/ghv5lp9OAr/ymus65Nf7rtUDd4ooys7uBuwEaNWpUgmGIiBy8Bau30Kn/WF7Z3Zn6kc1ww8By+UV4ByPQyWczqwqMBu53zv1SdL1zzpmZK25b51xvoDdAenp6sX1ERErT1EW5DB42mD6h16lcuRLhWz6EBulBl1XqShIMq4CCMykN/Lbi+uSYmQfUADb81rZmFiEeCkOdc2MK9Mk1s6Occ2vM7Chg3X6MR0SkTPT37zzqG8mA2k3xbhsJtRoHXVaZKMkcw2ygmZk1MbMKxCeTM4v0yQTu8F+3A6Y555zf3sG/a6kJ0AyY5c8/ZACLnHOv/8a+7gDG7e+gRERKS340xtNjv2PHhKd4OdIba3we3l8mJ20oQAneMfhzBvcCHwFhoJ9zboGZdQKynHOZxH/JDzazbGAj8fDA7zcCWEj8TqR7nHNRMzsPuA2YZ2bf+Id6wjn3IfASMMLM7gJ+BG4qxfGKiJTY1l17eHDIl1z343Nc5c2K33l01WtJM8m8Nxb/wz6xpaenu6ysrKDLEJEkkrNpBw/3m0zHLc9ySmgZdvlzcM69YBZ0aaXGzOY45/5rkkSffBYRKeLrnzbRZeBoukZfpG5kO9ZuCJxwddBlHTIKBhGRAj74djWZowbQP9yNClVqEL5lAtQ7LeiyDikFg4gI8aetvTUtmw3T3uKdyGBiR5yId+sIqF4v6NIOOQWDiKS8XXuidBz1Nact6MI/I5OJNr8K78Y+UKFK0KUFQsEgIint52153D9gBnfnPscF3jzcuf8kfOkzEAoHXVpgFAwikrIWr/2F/+s3nhfznucYLxeu6Y6dflvQZQVOwSAiKWnqolwGDhtKn9DrVK0YItRhLDQ5P+iyygUFg4ikFOccGZ8tZ/HEt8mIZGC1jsa7dSTUaRp0aeWGgkFEUsbu/BhPvf8tjb99jVcjHxBtfCHh9gOhkp4HVpCCQURSwsbtu7l/0Kfctrozl3lzcGf8iXCbl5P+6y0OhIJBRJLe0tytPNH/Q57b+TzHeavgylexVn8JuqxyS8EgIklt+pJ1ZLw7nLftVWqkxQjdNBKOvSTosso1BYOIJCXnHP0+X8HCCW/TP9IXq9kQ75YRcPhxQZdW7ikYRCTp5OVH/Unm13kt8gHRo88n3H4QVK4ddGkJQcEgIknl5215PDDoU25f05nLvK81yXwAFAwikjQWrv6Fpwb8m855nWnmrdYk8wFSMIhIUpg4fy3vjniXvqHXqZpmhNqPhqa/D7qshKRgEJGE5pyj+7RsVk/rRb/IAKh9DN4t7+mTzAdBwSAiCWvn7iiPjpzDGYte4cXIJKLHXEL4pv5QsUbQpSU0BYOIJKTVm3fy4IBp/GNjZ37nLcCdcy/hyzql9NdllxYFg4gknKwVG+kyeByv5b9IA28jXNsLa3lz0GUlDQWDiCSU92b/xLRxAxng9SCtclVCN0+AhmcGXVZSUTCISELYE43RefxCKs/qRq/ISGJ1T8G7eRjUqB90aUlHwSAi5d6m7bt5YMgX3JjzItdEZhI7qR1e2+4QqRR0aUlJwSAi5dritb/wrwET6LTzBVqEf4RLnyH0u/vBLOjSkpaCQUTKrYnz1zB0xHD6hF6nRgWH/c8IOO7yoMtKegoGESl3YjHHG1O+Z8PHb9M/MhBqNSZ8y3twWLOgS0sJCgYRKVe27trDI+9lccHSl3kwMo1o08sIt+sLlWoGXVrKUDCISLmxbP02Hhs4hcd+eYF0bwnuvAcJX/ykPrR2iCkYRKRcmL54Hb2Hj6Q7r3JYZCdc3w876cagy0pJCgYRCZRzjp4zfmDZlD4MjGQQqn4k4Zsz4ciTgy4tZSkYRCQw2/PyeXTEHNKXvMZrkY+INr6Q8E0D9KS1gCkYRCQQK37ezqMDp/Dglhc521uEO/ue+JfghfVrKWihknQys9ZmtsTMss3s8WLWp5nZe/76r8yscYF1Hf32JWZ2RYH2fma2zszmF9nXM2a2ysy+8f9rcxDjE5FyaPridTzWfRDdtj7AmZFlcEMfrPULCoVyYp9nwczCQA/gMiAHmG1mmc65hQW63QVscs4da2YdgC5AezNrAXQATgTqAVPM7DjnXBQYAHQHBhVz2K7OuVcPYlwiUg7FYo6eM7JZPrUvgyMZhKodQfgPY6Bey6BLkwJK8o6hFZDtnFvmnNsNDAfaFunTFhjovx4FXGJm5rcPd87lOeeWA9n+/nDOfQJsLIUxiEgC2LprD/cM/oqq057gtcjbhI8+G+9vnygUyqGSBEN9YGWB5Ry/rdg+zrl8YAtQp4TbFudeM/vOv9xUq7gOZna3mWWZWdb69etLsEsRCcoP67dx51vjufOH+/ijNwl39t8J3z4WqhwWdGlSjBLNMRxivYCmQEtgDfBacZ2cc72dc+nOufTDDz/8EJYnIvtj0oK1/F/3/vTc/iBnRFbADX2x1i9qPqEcK8mZWQU0LLDcwG8rrk+OmXlADWBDCbctxDmX++trM+sDjC9BjSJSzkRjjq6TlrDh094MjAzEqtfT5xMSREneMcwGmplZEzOrQHwyObNIn0zgDv91O2Cac8757R38u5aaAM2AWb91MDM7qsDi9cD8vfUVkfJp847d/LX/5zT47DFejGQQOuYCvL99rFBIEPt8x+Ccyzeze4GPgDDQzzm3wMw6AVnOuUwgAxhsZtnEJ5Q7+NsuMLMRwEIgH7jHvyMJMxsGXAQcZmY5wNPOuQzgZTNrCThgBfDXUhyviJSxBau38NSgiTy94yVO8ZbB+Q8T/v0T+r6jBGLxP+wTW3p6usvKygq6DJGUN3pODplj3+WN8FtUj8QI3/AOnHB10GXJXpjZHOdcetF2zf6IyEHbnR/juQ8WUDmrB/0j7xGr04zwH4bq+QkJSsEgIgdlzZadPDT4M27NfZk2kVnETrgO77oekFY16NLkACkYROSAffHDz3R9N5Mu+a/Q2MuFy54ndM69eh5zglMwiMh+c87xzifLmDdpAIMivYlUrkqofSY0Pi/o0qQUKBhEZL/8smsPj42YQ/r3b9AjMoFo/TMJtx8E1esFXZqUEgWDiJTY4rW/8OSgyTy2rQtnektwre4mfHln8CoEXZqUIgWDiJTI+3NzGDPmPd4Jv0nNCnnQNgM7uV3QZUkZUDCIyG/Ky4/6t6L2ZEDkPVytYwj/YQgccULQpUkZUTCIyF6t3LiDR4Z8yh/Xv0LryGxiJ7QlfF0PSKsWdGlShhQMIlKs6YvX0Wv4+7zKazTwNsDlLxA6+++6FTUFKBhEpJBozNF18vfkftKXwZEBhKvUJtT+39Do7KBLk0NEwSAi/7Fu6y4eGTqTq3Je5+HIx0QbX0C4XT+oqmeepBIFg4gAMHPZBl5+99+8sOcVjvd+ggseJXzR4/pW1BSkYBBJcbGY4+1PfmDB5IEMjvQhrWIatBsFzS4LujQJiIJBJIVt2r6bR0dkce4Pb9Aj8hHReumEbxoANRvuc1tJXgoGkRT19U+b6DzkI57a9TKnej/gzvpfwpd10qeYRcEgkmqcc/T7fAVfTRhC/8jbVEkzuH4Q1qJt0KVJOaFgEEkhW3bs4fGRczh16Vv0jownWvcUwu0HQu1jgi5NyhEFg0iK+GblZjoNmcSTO1/mdG8pLv0uwle8AJGKQZcm5YyCQSTJ/XrpaObEofT3elE1DWjbDzvpxqBLk3JKwSCSxDbv2M3jI7+m5dK36OONJ/+Ik+LPTqjTNOjSpBxTMIgkqa9/2sTzQz/i/3a+ymneUtwZf8Jr/QJEKgVdmpRzCgaRJBOLOfp8uoy5k4YyIPIOVdKAtv2xk24IujRJEAoGkSSyYVsej42YzbnL3uLtyETyjzyV8P/016Uj2S8KBpEk8eUPG3hl+ASezXuNk71luFZ/xbv8OfDSgi5NEoyCQSTBRWOON6cuZfmMgQyOZFCxYgSuH4KdcE3QpUmCUjCIJLA1W3byyLCZXJXTjQci04nWb0X4fzKgZqOgS5MEpmAQSVBTFuby9shMXoq9QVNvFZz/EOGLOkI4EnRpkuAUDCIJZteeKC99uIg9szIYGhlCuEpN7Mb3oenvgy5NkoSCQSSBZK/bxhNDP+HOja9zZWQ2sWMuJnRDbz1hTUqVgkEkATjnGJmVw7jM0XQLv0ldbwtc2onQOf+AUCjo8iTJKBhEyrktO/fw5JhvaLLwbQZHxhCr0YjQTSOh/hlBlyZJSsEgUo5lrdjIC8Mm03Hna5wZWULs5Jvwrn4d0qoFXZoksRK9BzWz1ma2xMyyzezxYtanmdl7/vqvzKxxgXUd/fYlZnZFgfZ+ZrbOzOYX2VdtM5tsZkv9n7UOYnwiCSk/GuONKd/Tv88bDMx7gNMr5MD1vQnd2EehIGVun8FgZmGgB3Al0AL4g5m1KNLtLmCTc+5YoCvQxd+2BdABOBFoDfT09wcwwG8r6nFgqnOuGTDVXxZJGSs37uCOd2Zw5IxH6BHpRuWjjiP898/g1PZBlyYpoiTvGFoB2c65Zc653cBwoOgzANsCA/3Xo4BLzMz89uHOuTzn3HIg298fzrlPgI3FHK/gvgYC15V8OCKJbdw3q3io20Cez/077b2P4bwHCf95sp6wJodUSeYY6gMrCyznAGftrY9zLt/MtgB1/PaZRbatv4/j1XXOrfFfrwXqFtfJzO4G7gZo1Eif8pTEtnXXHp4ZO4/D5vVmaGQkVuVwrN0H0OT8oEuTFFSuJ5+dc87M3F7W9QZ6A6SnpxfbRyQRzPlxI52HTeHhHW9wbmQBsRPaErrmDahcO+jSJEWVJBhWAQ0LLDfw24rrk2NmHlAD2FDCbYvKNbOjnHNrzOwoYF0JahRJOPnRGG9Oy+aHGUMYEMmgSoUYXNWDUMtbwCzo8iSFlWSOYTbQzMyamFkF4pPJmUX6ZAJ3+K/bAdOcc85v7+DftdQEaAbM2sfxCu7rDmBcCWoUSSgrft7OrT2n0vDjh+gR6UaVo5oR/t/P4LRbFQoSuH2+Y/DnDO4FPgLCQD/n3AIz6wRkOecygQxgsJllE59Q7uBvu8DMRgALgXzgHudcFMDMhgEXAYeZWQ7wtHMuA3gJGGFmdwE/AjeV6ohFAuScY0TWSsZ+MJZXQt1p4K2HCx4hfOFj+vI7KTcs/od9YktPT3dZWVlBlyHymzZsy+PJMXM5/vt3+Ic3lli1enjt+sDR5wZdmqQoM5vjnEsv2l6uJ59FksX0xet4a+REnt7zBqd6P+BO6YDX5mWoWCPo0kT+i4JBpAzt2J1P5/ELsTn9eDcylEilSnBNf+ykG4IuTWSvFAwiZWTuT5t4fvh07tnajYsj3xA95veEr+sJ1esFXZrIb1IwiJSyPdEYb03LJnvGUDIiGVSvsBsuf5nwmX/RV2RLQlAwiJSipblb+dfwz2n/c3cejHxK9MhTCd3YFw4/LujSREpMwSBSCmIxR7/Pl/PppFF0C79DXW8TnP8o4Qsf1W2oknAUDCIHaeXGHXQcMYuLc3ox0JtIfq2mhG4cAQ30IB1JTAoGkQPknOO92SsZO34cL1hPjvFW4878C95lnaBC5aDLEzlgCgaRA5D7yy7+NSqLlst68673AbGqR8IN47BjLgq6NJGDpmAQ2Q/OOcZ9s5oh48bzvOvO8d5PuFNvxrvyJX1YTZKGgkGkhNZvzeP/xnzDsd/3YVjkfaxKbWg7HGt+ZdCliZQqBYPIPjjnGP/dGgaO/ZCnYz04ObKM2Ik3ErrqVT0zQZKSgkHkN/y8LY+n3/+Woxf3ZVhkNKFK1eGagYROvC7o0kTKjIJBpBi/vksYMHYCT0d7cErkB2LHX0vo6teh6uFBlydSphQMIkWs3xp/l9B4SV+GR0ZjlavD1f0InXiDHqIjKUHBIOL7zx1HmR/ydKxnfC7hhLaErnpN7xIkpSgYRIC1W3bx9PtzOW5pX4ZFxmKVa8DVAwideH3QpYkccgoGSWnOOUZm5TDq3+N51vXihMiP8TuO2rwCVeoEXZ5IIBQMkrJWbtzBU6PnkP5jb4Z543FVDoNrhhI64eqgSxMJlIJBUk405hj4xQqmfDSO50PvxL/jqOUt2BWdoVKtoMsTCZyCQVLK97lbeXrkTK5Y25sh3mRi1epD2zHYsZcEXZpIuaFgkJSQlx+l5/QfmP/xKF73MjjS2wCt7sa75ClIqxp0eSLlioJBkl7Wio28OOozbtvSiwe8L8iv0xy7bhg0bBV0aSLlkoJBktbWXXvoMmERO2cPpV+FIVSL5MEFHfHOewC8tKDLEym3FAySlCbOX0vvcVN4MK8X51WYT7R+K0Jt34Ijjg+6NJFyT8EgSWXNlp08O/Zbmnzfn2GRMXhpaXD5a4TP+BOEQkGXJ5IQFAySFKIxx6AvVzDpow94xvrQPPITseOvIdTmZaheL+jyRBKKgkES3rycLXQe/SVXre/DUG8qsapHwdXDCB3fJujSRBKSgkES1tZde3jtoyVsmDWc7pHB1PF+gbP+hnfxvyCtWtDliSQsBYMkHOccH85bS0bmFO7f/Q4XROaRf+Sp2LVvQr2WQZcnkvAUDJJQVvy8nU7j5nLisv4Mj4wjnJYGl76Cd+ZdEAoHXZ5IUlAwSELYtSfK2x//wDcfv88zof40jqwh1uJ6Qq1fhOpHBV2eSFJRMEi5N33xOt4a+wl3bu/D/eGZ5NdsAlePIaTvNxIpEwoGKbdWbtzBCx98R/3vBzEkMpq0Cg7OfwLvd/dBpGLQ5YkkrRJ94sfMWpvZEjPLNrPHi1mfZmbv+eu/MrPGBdZ19NuXmNkV+9qnmQ0ws+Vm9o3/X8uDG6Ikml17orw5dSlPdO3Jg8vu4snIUCo2u5DwPV/BRY8pFETK2D7fMZhZGOgBXAbkALPNLNM5t7BAt7uATc65Y82sA9AFaG9mLYAOwIlAPWCKmR3nb/Nb+3zEOTeqFMYnCWbqolx6ZX7C7dsy+Gf4S/KrN4I2+kyCyKFUkktJrYBs59wyADMbDrQFCgZDW+AZ//UooLuZmd8+3DmXByw3s2x/f5Rgn5JClv+8nRcy53LsD4MZHBlLhQoOzu/oXzaqFHR5IimlJMFQH1hZYDkHOGtvfZxz+Wa2Bajjt88ssm19//Vv7bOzmT0FTAUe94OlEDO7G7gboFGjRiUYhpRH2/Py6T49m+zPxvCv8KD43UbHtSHU+gWo3STo8kRSUnmcfO4IrAUqAL2Bx4BORTs553r760lPT3eHskA5eM45xn6ziqEfTuN/d/XjMW8u+bWawlWjCR17adDliaS0kgTDKqBhgeUGfltxfXLMzANqABv2sW2x7c65NX5bnpn1Bx4uQY2SQL5ZuZmXx83mgrUDGOZNJFSxIlz0HN5ZfwOvQtDliaS8kgTDbKCZmTUh/su7A3BzkT6ZwB3Al0A7YJpzzplZJvCumb1OfPK5GTALsL3t08yOcs6t8ecorgPmH9wQpbzI/WUXr0xYSOi7d3krMpI63mZcy1uwS56GanWDLk9EfPsMBn/O4F7gIyAM9HPOLTCzTkCWcy4TyAAG+5PLG4n/osfvN4L4pHI+cI9zLgpQ3D79Qw41s8OJh8c3wN9KbbQSiJ27o/T5dBmzZ3zAY6FBnBRZTrT+mdCmC1b/jKDLE5EizLnEvzyfnp7usrKygi5DiojFHB98t5rBH87gTzsH0CY8i/yq9fAu7wQntwOzoEsUSWlmNsc5l160vTxOPksSmL1iI10zZ3HBusHxeYS0CJz/L7xz7oUKlYMuT0R+g4JBStXyn7fz6oT51Fk8lJ6RMdTwtsGpHbBLntKT1EQShIJBSsXG7bt5c8r3rJ01mse8YTSJrCF69PlY685w1KlBlyci+0HBIAdl5+4o/T5fzuczJvKAG8SZkSXk1z4OruhK+LjWmkcQSUAKBjkg+dEYY75exchJM/jjrkHcE55FfuXD4ZKueKfdDmH90xJJVPq/V/aLc47JC3PpM2Em12wewjBvGpaWBud1jE8sp1UNukQROUgKBimxmcs20P3DOZy1diiDvAmkRaLYGX/ELnocqh4RdHkiUkoUDLJP81dtoeuE72i6fCg9Ih9Qw9tG7MQbCF38JNRpGnR5IlLKFAyyV0tzt9Jt0gKqLx7Bi5ExHBHZRPSYi+HSpwnVaxl0eSJSRhQM8l9W/Lydt6YshnkjeMwbQ8PIOvLrnwmXPUO48XlBlyciZUzBIP+xcuMOuk9dws5vx3BfeBRNI6vJP+IkuLQ7XrPLdeupSIpQMAg5m3bQc9r3bP56LPd5o2jurSS/TnO4eCDeCddCqESPBheRJKFgSGErN+6g1/SlbJn7Pv8Ij+b4yE/k1zoWLs7AO/F6CIWDLlFEAqBgSEE/bthOz2lL2Pnt+/w9PJbjvZ/iT0/7fR+8k25UIIikOAVDCvk+dytvT1uMmz+Ge72xNPVWk1+7GVykQBCR/0/BkAK+y9nMO9MWUm3JaO73PqBRJJf8OsfD7zvjtWirQBCRQhQMSco5x5c/bCBj2nwa/ziSpyIfUjeykfwjW8KFr+E1v0qTyiJSLAVDkonGHBPnr2XYjK85c90oXvcmUSOyjfyG58KFD+M1vVi3nYrIb1IwJIkdu/MZNSeHDz/+gjbbxpDhfUyat5vocVfB+Q/gNTwz6BJFJEEoGBJc7i+7GPTFcubNnEKHaCZDw1lYhTCc2gHO/Qfhw5sHXaKIJBgFQ4L6LmczAz5dSv6CTO4MfcgjoWzyK9Ug3Op+OOuvUO3IoEsUkQSlYEgge6IxJsxfy+jPvqPF6vd51JvMkd4G9tRoAr97Fa/lzVChStBlikiCUzAkgNxfdjF81kqyZk7nml3j6e19SVpkN/mNL4Rz/k6k2eW6w0hESo2CoZz69XbT977MJrJkLDeHJnNfKJtoWiVCLW+Fs+7GO+KEoMsUkSSkYChnft6Wx+g5OXw280vO3/ohz3ofU9Pbxp6aTeGsFwm3vBkq1Qy6TBFJYgqGciAac3zy/XrGfLWUCkvHc1NoGn8NLSYW8XDN20CrPxNpcoE+fyAih4SCIUDZ67YyKiuHpV9P55Jdk3nBm0k1bwe7qx8NrZ4h1PIWPUtZRA45BcMhtn5rHuO/W83nc+bSPHcCN3mfcoytIZpWEVpcB2fcRoVG52oyWUQCo2A4BLbs3MOkBWuZMXcxtX/8kGtDn3Fn6HuIwO4G58Lp/yLcoi1UrB50qSIiCoaysmXHHiYvyuWzbxdTddlELreZdAsvwPNi5NVqDqc/BSe1o0Kto4MuVUSkEAVDKcrZtIOpi9Yxd948aq+czGU2m1fDi+NhUL0x4VPuh5NuIK3uSZpIFpFyS8FwEPZEY8z5cROfLFnDmgWf03Tz51wSmssdoZ/Ag101mxE++UFocS1pR56iMBCRhKBg2A+xmGNJ7la+yP6ZZYvmUnHVF5wZ+5a/hRZQ3XYQi4TJO6oVnPhnaH4VFQ87NuiSRUT2W4mCwcxaA92AMNDXOfdSkfVpwCDgDGAD0N45t8Jf1xG4C4gC/3TOffRb+zSzJsBwoA4wB7jNObf74IZ5YLbs2MO8VVuY/+NaNiydTYXcr2kRXcy1ocUcbr+Awc5q9Qg3uwGaXUKo6cVU0ofPRCTB7TMYzCwM9AAuA3KA2WaW6ZxbWKDbXcAm59yxZtYB6AK0N7MWQAfgRKAeMMXMjvO32ds+uwBdnXPDzextf9+9SmOwxYnGHOu27mL1ph2syV3Lhpyl7MxdRmjTD9TdtYzmtpK7bDURiwKwvepR2NGXQ7MLoPF5VKp9jC4RiUhSKck7hlZAtnNuGYCZDQfaAgWDoS3wjP96FNDdzMxvH+6cywOWm1m2vz+K26eZLQIuBm72+wz091smwTCj98M0XvUB1djOKezgDP+X/6+2Vj6SPXVOINroBiJHt4IG6VTR11mLSJIrSTDUB1YWWM4BztpbH+dcvpltIX4pqD4ws8i29f3Xxe2zDrDZOZdfTP9CzOxu4G6ARo0alWAY/63OUUezM+8UdlepybYqtalc83Bq1m9GpE5jqNWYahVrHNB+RUQSWcJOPjvnegO9AdLT092B7OPka/4B/KM0yxIRSXgl+d6FVUDDAssN/LZi+5iZB9QgPgm9t2331r4BqOnvY2/HEhGRMlSSYJgNNDOzJmZWgfhkcmaRPpnAHf7rdsA055zz2zuYWZp/t1EzYNbe9ulvM93fB/4+xx348EREZH/t81KSP2dwL/AR8VtL+znnFphZJyDLOZcJZACD/cnljcR/0eP3G0F8ojofuMc5FwUobp/+IR8DhpvZ88Bcf98iInKIWPyP9MSWnp7usrKygi5DRCShmNkc51x60XZ9t7OIiBSiYBARkUIUDCIiUoiCQURECkmKyWczWw/8eICbHwb8XIrlJIpUHHcqjhlSc9ypOGbY/3Ef7Zw7vGhjUgTDwTCzrOJm5ZNdKo47FccMqTnuVBwzlN64dSlJREQKUTCIiEghCgb/i/hSUCqOOxXHDKk57lQcM5TSuFN+jkFERArTOwYRESlEwSAiIoWkdDCYWWszW2Jm2Wb2eND1lAUza2hm081soZktMLP7/PbaZjbZzJb6P2sFXWtpM7Owmc01s/H+chMz+8o/3+/5X/meVMysppmNMrPFZrbIzM5J9nNtZg/4/7bnm9kwM6uYjOfazPqZ2Tozm1+grdhza3Fv+uP/zsxO359jpWwwmFkY6AFcCbQA/mBmLYKtqkzkAw8551oAZwP3+ON8HJjqnGsGTPWXk819wKICy12Ars65Y4FNwF2BVFW2ugETnXPHA6cSH3/Snmszqw/8E0h3zp1E/Gv8O5Cc53oA0LpI297O7ZXEn3/TjPgjkHvtz4FSNhiAVkC2c26Zc243MBxoG3BNpc45t8Y597X/eivxXxT1iY91oN9tIHBdIAWWETNrAFwF9PWXDbgYGOV3ScYx1wAuwH+GiXNut3NuM0l+rok/V6aS/+THysAakvBcO+c+If68m4L2dm7bAoNc3EziT8Y8qqTHSuVgqA+sLLCc47clLTNrDJwGfAXUdc6t8VetBeoGVVcZeQN4FIj5y3WAzc65fH85Gc93E2A90N+/hNbXzKqQxOfaObcKeBX4iXggbAHmkPzn+ld7O7cH9fstlYMhpZhZVWA0cL9z7peC6/xHqibNfctmdjWwzjk3J+haDjEPOB3o5Zw7DdhOkctGSXiuaxH/67gJUA+own9fbkkJpXluUzkYVgENCyw38NuSjplFiIfCUOfcGL8599e3lv7PdUHVVwZ+B1xrZiuIXyK8mPi195r+5QZIzvOdA+Q4577yl0cRD4pkPteXAsudc+udc3uAMcTPf7Kf61/t7dwe1O+3VA6G2UAz/+6FCsQnrDIDrqnU+dfWM4BFzrnXC6zKBO7wX98BjDvUtZUV51xH51wD51xj4ud1mnPuFmA60M7vllRjBnDOrQVWmllzv+kS4s9bT9pzTfwS0tlmVtn/t/7rmJP6XBewt3ObCdzu3510NrClwCWnfUrpTz6bWRvi16LDQD/nXOdgKyp9ZnYe8Ckwj/9/vf0J4vMMI4BGxL+y/CbnXNGJrYRnZhcBDzvnrjazY4i/g6gNzAVudc7lBVheqTOzlsQn3CsAy4A7if8BmLTn2syeBdoTvwNvLvBn4tfTk+pcm9kw4CLiX62dCzwNjKWYc+uHZHfil9V2AHc657JKfKxUDgYREflvqXwpSUREiqFgEBGRQhQMIiJSiIJBREQKUTCIiEghCgYRESlEwSAiIoX8P+JEOdO65sdoAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.plot(le_errorsSet[0].cpu().numpy())\n",
    "plt.plot(ai_errorsSet[0].cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1.6675e-06, 5.8438e-06, 1.1797e-05, 1.9131e-05, 2.7599e-05, 3.7036e-05,\n",
       "        4.7327e-05, 5.8382e-05, 7.0134e-05, 8.2527e-05, 9.5518e-05, 1.0907e-04,\n",
       "        1.2315e-04, 1.3772e-04, 1.5278e-04, 1.6828e-04, 1.8422e-04, 2.0058e-04,\n",
       "        2.1734e-04, 2.3448e-04, 2.5200e-04, 2.6988e-04, 2.8810e-04, 3.0667e-04,\n",
       "        3.2556e-04, 3.4477e-04, 3.6429e-04, 3.8412e-04, 4.0424e-04, 4.2464e-04,\n",
       "        4.4532e-04, 4.6628e-04, 4.8750e-04, 5.0899e-04, 5.3073e-04, 5.5272e-04,\n",
       "        5.7495e-04, 5.9743e-04, 6.2013e-04, 6.4307e-04, 6.6624e-04, 6.8962e-04,\n",
       "        7.1323e-04, 7.3705e-04, 7.6108e-04, 7.8531e-04, 8.0975e-04, 8.3439e-04,\n",
       "        8.5923e-04, 8.8426e-04, 9.0949e-04, 9.3490e-04, 9.6050e-04, 9.8628e-04,\n",
       "        1.0122e-03, 1.0384e-03, 1.0647e-03, 1.0912e-03, 1.1178e-03, 1.1447e-03,\n",
       "        1.1717e-03, 1.1988e-03, 1.2262e-03, 1.2537e-03, 1.2813e-03, 1.3091e-03,\n",
       "        1.3371e-03, 1.3652e-03, 1.3935e-03, 1.4219e-03, 1.4505e-03, 1.4792e-03,\n",
       "        1.5081e-03, 1.5371e-03, 1.5662e-03, 1.5955e-03, 1.6250e-03, 1.6546e-03,\n",
       "        1.6843e-03, 1.7141e-03, 1.7441e-03, 1.7743e-03, 1.8045e-03, 1.8349e-03,\n",
       "        1.8655e-03, 1.8961e-03, 1.9269e-03, 1.9579e-03, 1.9889e-03, 2.0201e-03,\n",
       "        2.0514e-03, 2.0829e-03, 2.1144e-03, 2.1461e-03, 2.1779e-03, 2.2098e-03,\n",
       "        2.2419e-03, 2.2741e-03, 2.3064e-03, 2.3388e-03], device='cuda:0')"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "le_errorsSet[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./filtering_1mm_axial_results/gdae_N_nTanh_filtering_allresult_subject033_S_4179_20180125_131417_std0.0071_covcoeff_stepsize0.1_maxiter100.pickle\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "savefilename = './filtering_1mm_axial_results/gdae_N_nTanh_filtering_allresult'+'_subject'+file_prefix+'_std'+'{:.4f}'.format(noise_std)+'_covcoeff'+'_stepsize'+str(step_size)+'_maxiter'+str(max_iter)+'.pickle'\n",
    "print(savefilename)\n",
    "with open(savefilename, 'wb') as handle:\n",
    "    pickle.dump(filteringResults, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save Figures"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### NOTE: To run the below code, nibabel library is required"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "dx = 1\n",
    "data_dim = [roi_range[1] - roi_range[0], roi_range[3] - roi_range[2], roi_range[5] - roi_range[4]]\n",
    "mat2 = np.eye(4)\n",
    "mat2[0,0] = -1\n",
    "mat2[1,1] = 1\n",
    "mat2[2,2] = 1\n",
    "mat2[0,3] = 72 # 72 - roi_range[0]*1\n",
    "mat2[1,3] = -106 # -106 + roi_range[2]*1\n",
    "mat2[2,3] = 10 # 10 + roi_range[4]*1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1.6675e-06, 2.3448e-04, 6.4307e-04, 1.1447e-03, 1.7141e-03, 2.3388e-03],\n",
      "       device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "from DTI_ms_utils import save_DTI_shiftedTensor_result_file\n",
    "\n",
    "index = [0, 19, 39, 59, 79, 99]\n",
    "print(le_errorsSet[0][index])\n",
    "for i in index:\n",
    "    shiftedPoints = filteringResults.shiftedPointsSet[i]\n",
    "    savefilename = './filtering_1mm_axial_results/gdae_N_nTanh_filtering_result'+'_subject'+file_prefix+'_iter'+str(i+1)+'_std'+'{:.4f}'.format(noise_std)+'_covcoeff'+'_stepsize'+str(step_size)+'.nii.gz'\n",
    "    shiftedTensor = save_DTI_shiftedTensor_result_file(data_dim, input.cuda(), shiftedPoints.cuda(), dx, \n",
    "            savefilename, mat2 = mat2, posMean = testdataset.posMean.cuda(), use_logvec = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
