{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7f72772ac828>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# set device\n",
    "useGPU = torch.cuda.is_available()\n",
    "torch.cuda.set_device(0)\n",
    "\n",
    "# set random seed\n",
    "seed = 0\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from DTI_dataset import DTI_DataSet\n",
    "\n",
    "shuffle = False\n",
    "file_prefix = '033_S_4179_20180125_131417'\n",
    "filelist = ['../DTI_data_1mm_axialcrop/'+file_prefix[:10]+'/'+file_prefix+'_1mm_6_tensor_cropped.nii.gz']\n",
    "returnforDAE = True\n",
    "randPos = True\n",
    "roi_range = [0, 143, 0, 181, 0, 13]\n",
    "dx = 1\n",
    "data_num = 1\n",
    "data_prefix = '_fix'\n",
    "            \n",
    "traindataset = DTI_DataSet(filelist,\n",
    "                           match_num_per_subject = data_num, shuffle = shuffle, randPos = randPos, \n",
    "                           dx = dx, roi_range = roi_range, data_prefix = data_prefix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_batch_size = 60000\n",
    "trainloader = torch.utils.data.DataLoader(traindataset, batch_size=train_batch_size, \n",
    "                                          shuffle=True, num_workers = min(6, len(traindataset)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DAE Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gae2 import DAE_DTI\n",
    "\n",
    "dim = [9, 1000]\n",
    "num_hidden_layers = 2\n",
    "noise_std = 0.01\n",
    "dae = DAE_DTI(dim, num_hidden_layers, noise_std, useLeakyReLU = False, pos_dim = 3)\n",
    "data_rms = torch.sqrt(torch.mean(traindataset.posAndCov**2, dim=0)).view(1,-1)\n",
    "dae.autoencoder[0].weight.data /= data_rms\n",
    "dae = dae.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train DAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.optim.lr_scheduler import MultiStepLR\n",
    "\n",
    "max_iter_num = 40000\n",
    "optimizer = torch.optim.Adam(dae.parameters(), lr=1e-4, weight_decay=1e-6)\n",
    "scheduler = MultiStepLR(optimizer, milestones=[max_iter_num//2], gamma=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iter: 0 ---- time 0.2\n",
      "loss: 0.043238\n",
      "iter: 100 ---- time 14.0\n",
      "loss: 0.001359\n",
      "iter: 200 ---- time 29.3\n",
      "loss: 0.000978\n",
      "iter: 300 ---- time 44.5\n",
      "loss: 0.000878\n",
      "iter: 400 ---- time 59.8\n",
      "loss: 0.000825\n",
      "iter: 500 ---- time 75.0\n",
      "loss: 0.000786\n",
      "iter: 600 ---- time 90.3\n",
      "loss: 0.000771\n",
      "iter: 700 ---- time 105.6\n",
      "loss: 0.000744\n",
      "iter: 800 ---- time 120.9\n",
      "loss: 0.000817\n",
      "iter: 900 ---- time 136.2\n",
      "loss: 0.000777\n",
      "iter: 1000 ---- time 151.5\n",
      "loss: 0.000723\n",
      "iter: 1100 ---- time 166.8\n",
      "loss: 0.000689\n",
      "iter: 1200 ---- time 182.1\n",
      "loss: 0.000706\n",
      "iter: 1300 ---- time 197.4\n",
      "loss: 0.000677\n",
      "iter: 1400 ---- time 212.7\n",
      "loss: 0.000712\n",
      "iter: 1500 ---- time 228.1\n",
      "loss: 0.000897\n",
      "iter: 1600 ---- time 243.4\n",
      "loss: 0.000693\n",
      "iter: 1700 ---- time 258.7\n",
      "loss: 0.000792\n",
      "iter: 1800 ---- time 274.0\n",
      "loss: 0.000853\n",
      "iter: 1900 ---- time 289.4\n",
      "loss: 0.000726\n",
      "iter: 2000 ---- time 304.7\n",
      "loss: 0.000794\n",
      "iter: 2100 ---- time 320.0\n",
      "loss: 0.000757\n",
      "iter: 2200 ---- time 335.4\n",
      "loss: 0.000679\n",
      "iter: 2300 ---- time 350.7\n",
      "loss: 0.000698\n",
      "iter: 2400 ---- time 366.0\n",
      "loss: 0.000873\n",
      "iter: 2500 ---- time 381.4\n",
      "loss: 0.000663\n",
      "iter: 2600 ---- time 396.7\n",
      "loss: 0.000654\n",
      "iter: 2700 ---- time 412.1\n",
      "loss: 0.001086\n",
      "iter: 2800 ---- time 427.4\n",
      "loss: 0.000636\n",
      "iter: 2900 ---- time 442.8\n",
      "loss: 0.000711\n",
      "iter: 3000 ---- time 458.1\n",
      "loss: 0.000750\n",
      "iter: 3100 ---- time 473.5\n",
      "loss: 0.000629\n",
      "iter: 3200 ---- time 488.8\n",
      "loss: 0.000689\n",
      "iter: 3300 ---- time 504.2\n",
      "loss: 0.000640\n",
      "iter: 3400 ---- time 519.5\n",
      "loss: 0.000799\n",
      "iter: 3500 ---- time 534.9\n",
      "loss: 0.000649\n",
      "iter: 3600 ---- time 550.2\n",
      "loss: 0.000769\n",
      "iter: 3700 ---- time 565.6\n",
      "loss: 0.000638\n",
      "iter: 3800 ---- time 581.0\n",
      "loss: 0.000747\n",
      "iter: 3900 ---- time 596.3\n",
      "loss: 0.000778\n",
      "iter: 4000 ---- time 611.7\n",
      "loss: 0.000695\n",
      "iter: 4100 ---- time 627.0\n",
      "loss: 0.000746\n",
      "iter: 4200 ---- time 642.4\n",
      "loss: 0.000756\n",
      "iter: 4300 ---- time 657.7\n",
      "loss: 0.000964\n",
      "iter: 4400 ---- time 673.1\n",
      "loss: 0.000766\n",
      "iter: 4500 ---- time 688.4\n",
      "loss: 0.000630\n",
      "iter: 4600 ---- time 703.8\n",
      "loss: 0.000681\n",
      "iter: 4700 ---- time 719.1\n",
      "loss: 0.000668\n",
      "iter: 4800 ---- time 734.5\n",
      "loss: 0.000785\n",
      "iter: 4900 ---- time 749.9\n",
      "loss: 0.000863\n",
      "iter: 5000 ---- time 765.2\n",
      "loss: 0.000656\n",
      "iter: 5100 ---- time 780.6\n",
      "loss: 0.000726\n",
      "iter: 5200 ---- time 795.9\n",
      "loss: 0.000706\n",
      "iter: 5300 ---- time 811.3\n",
      "loss: 0.000774\n",
      "iter: 5400 ---- time 826.6\n",
      "loss: 0.000681\n",
      "iter: 5500 ---- time 842.0\n",
      "loss: 0.000626\n",
      "iter: 5600 ---- time 857.3\n",
      "loss: 0.000674\n",
      "iter: 5700 ---- time 872.7\n",
      "loss: 0.000636\n",
      "iter: 5800 ---- time 888.1\n",
      "loss: 0.000648\n",
      "iter: 5900 ---- time 903.4\n",
      "loss: 0.000868\n",
      "iter: 6000 ---- time 918.8\n",
      "loss: 0.000896\n",
      "iter: 6100 ---- time 934.1\n",
      "loss: 0.000694\n",
      "iter: 6200 ---- time 949.5\n",
      "loss: 0.000764\n",
      "iter: 6300 ---- time 964.9\n",
      "loss: 0.000686\n",
      "iter: 6400 ---- time 980.2\n",
      "loss: 0.000629\n",
      "iter: 6500 ---- time 995.6\n",
      "loss: 0.000653\n",
      "iter: 6600 ---- time 1011.0\n",
      "loss: 0.000733\n",
      "iter: 6700 ---- time 1026.3\n",
      "loss: 0.000717\n",
      "iter: 6800 ---- time 1041.7\n",
      "loss: 0.000661\n",
      "iter: 6900 ---- time 1057.0\n",
      "loss: 0.000637\n",
      "iter: 7000 ---- time 1072.4\n",
      "loss: 0.000707\n",
      "iter: 7100 ---- time 1087.8\n",
      "loss: 0.000700\n",
      "iter: 7200 ---- time 1103.1\n",
      "loss: 0.000617\n",
      "iter: 7300 ---- time 1118.5\n",
      "loss: 0.000674\n",
      "iter: 7400 ---- time 1133.8\n",
      "loss: 0.000749\n",
      "iter: 7500 ---- time 1149.2\n",
      "loss: 0.000618\n",
      "iter: 7600 ---- time 1164.6\n",
      "loss: 0.000625\n",
      "iter: 7700 ---- time 1179.9\n",
      "loss: 0.000625\n",
      "iter: 7800 ---- time 1195.3\n",
      "loss: 0.000736\n",
      "iter: 7900 ---- time 1210.7\n",
      "loss: 0.000679\n",
      "iter: 8000 ---- time 1226.0\n",
      "loss: 0.000628\n",
      "iter: 8100 ---- time 1241.4\n",
      "loss: 0.000614\n",
      "iter: 8200 ---- time 1256.7\n",
      "loss: 0.000664\n",
      "iter: 8300 ---- time 1272.1\n",
      "loss: 0.000659\n",
      "iter: 8400 ---- time 1287.5\n",
      "loss: 0.000656\n",
      "iter: 8500 ---- time 1302.8\n",
      "loss: 0.000623\n",
      "iter: 8600 ---- time 1318.2\n",
      "loss: 0.000652\n",
      "iter: 8700 ---- time 1333.6\n",
      "loss: 0.000686\n",
      "iter: 8800 ---- time 1348.9\n",
      "loss: 0.000643\n",
      "iter: 8900 ---- time 1364.3\n",
      "loss: 0.000648\n",
      "iter: 9000 ---- time 1379.7\n",
      "loss: 0.000649\n",
      "iter: 9100 ---- time 1395.0\n",
      "loss: 0.000724\n",
      "iter: 9200 ---- time 1410.4\n",
      "loss: 0.000628\n",
      "iter: 9300 ---- time 1425.7\n",
      "loss: 0.000629\n",
      "iter: 9400 ---- time 1441.1\n",
      "loss: 0.000628\n",
      "iter: 9500 ---- time 1456.5\n",
      "loss: 0.000658\n",
      "iter: 9600 ---- time 1471.8\n",
      "loss: 0.000652\n",
      "iter: 9700 ---- time 1487.2\n",
      "loss: 0.000715\n",
      "iter: 9800 ---- time 1502.6\n",
      "loss: 0.000703\n",
      "iter: 9900 ---- time 1517.9\n",
      "loss: 0.000718\n",
      "iter: 10000 ---- time 1533.3\n",
      "loss: 0.000614\n",
      "iter: 10100 ---- time 1548.7\n",
      "loss: 0.000652\n",
      "iter: 10200 ---- time 1564.1\n",
      "loss: 0.000643\n",
      "iter: 10300 ---- time 1579.4\n",
      "loss: 0.000707\n",
      "iter: 10400 ---- time 1594.8\n",
      "loss: 0.000768\n",
      "iter: 10500 ---- time 1610.2\n",
      "loss: 0.000617\n",
      "iter: 10600 ---- time 1625.5\n",
      "loss: 0.000696\n",
      "iter: 10700 ---- time 1640.9\n",
      "loss: 0.000639\n",
      "iter: 10800 ---- time 1656.2\n",
      "loss: 0.000614\n",
      "iter: 10900 ---- time 1671.6\n",
      "loss: 0.000673\n",
      "iter: 11000 ---- time 1687.0\n",
      "loss: 0.000658\n",
      "iter: 11100 ---- time 1702.3\n",
      "loss: 0.000649\n",
      "iter: 11200 ---- time 1717.7\n",
      "loss: 0.000642\n",
      "iter: 11300 ---- time 1733.1\n",
      "loss: 0.000614\n",
      "iter: 11400 ---- time 1748.4\n",
      "loss: 0.000747\n",
      "iter: 11500 ---- time 1763.8\n",
      "loss: 0.000722\n",
      "iter: 11600 ---- time 1779.2\n",
      "loss: 0.000825\n",
      "iter: 11700 ---- time 1794.5\n",
      "loss: 0.000770\n",
      "iter: 11800 ---- time 1809.9\n",
      "loss: 0.000676\n",
      "iter: 11900 ---- time 1825.3\n",
      "loss: 0.000715\n",
      "iter: 12000 ---- time 1840.6\n",
      "loss: 0.000703\n",
      "iter: 12100 ---- time 1856.0\n",
      "loss: 0.000663\n",
      "iter: 12200 ---- time 1871.4\n",
      "loss: 0.000631\n",
      "iter: 12300 ---- time 1886.7\n",
      "loss: 0.000669\n",
      "iter: 12400 ---- time 1902.1\n",
      "loss: 0.000613\n",
      "iter: 12500 ---- time 1917.5\n",
      "loss: 0.000646\n",
      "iter: 12600 ---- time 1932.8\n",
      "loss: 0.000695\n",
      "iter: 12700 ---- time 1948.2\n",
      "loss: 0.000687\n",
      "iter: 12800 ---- time 1963.6\n",
      "loss: 0.000805\n",
      "iter: 12900 ---- time 1978.9\n",
      "loss: 0.000647\n",
      "iter: 13000 ---- time 1994.3\n",
      "loss: 0.000604\n",
      "iter: 13100 ---- time 2009.7\n",
      "loss: 0.000632\n",
      "iter: 13200 ---- time 2025.0\n",
      "loss: 0.000637\n",
      "iter: 13300 ---- time 2040.4\n",
      "loss: 0.000703\n",
      "iter: 13400 ---- time 2055.8\n",
      "loss: 0.000616\n",
      "iter: 13500 ---- time 2071.1\n",
      "loss: 0.000680\n",
      "iter: 13600 ---- time 2086.5\n",
      "loss: 0.000610\n",
      "iter: 13700 ---- time 2101.9\n",
      "loss: 0.000625\n",
      "iter: 13800 ---- time 2117.2\n",
      "loss: 0.000622\n",
      "iter: 13900 ---- time 2132.6\n",
      "loss: 0.000705\n",
      "iter: 14000 ---- time 2148.0\n",
      "loss: 0.000639\n",
      "iter: 14100 ---- time 2163.3\n",
      "loss: 0.000626\n",
      "iter: 14200 ---- time 2178.7\n",
      "loss: 0.000611\n",
      "iter: 14300 ---- time 2194.1\n",
      "loss: 0.000609\n",
      "iter: 14400 ---- time 2209.4\n",
      "loss: 0.000601\n",
      "iter: 14500 ---- time 2224.8\n",
      "loss: 0.000661\n",
      "iter: 14600 ---- time 2240.1\n",
      "loss: 0.000711\n",
      "iter: 14700 ---- time 2255.5\n",
      "loss: 0.000631\n",
      "iter: 14800 ---- time 2270.9\n",
      "loss: 0.000695\n",
      "iter: 14900 ---- time 2286.2\n",
      "loss: 0.000701\n",
      "iter: 15000 ---- time 2301.6\n",
      "loss: 0.000624\n",
      "iter: 15100 ---- time 2317.0\n",
      "loss: 0.000600\n",
      "iter: 15200 ---- time 2332.3\n",
      "loss: 0.000775\n",
      "iter: 15300 ---- time 2347.7\n",
      "loss: 0.000619\n",
      "iter: 15400 ---- time 2363.1\n",
      "loss: 0.000670\n",
      "iter: 15500 ---- time 2378.4\n",
      "loss: 0.000670\n",
      "iter: 15600 ---- time 2393.8\n",
      "loss: 0.000608\n",
      "iter: 15700 ---- time 2409.2\n",
      "loss: 0.000613\n",
      "iter: 15800 ---- time 2424.5\n",
      "loss: 0.000633\n",
      "iter: 15900 ---- time 2439.9\n",
      "loss: 0.000600\n",
      "iter: 16000 ---- time 2455.3\n",
      "loss: 0.000655\n",
      "iter: 16100 ---- time 2470.7\n",
      "loss: 0.000611\n",
      "iter: 16200 ---- time 2486.0\n",
      "loss: 0.000667\n",
      "iter: 16300 ---- time 2501.4\n",
      "loss: 0.000616\n",
      "iter: 16400 ---- time 2516.8\n",
      "loss: 0.000716\n",
      "iter: 16500 ---- time 2532.1\n",
      "loss: 0.000672\n",
      "iter: 16600 ---- time 2547.5\n",
      "loss: 0.000605\n",
      "iter: 16700 ---- time 2562.9\n",
      "loss: 0.000643\n",
      "iter: 16800 ---- time 2578.2\n",
      "loss: 0.000613\n",
      "iter: 16900 ---- time 2593.6\n",
      "loss: 0.000659\n",
      "iter: 17000 ---- time 2609.0\n",
      "loss: 0.000633\n",
      "iter: 17100 ---- time 2624.3\n",
      "loss: 0.000765\n",
      "iter: 17200 ---- time 2639.7\n",
      "loss: 0.000611\n",
      "iter: 17300 ---- time 2655.1\n",
      "loss: 0.000593\n",
      "iter: 17400 ---- time 2670.4\n",
      "loss: 0.000657\n",
      "iter: 17500 ---- time 2685.8\n",
      "loss: 0.000600\n",
      "iter: 17600 ---- time 2701.2\n",
      "loss: 0.000693\n",
      "iter: 17700 ---- time 2716.6\n",
      "loss: 0.000630\n",
      "iter: 17800 ---- time 2731.9\n",
      "loss: 0.000660\n",
      "iter: 17900 ---- time 2747.3\n",
      "loss: 0.000637\n",
      "iter: 18000 ---- time 2762.6\n",
      "loss: 0.000627\n",
      "iter: 18100 ---- time 2778.0\n",
      "loss: 0.000636\n",
      "iter: 18200 ---- time 2793.4\n",
      "loss: 0.000633\n",
      "iter: 18300 ---- time 2808.8\n",
      "loss: 0.000607\n",
      "iter: 18400 ---- time 2824.1\n",
      "loss: 0.000652\n",
      "iter: 18500 ---- time 2839.5\n",
      "loss: 0.000699\n",
      "iter: 18600 ---- time 2854.9\n",
      "loss: 0.000626\n",
      "iter: 18700 ---- time 2870.2\n",
      "loss: 0.000594\n",
      "iter: 18800 ---- time 2885.6\n",
      "loss: 0.000605\n",
      "iter: 18900 ---- time 2901.0\n",
      "loss: 0.000682\n",
      "iter: 19000 ---- time 2916.3\n",
      "loss: 0.000609\n",
      "iter: 19100 ---- time 2931.7\n",
      "loss: 0.000597\n",
      "iter: 19200 ---- time 2947.1\n",
      "loss: 0.000651\n",
      "iter: 19300 ---- time 2962.5\n",
      "loss: 0.000599\n",
      "iter: 19400 ---- time 2977.8\n",
      "loss: 0.000595\n",
      "iter: 19500 ---- time 2993.2\n",
      "loss: 0.000634\n",
      "iter: 19600 ---- time 3008.6\n",
      "loss: 0.000633\n",
      "iter: 19700 ---- time 3024.0\n",
      "loss: 0.000639\n",
      "iter: 19800 ---- time 3039.3\n",
      "loss: 0.000596\n",
      "iter: 19900 ---- time 3054.7\n",
      "loss: 0.000623\n",
      "iter: 20000 ---- time 3070.1\n",
      "loss: 0.000622\n",
      "iter: 20100 ---- time 3085.5\n",
      "loss: 0.000575\n",
      "iter: 20200 ---- time 3100.8\n",
      "loss: 0.000576\n",
      "iter: 20300 ---- time 3116.2\n",
      "loss: 0.000574\n",
      "iter: 20400 ---- time 3131.6\n",
      "loss: 0.000575\n",
      "iter: 20500 ---- time 3146.9\n",
      "loss: 0.000575\n",
      "iter: 20600 ---- time 3162.3\n",
      "loss: 0.000574\n",
      "iter: 20700 ---- time 3177.7\n",
      "loss: 0.000576\n",
      "iter: 20800 ---- time 3193.0\n",
      "loss: 0.000576\n",
      "iter: 20900 ---- time 3208.4\n",
      "loss: 0.000575\n",
      "iter: 21000 ---- time 3223.8\n",
      "loss: 0.000575\n",
      "iter: 21100 ---- time 3239.1\n",
      "loss: 0.000575\n",
      "iter: 21200 ---- time 3254.5\n",
      "loss: 0.000576\n",
      "iter: 21300 ---- time 3269.9\n",
      "loss: 0.000574\n",
      "iter: 21400 ---- time 3285.3\n",
      "loss: 0.000575\n",
      "iter: 21500 ---- time 3300.6\n",
      "loss: 0.000575\n",
      "iter: 21600 ---- time 3316.0\n",
      "loss: 0.000575\n",
      "iter: 21700 ---- time 3331.4\n",
      "loss: 0.000575\n",
      "iter: 21800 ---- time 3346.7\n",
      "loss: 0.000576\n",
      "iter: 21900 ---- time 3362.1\n",
      "loss: 0.000574\n",
      "iter: 22000 ---- time 3377.5\n",
      "loss: 0.000575\n",
      "iter: 22100 ---- time 3392.9\n",
      "loss: 0.000575\n",
      "iter: 22200 ---- time 3408.2\n",
      "loss: 0.000575\n",
      "iter: 22300 ---- time 3423.6\n",
      "loss: 0.000576\n",
      "iter: 22400 ---- time 3439.0\n",
      "loss: 0.000575\n",
      "iter: 22500 ---- time 3454.4\n",
      "loss: 0.000575\n",
      "iter: 22600 ---- time 3469.7\n",
      "loss: 0.000575\n",
      "iter: 22700 ---- time 3485.1\n",
      "loss: 0.000576\n",
      "iter: 22800 ---- time 3500.5\n",
      "loss: 0.000575\n",
      "iter: 22900 ---- time 3515.8\n",
      "loss: 0.000575\n",
      "iter: 23000 ---- time 3531.2\n",
      "loss: 0.000575\n",
      "iter: 23100 ---- time 3546.6\n",
      "loss: 0.000574\n",
      "iter: 23200 ---- time 3561.9\n",
      "loss: 0.000576\n",
      "iter: 23300 ---- time 3577.3\n",
      "loss: 0.000575\n",
      "iter: 23400 ---- time 3592.7\n",
      "loss: 0.000575\n",
      "iter: 23500 ---- time 3608.0\n",
      "loss: 0.000575\n",
      "iter: 23600 ---- time 3623.4\n",
      "loss: 0.000575\n",
      "iter: 23700 ---- time 3638.8\n",
      "loss: 0.000575\n",
      "iter: 23800 ---- time 3654.2\n",
      "loss: 0.000575\n",
      "iter: 23900 ---- time 3669.5\n",
      "loss: 0.000576\n",
      "iter: 24000 ---- time 3684.9\n",
      "loss: 0.000575\n",
      "iter: 24100 ---- time 3700.3\n",
      "loss: 0.000575\n",
      "iter: 24200 ---- time 3715.6\n",
      "loss: 0.000576\n",
      "iter: 24300 ---- time 3731.0\n",
      "loss: 0.000575\n",
      "iter: 24400 ---- time 3746.4\n",
      "loss: 0.000575\n",
      "iter: 24500 ---- time 3761.8\n",
      "loss: 0.000576\n",
      "iter: 24600 ---- time 3777.1\n",
      "loss: 0.000575\n",
      "iter: 24700 ---- time 3792.5\n",
      "loss: 0.000574\n",
      "iter: 24800 ---- time 3807.8\n",
      "loss: 0.000575\n",
      "iter: 24900 ---- time 3823.2\n",
      "loss: 0.000576\n",
      "iter: 25000 ---- time 3838.6\n",
      "loss: 0.000574\n",
      "iter: 25100 ---- time 3854.0\n",
      "loss: 0.000574\n",
      "iter: 25200 ---- time 3869.3\n",
      "loss: 0.000575\n",
      "iter: 25300 ---- time 3884.7\n",
      "loss: 0.000575\n",
      "iter: 25400 ---- time 3900.1\n",
      "loss: 0.000575\n",
      "iter: 25500 ---- time 3915.4\n",
      "loss: 0.000574\n",
      "iter: 25600 ---- time 3930.8\n",
      "loss: 0.000574\n",
      "iter: 25700 ---- time 3946.2\n",
      "loss: 0.000574\n",
      "iter: 25800 ---- time 3961.6\n",
      "loss: 0.000574\n",
      "iter: 25900 ---- time 3976.9\n",
      "loss: 0.000575\n",
      "iter: 26000 ---- time 3992.3\n",
      "loss: 0.000575\n",
      "iter: 26100 ---- time 4007.7\n",
      "loss: 0.000575\n",
      "iter: 26200 ---- time 4023.1\n",
      "loss: 0.000574\n",
      "iter: 26300 ---- time 4038.5\n",
      "loss: 0.000574\n",
      "iter: 26400 ---- time 4053.8\n",
      "loss: 0.000575\n",
      "iter: 26500 ---- time 4069.2\n",
      "loss: 0.000574\n",
      "iter: 26600 ---- time 4084.6\n",
      "loss: 0.000575\n",
      "iter: 26700 ---- time 4099.9\n",
      "loss: 0.000575\n",
      "iter: 26800 ---- time 4115.3\n",
      "loss: 0.000574\n",
      "iter: 26900 ---- time 4130.7\n",
      "loss: 0.000575\n",
      "iter: 27000 ---- time 4146.0\n",
      "loss: 0.000575\n",
      "iter: 27100 ---- time 4161.4\n",
      "loss: 0.000574\n",
      "iter: 27200 ---- time 4176.8\n",
      "loss: 0.000574\n",
      "iter: 27300 ---- time 4192.2\n",
      "loss: 0.000575\n",
      "iter: 27400 ---- time 4207.5\n",
      "loss: 0.000575\n",
      "iter: 27500 ---- time 4222.9\n",
      "loss: 0.000574\n",
      "iter: 27600 ---- time 4238.3\n",
      "loss: 0.000574\n",
      "iter: 27700 ---- time 4253.6\n",
      "loss: 0.000574\n",
      "iter: 27800 ---- time 4269.0\n",
      "loss: 0.000575\n",
      "iter: 27900 ---- time 4284.4\n",
      "loss: 0.000574\n",
      "iter: 28000 ---- time 4299.7\n",
      "loss: 0.000574\n",
      "iter: 28100 ---- time 4315.1\n",
      "loss: 0.000574\n",
      "iter: 28200 ---- time 4330.5\n",
      "loss: 0.000574\n",
      "iter: 28300 ---- time 4345.8\n",
      "loss: 0.000574\n",
      "iter: 28400 ---- time 4361.2\n",
      "loss: 0.000575\n",
      "iter: 28500 ---- time 4376.6\n",
      "loss: 0.000575\n",
      "iter: 28600 ---- time 4392.0\n",
      "loss: 0.000574\n",
      "iter: 28700 ---- time 4407.3\n",
      "loss: 0.000573\n",
      "iter: 28800 ---- time 4422.7\n",
      "loss: 0.000574\n",
      "iter: 28900 ---- time 4438.1\n",
      "loss: 0.000574\n",
      "iter: 29000 ---- time 4453.4\n",
      "loss: 0.000573\n",
      "iter: 29100 ---- time 4468.8\n",
      "loss: 0.000575\n",
      "iter: 29200 ---- time 4484.2\n",
      "loss: 0.000573\n",
      "iter: 29300 ---- time 4499.6\n",
      "loss: 0.000573\n",
      "iter: 29400 ---- time 4514.9\n",
      "loss: 0.000574\n",
      "iter: 29500 ---- time 4530.3\n",
      "loss: 0.000574\n",
      "iter: 29600 ---- time 4545.7\n",
      "loss: 0.000573\n",
      "iter: 29700 ---- time 4561.0\n",
      "loss: 0.000574\n",
      "iter: 29800 ---- time 4576.4\n",
      "loss: 0.000574\n",
      "iter: 29900 ---- time 4591.8\n",
      "loss: 0.000573\n",
      "iter: 30000 ---- time 4607.1\n",
      "loss: 0.000573\n",
      "iter: 30100 ---- time 4622.5\n",
      "loss: 0.000574\n",
      "iter: 30200 ---- time 4637.9\n",
      "loss: 0.000573\n",
      "iter: 30300 ---- time 4653.2\n",
      "loss: 0.000572\n",
      "iter: 30400 ---- time 4668.6\n",
      "loss: 0.000573\n",
      "iter: 30500 ---- time 4684.0\n",
      "loss: 0.000574\n",
      "iter: 30600 ---- time 4699.4\n",
      "loss: 0.000573\n",
      "iter: 30700 ---- time 4714.7\n",
      "loss: 0.000573\n",
      "iter: 30800 ---- time 4730.1\n",
      "loss: 0.000573\n",
      "iter: 30900 ---- time 4745.4\n",
      "loss: 0.000575\n",
      "iter: 31000 ---- time 4760.8\n",
      "loss: 0.000573\n",
      "iter: 31100 ---- time 4776.2\n",
      "loss: 0.000573\n",
      "iter: 31200 ---- time 4791.6\n",
      "loss: 0.000572\n",
      "iter: 31300 ---- time 4806.9\n",
      "loss: 0.000572\n",
      "iter: 31400 ---- time 4822.3\n",
      "loss: 0.000573\n",
      "iter: 31500 ---- time 4837.7\n",
      "loss: 0.000572\n",
      "iter: 31600 ---- time 4853.1\n",
      "loss: 0.000572\n",
      "iter: 31700 ---- time 4868.4\n",
      "loss: 0.000574\n",
      "iter: 31800 ---- time 4883.8\n",
      "loss: 0.000572\n",
      "iter: 31900 ---- time 4899.2\n",
      "loss: 0.000572\n",
      "iter: 32000 ---- time 4914.5\n",
      "loss: 0.000572\n",
      "iter: 32100 ---- time 4929.9\n",
      "loss: 0.000572\n",
      "iter: 32200 ---- time 4945.3\n",
      "loss: 0.000573\n",
      "iter: 32300 ---- time 4960.6\n",
      "loss: 0.000572\n",
      "iter: 32400 ---- time 4976.0\n",
      "loss: 0.000573\n",
      "iter: 32500 ---- time 4991.4\n",
      "loss: 0.000573\n",
      "iter: 32600 ---- time 5006.7\n",
      "loss: 0.000573\n",
      "iter: 32700 ---- time 5022.1\n",
      "loss: 0.000573\n",
      "iter: 32800 ---- time 5037.5\n",
      "loss: 0.000573\n",
      "iter: 32900 ---- time 5052.8\n",
      "loss: 0.000573\n",
      "iter: 33000 ---- time 5068.2\n",
      "loss: 0.000571\n",
      "iter: 33100 ---- time 5083.5\n",
      "loss: 0.000572\n",
      "iter: 33200 ---- time 5098.9\n",
      "loss: 0.000572\n",
      "iter: 33300 ---- time 5114.3\n",
      "loss: 0.000572\n",
      "iter: 33400 ---- time 5129.6\n",
      "loss: 0.000572\n",
      "iter: 33500 ---- time 5145.0\n",
      "loss: 0.000572\n",
      "iter: 33600 ---- time 5160.4\n",
      "loss: 0.000572\n",
      "iter: 33700 ---- time 5175.7\n",
      "loss: 0.000572\n",
      "iter: 33800 ---- time 5191.1\n",
      "loss: 0.000572\n",
      "iter: 33900 ---- time 5206.4\n",
      "loss: 0.000572\n",
      "iter: 34000 ---- time 5221.8\n",
      "loss: 0.000573\n",
      "iter: 34100 ---- time 5237.2\n",
      "loss: 0.000571\n",
      "iter: 34200 ---- time 5252.5\n",
      "loss: 0.000572\n",
      "iter: 34300 ---- time 5267.9\n",
      "loss: 0.000571\n",
      "iter: 34400 ---- time 5283.2\n",
      "loss: 0.000572\n",
      "iter: 34500 ---- time 5298.6\n",
      "loss: 0.000570\n",
      "iter: 34600 ---- time 5314.0\n",
      "loss: 0.000571\n",
      "iter: 34700 ---- time 5329.4\n",
      "loss: 0.000572\n",
      "iter: 34800 ---- time 5344.7\n",
      "loss: 0.000571\n",
      "iter: 34900 ---- time 5360.1\n",
      "loss: 0.000570\n",
      "iter: 35000 ---- time 5375.5\n",
      "loss: 0.000571\n",
      "iter: 35100 ---- time 5390.8\n",
      "loss: 0.000572\n",
      "iter: 35200 ---- time 5406.2\n",
      "loss: 0.000572\n",
      "iter: 35300 ---- time 5421.6\n",
      "loss: 0.000571\n",
      "iter: 35400 ---- time 5436.9\n",
      "loss: 0.000572\n",
      "iter: 35500 ---- time 5452.3\n",
      "loss: 0.000571\n",
      "iter: 35600 ---- time 5467.7\n",
      "loss: 0.000571\n",
      "iter: 35700 ---- time 5483.0\n",
      "loss: 0.000572\n",
      "iter: 35800 ---- time 5498.4\n",
      "loss: 0.000571\n",
      "iter: 35900 ---- time 5513.8\n",
      "loss: 0.000571\n",
      "iter: 36000 ---- time 5529.2\n",
      "loss: 0.000571\n",
      "iter: 36100 ---- time 5544.5\n",
      "loss: 0.000571\n",
      "iter: 36200 ---- time 5559.9\n",
      "loss: 0.000572\n",
      "iter: 36300 ---- time 5575.3\n",
      "loss: 0.000571\n",
      "iter: 36400 ---- time 5590.6\n",
      "loss: 0.000570\n",
      "iter: 36500 ---- time 5606.0\n",
      "loss: 0.000571\n",
      "iter: 36600 ---- time 5621.4\n",
      "loss: 0.000571\n",
      "iter: 36700 ---- time 5636.7\n",
      "loss: 0.000571\n",
      "iter: 36800 ---- time 5652.1\n",
      "loss: 0.000571\n",
      "iter: 36900 ---- time 5667.5\n",
      "loss: 0.000571\n",
      "iter: 37000 ---- time 5682.8\n",
      "loss: 0.000570\n",
      "iter: 37100 ---- time 5698.2\n",
      "loss: 0.000570\n",
      "iter: 37200 ---- time 5713.6\n",
      "loss: 0.000572\n",
      "iter: 37300 ---- time 5728.9\n",
      "loss: 0.000571\n",
      "iter: 37400 ---- time 5744.3\n",
      "loss: 0.000570\n",
      "iter: 37500 ---- time 5759.7\n",
      "loss: 0.000571\n",
      "iter: 37600 ---- time 5775.0\n",
      "loss: 0.000571\n",
      "iter: 37700 ---- time 5790.4\n",
      "loss: 0.000571\n",
      "iter: 37800 ---- time 5805.8\n",
      "loss: 0.000570\n",
      "iter: 37900 ---- time 5821.2\n",
      "loss: 0.000570\n",
      "iter: 38000 ---- time 5836.5\n",
      "loss: 0.000570\n",
      "iter: 38100 ---- time 5851.9\n",
      "loss: 0.000571\n",
      "iter: 38200 ---- time 5867.3\n",
      "loss: 0.000570\n",
      "iter: 38300 ---- time 5882.7\n",
      "loss: 0.000570\n",
      "iter: 38400 ---- time 5898.0\n",
      "loss: 0.000571\n",
      "iter: 38500 ---- time 5913.4\n",
      "loss: 0.000570\n",
      "iter: 38600 ---- time 5928.8\n",
      "loss: 0.000571\n",
      "iter: 38700 ---- time 5944.1\n",
      "loss: 0.000571\n",
      "iter: 38800 ---- time 5959.5\n",
      "loss: 0.000570\n",
      "iter: 38900 ---- time 5974.9\n",
      "loss: 0.000570\n",
      "iter: 39000 ---- time 5990.2\n",
      "loss: 0.000570\n",
      "iter: 39100 ---- time 6005.6\n",
      "loss: 0.000570\n",
      "iter: 39200 ---- time 6021.0\n",
      "loss: 0.000570\n",
      "iter: 39300 ---- time 6036.4\n",
      "loss: 0.000570\n",
      "iter: 39400 ---- time 6051.7\n",
      "loss: 0.000570\n",
      "iter: 39500 ---- time 6067.1\n",
      "loss: 0.000569\n",
      "iter: 39600 ---- time 6082.4\n",
      "loss: 0.000569\n",
      "iter: 39700 ---- time 6097.8\n",
      "loss: 0.000570\n",
      "iter: 39800 ---- time 6113.2\n",
      "loss: 0.000570\n",
      "iter: 39900 ---- time 6128.5\n",
      "loss: 0.000571\n",
      "iter: 39999 ---- time 6143.8\n",
      "loss: 0.000569\n"
     ]
    }
   ],
   "source": [
    "from gae2_trainer import dae_trainer_batchall\n",
    "\n",
    "model_wts = dae_trainer_batchall(trainloader, dae, optimizer, scheduler, max_iter_num, use_gpu = useGPU, \n",
    "                saveModel = True, printEpochPeriod = 100, useFixedNoise = False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Check Score Estimation Error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "Ntrunc = 10000\n",
    "input = traindataset.posAndCov.clone()\n",
    "idxSet = [Ntrunc*i for i in range(0, int(input.shape[0] / Ntrunc)+1)]\n",
    "idxSet += [input.shape[0]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Post iterations for DAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iter: 0 ---- time 0.0\n",
      "loss: 0.000574\n",
      "iter: 99 ---- time 13.8\n",
      "loss: 0.000570\n",
      "tensor(-61018.0430, device='cuda:1')\n",
      "iter: 0 ---- time 0.0\n",
      "loss: 0.000570\n",
      "iter: 99 ---- time 13.8\n",
      "loss: 0.000571\n",
      "tensor(-61004.9180, device='cuda:1')\n",
      "iter: 0 ---- time 0.0\n",
      "loss: 0.000569\n",
      "iter: 99 ---- time 13.8\n",
      "loss: 0.000570\n",
      "tensor(-60985.1602, device='cuda:1')\n",
      "iter: 0 ---- time 0.0\n",
      "loss: 0.000571\n",
      "iter: 99 ---- time 13.8\n",
      "loss: 0.000570\n",
      "tensor(-61037.7031, device='cuda:1')\n",
      "iter: 0 ---- time 0.0\n",
      "loss: 0.000571\n",
      "iter: 99 ---- time 13.8\n",
      "loss: 0.000570\n",
      "tensor(-61083.3398, device='cuda:1')\n",
      "iter: 0 ---- time 0.0\n",
      "loss: 0.000569\n",
      "iter: 99 ---- time 13.8\n",
      "loss: 0.000569\n",
      "tensor(-61055.6719, device='cuda:1')\n",
      "iter: 0 ---- time 0.0\n",
      "loss: 0.000570\n",
      "iter: 99 ---- time 13.8\n",
      "loss: 0.000570\n",
      "tensor(-61043.8906, device='cuda:1')\n",
      "iter: 0 ---- time 0.0\n",
      "loss: 0.000570\n",
      "iter: 99 ---- time 13.8\n",
      "loss: 0.000570\n",
      "tensor(-61153.5234, device='cuda:1')\n",
      "iter: 0 ---- time 0.0\n",
      "loss: 0.000569\n",
      "iter: 99 ---- time 13.8\n",
      "loss: 0.000570\n",
      "tensor(-61052.6406, device='cuda:1')\n",
      "iter: 0 ---- time 0.0\n",
      "loss: 0.000570\n",
      "iter: 99 ---- time 13.8\n",
      "loss: 0.000569\n",
      "tensor(-61145.1016, device='cuda:1')\n"
     ]
    }
   ],
   "source": [
    "from gae2_score_estimation import DTI_dae_estimate_score_error_truncated\n",
    "\n",
    "optimizer = torch.optim.Adam( dae.parameters(), lr=1e-6, weight_decay=1e-6)\n",
    "max_iter_num = 100\n",
    "score_error_set = []\n",
    "model_set = []\n",
    "for i in range(10):\n",
    "    model_wts = dae_trainer_batchall(trainloader, dae, optimizer, scheduler, max_iter_num, use_gpu = useGPU, \n",
    "                saveModel = True, printEpochPeriod = 100, useFixedNoise = False)\n",
    "    score_error_set.append(\n",
    "        DTI_dae_estimate_score_error_truncated(traindataset.posAndCov, idxSet, dae, \n",
    "                                       printError = False)\n",
    "    )\n",
    "    if i == 0:\n",
    "        best_model = model_wts\n",
    "        min_val = score_error_set[-1]\n",
    "    elif score_error_set[-1] < min_val:\n",
    "        best_model = model_wts\n",
    "        min_val = score_error_set[-1]\n",
    "    print(score_error_set[-1])\n",
    "    #print(best_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dae.load_state_dict(best_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Data to Filter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "shuffle = False\n",
    "randPos = False\n",
    "roi_range = [0, 143, 0, 181, 0, 13]\n",
    "data_prefix_test = '_fix'\n",
    "            \n",
    "testdataset = DTI_DataSet(filelist, \n",
    "                           match_num_per_subject = None, shuffle = shuffle, randPos = randPos,\n",
    "                           dx = dx, roi_range = roi_range, data_prefix = data_prefix_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Mean Shift Filtering for an Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7f72772ac828>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_num = 0\n",
    "start_num = 0\n",
    "end_num = testdataset.split_nums[0]\n",
    "for i in range(len(testdataset.split_nums)):\n",
    "    if i > 0:\n",
    "        start_num += testdataset.split_nums[i-1]\n",
    "        end_num += testdataset.split_nums[i]\n",
    "    if i == input_num:\n",
    "        input = testdataset.posAndCov[start_num:end_num].cuda()\n",
    "torch.manual_seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from DTI_meanShift2 import DTI_meanShift_vectordae\n",
    "\n",
    "step_size = 0.005\n",
    "pos_threshold = 0.002*step_size\n",
    "cov_threshold = 0.0001*step_size\n",
    "pos_group_threshold = 10\n",
    "cov_group_threshold = 0.5\n",
    "eliminate_threshold = 10\n",
    "lowerbound, _ = torch.min(input,dim=0)\n",
    "upperbound, _ = torch.max(input,dim=0)\n",
    "pos_metric_choice = None\n",
    "ms_dae = DTI_meanShift_vectordae(dae, step_size, pos_threshold = pos_threshold, cov_threshold = cov_threshold,\n",
    "                pos_group_threshold = pos_group_threshold, cov_group_threshold = cov_group_threshold, \n",
    "                           eliminate_threshold = eliminate_threshold, \n",
    "                            lowerBound = lowerbound, upperBound = upperbound, pos_metric = pos_metric_choice)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([274595, 9])\n",
      "iter: 1, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.4257e-06, device='cuda:1')\n",
      "iter: 2, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.4158e-06, device='cuda:1')\n",
      "iter: 3, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.4048e-06, device='cuda:1')\n",
      "iter: 4, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.3948e-06, device='cuda:1')\n",
      "iter: 5, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.3842e-06, device='cuda:1')\n",
      "iter: 6, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.3744e-06, device='cuda:1')\n",
      "iter: 7, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.3636e-06, device='cuda:1')\n",
      "iter: 8, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.3539e-06, device='cuda:1')\n",
      "iter: 9, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.3440e-06, device='cuda:1')\n",
      "iter: 10, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.3336e-06, device='cuda:1')\n",
      "iter: 11, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.3237e-06, device='cuda:1')\n",
      "iter: 12, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.3135e-06, device='cuda:1')\n",
      "iter: 13, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.3032e-06, device='cuda:1')\n",
      "iter: 14, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.2930e-06, device='cuda:1')\n",
      "iter: 15, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.2834e-06, device='cuda:1')\n",
      "iter: 16, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.2738e-06, device='cuda:1')\n",
      "iter: 17, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.2638e-06, device='cuda:1')\n",
      "iter: 18, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.2541e-06, device='cuda:1')\n",
      "iter: 19, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.2451e-06, device='cuda:1')\n",
      "iter: 20, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.2345e-06, device='cuda:1')\n",
      "iter: 21, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.2249e-06, device='cuda:1')\n",
      "iter: 22, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.2163e-06, device='cuda:1')\n",
      "iter: 23, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.2060e-06, device='cuda:1')\n",
      "iter: 24, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.1966e-06, device='cuda:1')\n",
      "iter: 25, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.1878e-06, device='cuda:1')\n",
      "iter: 26, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.1785e-06, device='cuda:1')\n",
      "iter: 27, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.1689e-06, device='cuda:1')\n",
      "iter: 28, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.1591e-06, device='cuda:1')\n",
      "iter: 29, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.1549e-06, device='cuda:1')\n",
      "iter: 30, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.1457e-06, device='cuda:1')\n",
      "iter: 31, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.1371e-06, device='cuda:1')\n",
      "iter: 32, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.1277e-06, device='cuda:1')\n",
      "iter: 33, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.1189e-06, device='cuda:1')\n",
      "iter: 34, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.1097e-06, device='cuda:1')\n",
      "iter: 35, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.1008e-06, device='cuda:1')\n",
      "iter: 36, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.0917e-06, device='cuda:1')\n",
      "iter: 37, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.0824e-06, device='cuda:1')\n",
      "iter: 38, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.0744e-06, device='cuda:1')\n",
      "iter: 39, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.0659e-06, device='cuda:1')\n",
      "iter: 40, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.0568e-06, device='cuda:1')\n",
      "iter: 41, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.0481e-06, device='cuda:1')\n",
      "iter: 42, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.0394e-06, device='cuda:1')\n",
      "iter: 43, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.0311e-06, device='cuda:1')\n",
      "iter: 44, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.0223e-06, device='cuda:1')\n",
      "iter: 45, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.0135e-06, device='cuda:1')\n",
      "iter: 46, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(2.0056e-06, device='cuda:1')\n",
      "iter: 47, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.9973e-06, device='cuda:1')\n",
      "iter: 48, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.9881e-06, device='cuda:1')\n",
      "iter: 49, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.9807e-06, device='cuda:1')\n",
      "iter: 50, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.9718e-06, device='cuda:1')\n",
      "iter: 51, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.9632e-06, device='cuda:1')\n",
      "iter: 52, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.9562e-06, device='cuda:1')\n",
      "iter: 53, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.9467e-06, device='cuda:1')\n",
      "iter: 54, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.9399e-06, device='cuda:1')\n",
      "iter: 55, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.9313e-06, device='cuda:1')\n",
      "iter: 56, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.9231e-06, device='cuda:1')\n",
      "iter: 57, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.9146e-06, device='cuda:1')\n",
      "iter: 58, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.9074e-06, device='cuda:1')\n",
      "iter: 59, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.8994e-06, device='cuda:1')\n",
      "iter: 60, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.8913e-06, device='cuda:1')\n",
      "iter: 61, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.8838e-06, device='cuda:1')\n",
      "iter: 62, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.8755e-06, device='cuda:1')\n",
      "iter: 63, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.8677e-06, device='cuda:1')\n",
      "iter: 64, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.8597e-06, device='cuda:1')\n",
      "iter: 65, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.8524e-06, device='cuda:1')\n",
      "iter: 66, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.8441e-06, device='cuda:1')\n",
      "iter: 67, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.8379e-06, device='cuda:1')\n",
      "iter: 68, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.8296e-06, device='cuda:1')\n",
      "iter: 69, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.8219e-06, device='cuda:1')\n",
      "iter: 70, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.8138e-06, device='cuda:1')\n",
      "iter: 71, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.8071e-06, device='cuda:1')\n",
      "iter: 72, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.7998e-06, device='cuda:1')\n",
      "iter: 73, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.7917e-06, device='cuda:1')\n",
      "iter: 74, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.7845e-06, device='cuda:1')\n",
      "iter: 75, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.7780e-06, device='cuda:1')\n",
      "iter: 76, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.7701e-06, device='cuda:1')\n",
      "iter: 77, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.7631e-06, device='cuda:1')\n",
      "iter: 78, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.7562e-06, device='cuda:1')\n",
      "iter: 79, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.7486e-06, device='cuda:1')\n",
      "iter: 80, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.7419e-06, device='cuda:1')\n",
      "iter: 81, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.7404e-06, device='cuda:1')\n",
      "iter: 82, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.7335e-06, device='cuda:1')\n",
      "iter: 83, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.7270e-06, device='cuda:1')\n",
      "iter: 84, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.7192e-06, device='cuda:1')\n",
      "iter: 85, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.7136e-06, device='cuda:1')\n",
      "iter: 86, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.7056e-06, device='cuda:1')\n",
      "iter: 87, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.6985e-06, device='cuda:1')\n",
      "iter: 88, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.6929e-06, device='cuda:1')\n",
      "iter: 89, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.6863e-06, device='cuda:1')\n",
      "iter: 90, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.6791e-06, device='cuda:1')\n",
      "iter: 91, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.6729e-06, device='cuda:1')\n",
      "iter: 92, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.6652e-06, device='cuda:1')\n",
      "iter: 93, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.6591e-06, device='cuda:1')\n",
      "iter: 94, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.6521e-06, device='cuda:1')\n",
      "iter: 95, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.6458e-06, device='cuda:1')\n",
      "iter: 96, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.6394e-06, device='cuda:1')\n",
      "iter: 97, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.6327e-06, device='cuda:1')\n",
      "iter: 98, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.6260e-06, device='cuda:1')\n",
      "iter: 99, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.6197e-06, device='cuda:1')\n",
      "iter: 100, num points to shift: 274595\n",
      "tensor(0.0002, device='cuda:1')\n",
      "tensor(1.6137e-06, device='cuda:1')\n",
      "mean shift terminated !!! elapsed time: 3124.2\n"
     ]
    }
   ],
   "source": [
    "max_iter = 100\n",
    "test_num = 1\n",
    "le_errorsSet = []\n",
    "le_wt_errorsSet = []\n",
    "ai_errorsSet = []\n",
    "ai_wt_errorsSet = []\n",
    "for i in range(test_num):\n",
    "    noised_input = input.clone()\n",
    "    print(input.shape)\n",
    "    filteringResults = ms_dae.run_meanShift(noised_input, max_iter, save_prefix = file_prefix, \n",
    "                                            save_iter=None, cleanInput = input.clone(), \n",
    "                                            error_weight = None)\n",
    "    le_errorsSet.append(filteringResults.le_errors.view(1,-1))\n",
    "    le_wt_errorsSet.append(filteringResults.le_wt_errors.view(1,-1))\n",
    "    ai_errorsSet.append(filteringResults.ai_errors.view(1,-1))\n",
    "    ai_wt_errorsSet.append(filteringResults.ai_wt_errors.view(1,-1))\n",
    "le_errorsSet = torch.cat(le_errorsSet, 0)\n",
    "le_wt_errorsSet = torch.cat(le_wt_errorsSet, 0)\n",
    "ai_errorsSet = torch.cat(ai_errorsSet, 0)\n",
    "ai_wt_errorsSet = torch.cat(ai_wt_errorsSet, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n"
     ]
    }
   ],
   "source": [
    "track_log_bad_idx1_shape = []\n",
    "for i in ms_dae.track_log_bad_idx1:\n",
    "    track_log_bad_idx1_shape.append(i.shape[0])\n",
    "print(track_log_bad_idx1_shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n"
     ]
    }
   ],
   "source": [
    "track_log_bad_idx2_shape = []\n",
    "for i in ms_dae.track_log_bad_idx2:\n",
    "    track_log_bad_idx2_shape.append(i.shape[0])\n",
    "print(track_log_bad_idx2_shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7f718c067da0>]"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAD4CAYAAADo30HgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAnrUlEQVR4nO3dd3RVVfrG8e+b0ELv0gUVFLCgRrCLCgI2rIAKgmCbUUdHR0d0VMYyio69oyCgKCBYsINgV0pQkA6hCUgnJLT09/dHjvMLMcgVkpzcm+ezliv37FPuu5chzz1nn3u2uTsiIiK/iQu7ABERKV0UDCIishsFg4iI7EbBICIiu1EwiIjIbsqFXUBRqFu3rjdv3jzsMkREosrMmTM3uXu9gu0xEQzNmzcnKSkp7DJERKKKma0srF2XkkREZDcKBhER2Y2CQUREdqNgEBGR3SgYRERkNwoGERHZjYJBRER2o2AQEYlG6WnkfnQ7pKcW+aEVDCIi0WZTMrte7IjPeJVffppc5IdXMIiIRJPFE8l86TR2bd3A3yv+m10tOhf5W8TEIzFERGKeO9lfP07cFw+yOPdAhjZ5kPt7d6Fm5QpF/lYKBhGR0i5zB+njrqfS4gm8n3Miycc/zH+7HUV8nBXL2ykYRERKs5QV7BzZk4opi3ks9wpaX3I3tx3VuFjfUsEgIlJaLf2CjNF9ycrM5p5K/+LqftfQumH1Yn9bBYOISGnjTva3TxM3+d+syG3EkEYPcM+V5xbLeEJhFAwiIqVJ5k52jf8rCYve5eOc9iw8fjCPdju62MYTCqNgEBEpLVJWsuP1XiRsWcBTub1odel93HpkoxIvQ8EgIlIK+LIvyXirLzmZmdyV8C/6X3UdrQ6oFkotCgYRkTC5k/XtM8RPHsSq3IYMbfoUA3ufS42E8qGVpGAQEQlL5g52jvsrlRe/xyc5x5F84mP8p0s74kpwPKEwCgYRkTBsWc6Okb1I2LqIp/wyDu81iJvaNgi7KkDBICJS4jx5Mhmj+5Gdlc2/qtzLNf2vo0XdKmGX9T8KBhGRkuJO5ldPUO7LB1me25jRBz3M3Zd1o0rF0vWnuHRVIyISqzK2sX3stVRd+jEf5BzPxjMeZ1DHtpiFO55QGAWDiEhx27SE7SN6UiltOY/H9eX4K+/hvJb1wq5qjyKaj8HMuprZIjNLNrM7C1lf0czGBOunmVnzfOsGBu2LzKxL0NbUzL4ws/lmNs/Mbs63/SAzW2Nms4L/zi6CfoqIhCJnwYdkvHgaGWkbGFTjQXrdPJiTSnEoQARnDGYWDzwPdAZWAzPMbIK7z8+32QAgxd0PMbNewGCgp5m1AXoBbYFGwOdm1grIBm5z9x/NrBow08wm5Tvmk+7+36LqpIhIicvNYdekB0n44Qnm5h7Ex60f5V+XnEGl8vFhV7ZXkZwxtAeS3X2Zu2cCo4HuBbbpDowIXo8DzrS8C2fdgdHunuHuy4FkoL27r3X3HwHcfRuwACje58iKiJSUnVvYNuxiEn54grdzO7Ko21gGXtY5KkIBIguGxsCqfMur+f0f8f9t4+7ZQCpQJ5J9g8tORwPT8jXfaGY/m9kwM6tVWFFmdq2ZJZlZ0saNGyPohohICVj7M9ufPZmKq77m0XLX0erq4fQ4oWXYVf0poc75bGZVgfHALe6eFjS/CBwMtAPWAo8Xtq+7D3H3RHdPrFevdF+vE5GyIfPHt8gacibbdu7iofpPMOCWBziqWaGfbUu1SO5KWgM0zbfcJGgrbJvVZlYOqAFs/qN9zaw8eaEwyt3f+W0Dd1//22szewX4MNLOiIiEIjuT7R/8k6qzhzE1tzUzj3uce885sUQflV2UIjljmAG0NLMWZlaBvMHkCQW2mQD0DV5fAkxxdw/aewV3LbUAWgLTg/GHocACd38i/4HMrGG+xQuBuX+2UyIiJSZtLakvdaHq7GGM8HPZduk4bjjvpKgNBYjgjMHds83sRuAzIB4Y5u7zzOx+IMndJ5D3R/51M0sGtpAXHgTbjQXmk3cn0g3unmNmJwN9gDlmNit4q7vc/WPgUTNrBziwAriuyHorIlKEcpd/x643+1Auczv/qXIHl/W/pVQ92mJfWd4H++iWmJjoSUlJYZchImWFO7u+eY4KU+5lZW593j74YW667HwqV4iu7wyb2Ux3TyzYHl29EBEJW8Z2to65nprLPmBibiKbOz3FHaccXiofbbGvFAwiIpHatITU4T2ptm0Zz5frzQlXPcBZB9YOu6oip2AQEYlA5pz3yH33L2TnxPNovUe4tt9V1KlaMeyyioWCQUTkj+Rkk/bxvVSf+Tyzcg9mWuJT3HHuKVF919HeKBhERPZk+wZSRvam1oZpjKEzdS99kuuOaLr3/aKcgkFEpBA5K35g16jeVMpM479VbuXSAf/gwDrRfytqJBQMIiL5ubP9q2ep9OUgNubW5f1WL3Njz+5R8wC8oqBgEBH5TcY2trx1HbVXfMTnuYls7/YMt5zQNuyqSpyCQUQE8A0LSB1+GTV2rOClCldyar8HaNO4ZthlhULBICJl3s4fRxP3wc1k5VbgqUaPcU3fflSvVD7sskKjYBCRsis7g83v3E6d+SNIyj2Uxac8w62dOsTUt5j3hYJBRMok3/oLW4ZfTp2tcxgVdx6HXvk4lx90QNhllQoKBhEpczIWTiT77aupkJ3JU3XvpfdVN1I3Rr/FvC8UDCJSduTmsOXjB6iZ9BTLc5vwQ+KT3HRup5j+FvO+UDCISNmwfSMbR/Sh3sYfmGAdqdXrWa5q0yzsqkolBYOIxLzM5d+R/uaVVM9M5YUat3Bh/3/SsGblsMsqtRQMIhK73En5/Emqf/cAW3Lr8fkRQ7nmovMoHx/JrMZll4JBRGLTrq2sf+NqDlgzic9pT9xFL3B1u5ZhVxUVFAwiEnOyVv3EtjeuoHb6Ol6tcjVdBtxP0zLyALyioGAQkdjhTsrXL1Hli3tI92q80+p5+vTsQcVyZecBeEVBwSAisSFjO2tHXU/DXz7gG2/HrnNf5Orj2oRdVVRSMIhI1MtaO5fUEZdTf9cvDE+4ko4D/kPzetXCLitqKRhEJKqlfD+cyhPvINcTGH7I01zRq3eZmjuhOCgYRCQ6Ze7k1zdvoNGKd5jmbUk9+0UGdDgq7KpigoJBRKJO1roFbB1xOQ12LmdUwmWc2P9ROtSvHnZZMUPBICJRZcv3I6g88XbMKzLykCfo1auvLh0VsYi+/mdmXc1skZklm9mdhayvaGZjgvXTzKx5vnUDg/ZFZtYlaGtqZl+Y2Xwzm2dmN+fbvraZTTKzJcHPWkXQTxGJdlm7WDNiALUn/o05fhA/nf0B/fr0VygUg70Gg5nFA88D3YA2wGVmVvAesAFAirsfAjwJDA72bQP0AtoCXYEXguNlA7e5exvgeOCGfMe8E5js7i2BycGyiJRhmesWsuGJE2m8fBxjEnpQ/8bP6NyhXdhlxaxIzhjaA8nuvszdM4HRQPcC23QHRgSvxwFnWt4USN2B0e6e4e7LgWSgvbuvdfcfAdx9G7AAaFzIsUYAF+xTz0QkJmz+dji5L51K/M6NvH7w41xw20scWK9G2GXFtEjGGBoDq/ItrwY67Gkbd882s1SgTtA+tcC+jfPvGFx2OhqYFjQd4O5rg9frAE2pJFIWZe5k1agbaLryHZJozbZzX6bPcbrrqCSEOvhsZlWB8cAt7p5WcL27u5n5Hva9FrgWoFkzPVNdJJZk/DqPrSMvp/GulYyt0osT+j9GYl3ddVRSIrmUtAZomm+5SdBW6DZmVg6oAWz+o33NrDx5oTDK3d/Jt816M2sYbNMQ2FBYUe4+xN0T3T2xXr16EXRDREo9dzZ8PRQfcjpxu7YwpvVTXHjbizRVKJSoSIJhBtDSzFqYWQXyBpMnFNhmAtA3eH0JMMXdPWjvFdy11AJoCUwPxh+GAgvc/Yk/OFZf4P0/2ykRiUIZ21k59ErqT7mVObRk8YWfcFmvfpo7IQR7vZQUjBncCHwGxAPD3H2emd0PJLn7BPL+yL9uZsnAFvLCg2C7scB88u5EusHdc8zsZKAPMMfMZgVvdZe7fww8Aow1swHASqBHEfZXREqhnb/8xPY3+tAkYzVjq/fh1P6DaVBLj8kOi+V9sI9uiYmJnpSUFHYZIvJnufPr589T97tBbPGqfHfkI1xwYU/i4yzsysoEM5vp7okF2/XNZxEJhe9KYeVrV9N8w+d8b8dQscfLXNymVdhlCQoGEQlB2pIfyBzTj8ZZGxlb+xo6DXiQ2lUrhV2WBBQMIlJycnNZ+eFgGv/4GGlei+/bv8alZ59P3v0oUlooGESkRGSnrWfVsL602PoDX8WfQL0rhnD+QfoOUmmkYBCRYrfp54nEv3ctjXK283bDW+nW7y6qViofdlmyBwoGESk+OdksHXc3LRa8zHJvyI+nD+XSjmeGXZXshYJBRIpF+qYVrB/Wm4N3zmFSxc606vciZzbUUwqigYJBRIrcmu9HU33irdT2XN49+N+cc/nfqFBO32COFgoGESkynrmDJSNvotXq8czjELaf/zIXHvu7709JKadgEJEikbpiNttHXUmrrBV8VL0HHQY8Qdsa1cIuS/aBgkFE9o87Sz95hsbTHyTTE/j0mBfodt7lxOmxFlFLwSAi+yxr2yaWD+tPq5SvmB5/NNV6vUrXloeEXZbsJwWDiOyTdbMmUW7CdTTP2cqHjW7ijH73UrlihbDLkiKgYBCRP8WzM1k4+m4OXfIKv1gDFncay7mndAq7LClCCgYRiVja2mQ2j+hD6/T5fFHlLFpf9SIn1qsbdllSxBQMIhKRpZOH0eCbu6jr8Fmbh+l06V80b0KMUjCIyB/K2rmVxcOup+2mT/g5rjXlLnmFLm2OCLssKUYKBhHZo7VzvsLevYZDczYy8YD+nHTVw1RJ0LwJsU7BICK/4znZzBtzH4cteoH1Vofpp7/BWR3PCbssKSEKBhHZTeqvS9kwsi+Hp8/hm8qnc0i/lznxgAPCLktKkIJBRP5n4aRhNP7ubhq683nrBzi9x00aYC6DFAwiQvq2FBa/dj1HbvmUuXGHUf7SV+nUWgPMZZWCQaSM++WnKVT44Dra5GxicsOrObHff0ioVDHssiRECgaRMio3O4vZo+7myGVDWGv1md15NGee3CXssqQUUDCIlEEbVy4gdVQ/js5cyHdVO3NY/5c4to6+wSx5FAwiZYk7syc8Q8ufHqK8l+Pbox/jpO7XYKYBZvl/CgaRMmLblnUsHzaAo7Z/y+zy7ah1xauc3Lxl2GVJKRTRJKxm1tXMFplZspndWcj6imY2Jlg/zcya51s3MGhfZGZd8rUPM7MNZja3wLEGmdkaM5sV/Hf2fvRPRICF37xDxjPHc+i2qXzV4hba/nMyzRQKsgd7PWMws3jgeaAzsBqYYWYT3H1+vs0GACnufoiZ9QIGAz3NrA3QC2gLNAI+N7NW7p4DDAeeA0YW8rZPuvt/96NfIgJk7NrGnNduIXHDOJZbMzZ1H8VpR58UdllSykVyxtAeSHb3Ze6eCYwGuhfYpjswIng9DjjT8i5adgdGu3uGuy8HkoPj4e5fA1uKoA8iUojlP3/L+seOJ3HDOL6reyn1b/uewxQKEoFIgqExsCrf8uqgrdBt3D0bSAXqRLhvYW40s5+Dy021CtvAzK41syQzS9q4cWMEhxQpG3Kys5g2YiBNxp9PxdydzOo4nJNufJUqVauFXZpEiYjGGErYi8DBQDtgLfB4YRu5+xB3T3T3xHr16pVgeSKl16/L5rNk8Kl0WP4Cs6udSoUbp9Ku44VhlyVRJpK7ktYATfMtNwnaCttmtZmVA2oAmyPcdzfuvv6312b2CvBhBDWKlGmem0vSu0/T5udHqEoc048dzHHnXafbUGWfRHLGMANoaWYtzKwCeYPJEwpsMwHoG7y+BJji7h609wruWmoBtASm/9GbmVnDfIsXAnP3tK2IwMZ1vzD7sbM5bs4gVlQ8lJ0Dvqb9+dcrFGSf7fWMwd2zzexG4DMgHhjm7vPM7H4gyd0nAEOB180smbwB5V7BvvPMbCwwH8gGbgjuSMLM3gI6AnXNbDVwn7sPBR41s3aAAyuA64qwvyIxZeanIzlo6l209nSmHvoP2ve8i7j4+LDLkihneR/so1tiYqInJSWFXYZIiUlN2cyi1/5C+7TPSI4/mPKXvsqBhx0TdlkSZcxsprsnFmzXN59Foszsbz6g/uRbONY3M63ZAI7p8x/KV9B0m1J0FAwiUWLH9m3MGn4rJ20ay+q4hiw/Zzwdjj0z7LIkBikYRKLAvBlfUvXjGzjJVzOz/sW07fsUlapUD7ssiVEKBpFSLD09nekj7+LENa+xJa4WCzuN4NiTLwi7LIlxCgaRUmrhz9Ox9/7CqbnJzKrdhZZ9X6B+Tc2ZIMVPwSBSymRkZvLDqAc4YcWL7LJKzD/5Odp16hN2WVKGKBhESpElC2aTOe46OuYsYG71k2nWdwht6kbyeDGRoqNgECkFsrKz+fbNwXRY+jQ5Fs/c9o9yeLdrQd9elhAoGERCtnTJAraPuZbTs39mQdXjaHTlqxx+QPOwy5IyTMEgEpLs7By+Hvsk7Rf9lwbmzD3m3xx+3s06S5DQKRhEQrA0eRFbx/yFM7JmsrhyO+pd8QqHN2kVdlkigIJBpERlZ+fw5dvP0H7hozSyHOYd9S/adr8N4krj1ChSVikYREpI8tIlbBn9VzplTSe58pHUueIV2jY5LOyyRH5HwSBSzLKyc/ji7edov3AwTSyL+UfeRZsLbtdZgpRaCgaRYrR4yWJSxt7AWVnTWZrQltqXv0qbZm3CLkvkDykYRIpBZlYOk8c+w4mLH6OZZbHgyIG0vuB2iNMkOlL6KRhEitiCRYtIe/uvdMtOYlnlI6hzxSu0btI67LJEIqZgECki6ZnZTB79FKcsfZwWls3io++m1Xn/0FiCRB0Fg0gR+HnuHNLfvYlzcn5iWZUjqdf7VVo1OjTsskT2iYJBZD/sSM9kyqjH6PjLs8SbsyTxPlqefYvOEiSqKRhE9tGMH38k7sObOC93LkurJ9Kg9yu0POCgsMsS2W8KBpE/KXVHOl+/8RBn/voyuRbP8hMe5uCz/qJnHEnMUDCI/AnffP8dNSb9nfN8EUtrnkDjK1+mRZ0Dwy5LpEgpGEQisGHrNqa+fi9dNo0kIy6BVac9xcEd++ksQWKSgkHkD7g7n0+ZSNNvbud8VpJcvxMH9n6e6jUahF2aSLFRMIjswar1m/npjYGcnfY22+JrsK7LqxzS4dKwyxIpdgoGkQKyc3L55MNxHPHjPZxv60hu0p2DLn+auCq1wi5NpEREdLO1mXU1s0VmlmxmdxayvqKZjQnWTzOz5vnWDQzaF5lZl3ztw8xsg5nNLXCs2mY2ycyWBD/1r1FKzPzlq5j06GWc99M1VCkPmy8ayyHXjFQoSJmy12Aws3jgeaAb0Aa4zMwKPh5yAJDi7ocATwKDg33bAL2AtkBX4IXgeADDg7aC7gQmu3tLYHKwLFKsdmXmMG7Uy9QefgpnZUxk2SH9qHv7TOoc2WXvO4vEmEjOGNoDye6+zN0zgdFA9wLbdAdGBK/HAWeamQXto909w92XA8nB8XD3r4Ethbxf/mONAC6IvDsif970OQuYOvhcLllyB7kJtdnV51MO6v00VrFq2KWJhCKSMYbGwKp8y6uBDnvaxt2zzSwVqBO0Ty2wb+O9vN8B7r42eL0OOKCwjczsWuBagGbNmu29FyIFbNmewaRRj9H11+dJsExWtruVA8+7C+LLh12aSKhK9eCzu7uZ+R7WDQGGACQmJha6jUhh3J2JX39LnS/uoCfzWVX9aCpd8RIHNtA0myIQWTCsAZrmW24StBW2zWozKwfUADZHuG9B682sobuvNbOGwIYIahSJyIr1KcwYdR/np75JVlxF1p36KE1Pu0YPvRPJJ5J/DTOAlmbWwswqkDeYPKHANhOAvsHrS4Ap7u5Be6/grqUWQEtg+l7eL/+x+gLvR1CjyB/KzM5l3HvvkPnCyVyaNoK1DU+n8t9n0uD06xQKIgXs9YwhGDO4EfgMiAeGufs8M7sfSHL3CcBQ4HUzSyZvQLlXsO88MxsLzAeygRvcPQfAzN4COgJ1zWw1cJ+7DwUeAcaa2QBgJdCjSHssZc5Pi1awevw/uSjjM7aWq0vKOSNpfkzB+ydE5DeW98E+uiUmJnpSUlLYZUgpk7ojkw/HvEjnlU9Qx9JY07IPzS55CCpWC7s0kVLBzGa6e2LB9lI9+CyyL9ydST/MIGHSP7nCf2RdlVZk9hhHs+bHhV2aSFRQMEhMWbEhle/efJALUkYQb7Du+Hto0PkWiNevukik9K9FYkJGdg7vfjCBI2cN4gpbwep6p9Dw8udoULt52KWJRB0Fg0S9qfOXs+7dgfTI/JS08nVI7TqUJsderLkSRPaRgkGi1oa0XXw0+kXOWfM0x1kaaw+9ksYXPQiVqoddmkhUUzBI1MnJdd7/4lvqf/MvrmIW66q2JrvHeBof+LubK0RkHygYJKr8vGI9P4+5n0t2jsHjyrPppH/T4IybIC5+7zuLSEQUDBIVUndmMX78G3RcMpjecWtZ06QrjXo+SUL1RmGXJhJzFAxSqrk7H30/i3Kf/4v+/i1bEhqz44K3adz6rLBLE4lZCgYptRas2cIPowdzSdoIEiyLDcfcQv1uA6F8pbBLE4lpCgYpddLSsxj37jt0WPAQ/eNWsrbeiVTt+Sz16x0SdmkiZYKCQUoNd+eTaXPJ+uxe+vsUUivWY8fZQ2nYTt9JEClJCgYpFRb+msI3ox/n0tRhVLV0NhxxHfXPvRc0vaZIiVMwSKhSd2Ux9v336DD/P1wTt4z1dRKp3uNZ6jdoE3ZpImWWgkFCkZvrfDB1DrmTBjEgdwo7KtRmR5cXOSDxMl02EgmZgkFK3JxftvDd2/+lV9pwqtkuthx5NXXPuVePshApJRQMUmI2b89gzLvvcMqSR7g+bgUb6h6HXfoMdXXZSKRUUTBIscvOyWX8VzOp9PUD/JWvSKtYl53dhlD/6B66bCRSCikYpFj9sGQtP49/jMt3vUmCZbHl6Buo3fUu3W0kUoopGKRYrE7ZyTvjRtFt1ZNcF7eGDQ1Po+olT1C7rr6kJlLaKRikSO3MzObNz76h2YyH+FvcDLZWbkLm+W9Rv83ZYZcmIhFSMEiRcHc+mrmEjZ8Mpk/2+xAfT+oJA6l5+i16tpFIlFEwyH77eVUKX7z9PD1Sh9LQtrDp4Auoe8HDVNQjsUWikoJB9tmGtHTefG8CJyc/xs1xi9lSow05F79F3QOPD7s0EdkPCgb509KzcnhrygxqfP8wf+Mrdlaoya7OT1P7uCshLi7s8kRkPykYJGLuzqezVrLyo8fonTWOSpZN2tHXU7PLQKhUI+zyRKSIKBgkIj+vSmHi+FfosWUI3eI2srlpJ6pe+Cg16xwcdmkiUsQiOu83s65mtsjMks3szkLWVzSzMcH6aWbWPN+6gUH7IjPrsrdjmtlwM1tuZrOC/9rtXxdlf6xN3cXjw8ey85Vu/GPrQ1SvUZOc3u9T5+rxoFAQiUl7PWMws3jgeaAzsBqYYWYT3H1+vs0GACnufoiZ9QIGAz3NrA3QC2gLNAI+N7NWwT5/dMzb3X1cEfRP9tGOjGzemDSVetMf5e/2NekVapB+5mPUbN8f4nWiKRLLIvkX3h5IdvdlAGY2GugO5A+G7sCg4PU44Dkzs6B9tLtnAMvNLDk4HhEcU0KQk+u8O20Rmyc9QZ+c9ygfn8v2Y/5K9c7/1DiCSBkRSTA0BlblW14NdNjTNu6ebWapQJ2gfWqBfRsHr//omA+Z2b3AZODOIFh2Y2bXAtcCNGvWLIJuyN58s3g9U997kd47RtDQtpDS4hwqn/8Q5Wu3CLs0ESlBpfGawEBgHVABGAL8E7i/4EbuPiRYT2JiopdkgbFm4bo03nlnDOete47b41awtfYR+AVvUuvAE8IuTURCEEkwrAGa5ltuErQVts1qMysH1AA272XfQtvdfW3QlmFmrwH/iKBG2Qfr09IZ+cFEjl74JHfF/8i2hAZkdX2Zmkf10PcRRMqwSIJhBtDSzFqQ98e7F3B5gW0mAH2BH4BLgCnu7mY2AXjTzJ4gb/C5JTAdsD0d08wauvvaYIziAmDu/nVRCtqekc3rk2ZQY/oT/N0+J7tCArtOvodqJ98A5RPCLk9EQrbXYAjGDG4EPgPigWHuPs/M7geS3H0CMBR4PRhc3kLeH3qC7caSN6icDdzg7jkAhR0zeMtRZlaPvPCYBVxfZL0t47Jycnn7h0WkTHmaK3Peo3JcBjuP7Ee1LndDlbphlycipYS5R//l+cTERE9KSgq7jFLL3flkzhrmfvQSV6a/QQNLYWuzztQ8/2Go2zLs8kQkJGY2090TC7aXxsFnKULTlm5i4vtv0GPrK5wdt5rUuu3w89+k5oEnhl2aiJRSCoYYtXBdGmPfn0Cn1c9zT/x8tlVtRs45I6jRprvmWRaRP6RgiDGrU3by+oeTOXLxs9wbP42dlWqRefqjVOvQH+LLh12eiEQBBUOM2Lw9g+ETp9Lop2e4PW4KueUrkn787VQ+7WaoWC3s8kQkiigYotz2jGxe/2IO8T88zV/5iPLlcsk4qh9VOg2EqvXDLk9EopCCIUqlZ+Xw1neLSP3qBfrmvkst2862lt1J6DaIcrUPCrs8EYliCoYok52Ty/gZy1n5+ctcmTWWBpZCWtPT4JwHqNbwqLDLE5EYoGCIErm5zkc/r2H2J0PpvWsUPePWk1b/GDhnFNWbnxR2eSISQxQMpZy7M3HeOn745A16bhvJeXG/sK3WofjZz1C9VRfdeioiRU7BUEq5O18u2sAXH4/hoq2vMShuGdurNSO3y6tUO/xiPeRORIqNgqEU+j55E598NJ5zNw/j/riF7KjSkJxOz1K13eWaPU1Eip3+ypQiU5dt5sOP3uesDcN4IH4OOxPqkX36Y1RJ7AvlKoZdnoiUEQqGUmDass188PEHdFo/jAfjZ7OrUi2yTnuAyh2u0WOwRaTEKRhCNH35FiZ8/AFnrBvGg/Gz8gLhlEEkdLgGKlYNuzwRKaMUDCGYumwzH3zyURAIP5FeqSZZJ91LwgnXKRBEJHQKhhLi7nybvIlPP5lA540jeCh+NukVa5B18j1UOuE6Pc9IREoNBUMxy7vtdCOTP32HrptH8lD8PNIr1SLr5PuodPw1CgQRKXUUDMUkJ9f5bO5avps0nvNT3+DBuIXsSqhD9qkPUKn9AKhQJewSRUQKpWAoYpnZubz30yrmTX6Ti3aO5aG4ZexMqE9Ox0dISOynu4xEpNRTMBSRXZk5jJm2lFVfjeSyzPH0iPuVHdWakdvxKSoffbm+hyAiUUPBsJ82b8/gzW8XsnPacK7InUAT28T22ofhZz5AlTYX6JvKIhJ19FdrHy3ftIPRU2ZQc+5w+thEatoOth1wLHR6gaotz9LD7UQkaikY/gR358dfUvjg8ym0XvE6t8V9S7m4HHa26Aqn/51qzTqEXaKIyH5TMEQgIzuHj2avYfZX73JayjsMip9FVvmKZB3Rhwqn3kTVOgeHXaKISJFRMPyBdanpjPtuLhlJr3NR9idcFLeeXQl1yOwwkArHX0P5KnXCLlFEpMgpGArIzXV+WLaZL7/5koOWvUn/uG+pbBmk1j8GP/UhEtp0h3IVwi5TRKTYKBgCKzbt4JOps8mY9TanZ37J3XHLyCpXgfTWF8Mpf6GG5lMWkTIiomAws67A00A88Kq7P1JgfUVgJHAssBno6e4rgnUDgQFADvA3d//sj45pZi2A0UAdYCbQx90z96+bhduWnsXEn5ayZuo4jtryGdfEzaWc5bK1dhuyEh+k/DFXUL5y7eJ4axGRUmuvwWBm8cDzQGdgNTDDzCa4+/x8mw0AUtz9EDPrBQwGeppZG6AX0BZoBHxuZq2CffZ0zMHAk+4+2sxeCo79YlF0tqBvX72DbhtHUdkySKvckPSjbqJq4uXUrH9YcbydiEhUiOSMoT2Q7O7LAMxsNNAdyB8M3YFBwetxwHNmZkH7aHfPAJabWXJwPAo7ppktAM4ALg+2GREct1iCod3hh7Nz7cUknNiH6k2P1zzKIiJEFgyNgVX5llcDBW/Y/9827p5tZqnkXQpqDEwtsG/j4HVhx6wDbHX37EK2342ZXQtcC9CsWbMIuvF7DTsOIO+EREREfhO1H5HdfYi7J7p7Yr169cIuR0QkZkQSDGuApvmWmwRthW5jZuWAGuQNQu9p3z21bwZqBsfY03uJiEgxiiQYZgAtzayFmVUgbzB5QoFtJgB9g9eXAFPc3YP2XmZWMbjbqCUwfU/HDPb5IjgGwTHf3/fuiYjIn7XXMYZgzOBG4DPybi0d5u7zzOx+IMndJwBDgdeDweUt5P2hJ9huLHkD1dnADe6eA1DYMYO3/Ccw2sweBH4Kji0iIiXE8j6kR7fExERPSkoKuwwRkahiZjPdPbFge9QOPouISPFQMIiIyG4UDCIispuYGGMws43Ayn3cvS6wqQjLiRbqd9lTVvuufu/Zge7+uy+CxUQw7A8zSyps8CXWqd9lT1ntu/r95+lSkoiI7EbBICIiu1EwwJCwCwiJ+l32lNW+q99/UpkfYxARkd3pjEFERHajYBARkd2U6WAws65mtsjMks3szrDrKS5mNszMNpjZ3Hxttc1skpktCX7WCrPG4mBmTc3sCzObb2bzzOzmoD2m+25mlcxsupnNDvr976C9hZlNC37fxwRPNo45ZhZvZj+Z2YfBcsz328xWmNkcM5tlZklB2z7/npfZYMg3l3U3oA1wWTBHdSwaDnQt0HYnMNndWwKTg+VYkw3c5u5tgOOBG4L/x7He9wzgDHc/CmgHdDWz4/n/+dQPAVKI3ekLbwYW5FsuK/0+3d3b5fvuwj7/npfZYCDfXNbungn8Npd1zHH3r8l7HHp+3cmbU5vg5wUlWVNJcPe17v5j8HobeX8sGhPjffc824PF8sF/Tt586uOC9pjrN4CZNQHOAV4Nlo0y0O892Off87IcDIXNZV3o/NIx6gB3Xxu8XgccEGYxxc3MmgNHA9MoA30PLqfMAjYAk4ClRDifepR7CrgDyA2WI55HPso5MNHMZprZtUHbPv+e73WiHol97u5mFrP3LZtZVWA8cIu7p+V9iMwTq30PJsRqZ2Y1gXeBw8KtqPiZ2bnABnefaWYdQy6npJ3s7mvMrD4wycwW5l/5Z3/Py/IZQyRzWcey9WbWECD4uSHkeoqFmZUnLxRGufs7QXOZ6DuAu28lb7rcE4j9+dRPAs43sxXkXRo+A3ia2O837r4m+LmBvA8C7dmP3/OyHAyRzGUdy/LP0x2Tc2sH15eHAgvc/Yl8q2K672ZWLzhTwMwSgM7kja/E9Hzq7j7Q3Zu4e3Py/j1PcfcriPF+m1kVM6v222vgLGAu+/F7Xqa/+WxmZ5N3TfK3eacfCrei4mFmbwEdyXsM73rgPuA9YCzQjLxHlvdw94ID1FHNzE4GvgHm8P/XnO8ib5whZvtuZkeSN9gYT96Hv7Hufr+ZHUTeJ+na5M2n3tvdM8KrtPgEl5L+4e7nxnq/g/69GyyWA95094fMrA77+HtepoNBRER+ryxfShIRkUIoGEREZDcKBhER2Y2CQUREdqNgEBGR3SgYRERkNwoGERHZzf8BiNGeS+6L8HgAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "endidx = 50\n",
    "plt.plot(le_errorsSet[0].cpu().numpy()[:endidx])\n",
    "plt.plot(ai_errorsSet[0].cpu().numpy()[:endidx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./filtering_1mm_axial_results/vectordaeTanh_filtering_allresult_subject033_S_4179_20180125_131417_std0.0100_stepsize0.005_maxiter100.pickle\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "savefilename = './filtering_1mm_axial_results/vectordaeTanh_filtering_allresult'+'_subject'+file_prefix+'_std'+'{:.4f}'.format(noise_std)+'_stepsize'+str(step_size)+'_maxiter'+str(max_iter)+'.pickle'\n",
    "print(savefilename)\n",
    "with open(savefilename, 'wb') as handle:\n",
    "    pickle.dump(filteringResults, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save Figures"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### NOTE: To run the below code, nibabel library is required"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "dx = 1\n",
    "data_dim = [roi_range[1] - roi_range[0], roi_range[3] - roi_range[2], roi_range[5] - roi_range[4]]\n",
    "mat2 = np.eye(4)\n",
    "mat2[0,0] = -1\n",
    "mat2[1,1] = 1\n",
    "mat2[2,2] = 1\n",
    "mat2[0,3] = 72 # 72 - roi_range[0]*1\n",
    "mat2[1,3] = -106 # -106 + roi_range[2]*1\n",
    "mat2[2,3] = 10 # 10 + roi_range[4]*1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1.9085e-05, 3.4513e-04, 7.8909e-04, 1.2873e-03, 1.8325e-03, 2.4196e-03],\n",
      "       device='cuda:1')\n"
     ]
    }
   ],
   "source": [
    "from DTI_ms_utils import save_DTI_shiftedTensor_result_file\n",
    "\n",
    "index = [0, 8, 17, 26, 35, 44]\n",
    "print(le_errorsSet[0][index])\n",
    "for i in index:\n",
    "    shiftedPoints = filteringResults.shiftedPointsSet[i]\n",
    "    savefilename = './filtering_1mm_axial_results/vectordaeTanh_filtering_result'+'_subject'+file_prefix+'_iter'+str(i+1)+'_std'+'{:.4f}'.format(noise_std)+'_stepsize'+str(step_size)+'.nii.gz'\n",
    "    shiftedTensor = save_DTI_shiftedTensor_result_file(data_dim, input.cuda(), shiftedPoints.cuda(), dx, \n",
    "            savefilename, mat2 = mat2, posMean = testdataset.posMean.cuda(), use_logvec = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
