{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import math\n",
    "from Pn_util import *\n",
    "from Pn_DataUtil import *\n",
    "from gae_score_estimation import *\n",
    "from gae_pd import *\n",
    "from gae_pd_trainer import *\n",
    "from gae_pd_score_estimation import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.lines as lines\n",
    "from mpl_toolkits.mplot3d import Axes3D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "useGPU = torch.cuda.is_available()\n",
    "#useGPU = False\n",
    "print(useGPU)\n",
    "np.random.seed()\n",
    "torch.cuda.set_device(0)\n",
    "torch.cuda.seed()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generate tangent Gaussian mixture with equidistant means and non-isotropic covariance"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Two mixtures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[3.4033, 0.0000, 0.0000],\n",
      "         [0.0000, 1.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 1.0000]],\n",
      "\n",
      "        [[1.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 3.4033, 0.0000],\n",
      "         [0.0000, 0.0000, 1.0000]]])\n",
      "tensor([[[0.3162, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.1000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.1000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.1000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.1000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1000]],\n",
      "\n",
      "        [[0.1000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.1000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.1000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.3162, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.1000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1000]]])\n",
      "tensor([3.0000], device='cuda:1')\n"
     ]
    }
   ],
   "source": [
    "N = 10000\n",
    "Nmix = 2\n",
    "\n",
    "pd_dim = 3\n",
    "vec_dim = int(pd_dim*(pd_dim+1) / 2)\n",
    "\n",
    "dist = math.sqrt(2*vec_dim)/2\n",
    "\n",
    "data_date = '210912m'+str(Nmix)\n",
    "\n",
    "r = math.sqrt(10) # std condition number\n",
    "c1 = math.pow(r, 1-0)\n",
    "c2 = math.pow(r, -0)\n",
    "var = 0.01\n",
    "Cov_sqrt = math.sqrt(var)*torch.eye(vec_dim)*c2\n",
    "Cov_sqrts = torch.zeros(Nmix,vec_dim,vec_dim)\n",
    "\n",
    "for i in range(Nmix):\n",
    "    Cov_sqrts[i] = Cov_sqrt.clone()\n",
    "    j = rowcol2idx(i, i, pd_dim)\n",
    "    Cov_sqrts[i,j,j] *= r\n",
    "\n",
    "CovInvs = torch.zeros(Nmix, vec_dim, vec_dim)\n",
    "for i in range(Nmix):\n",
    "    CovInvs[i] = torch.inverse(torch.mm(Cov_sqrts[i], Cov_sqrts[i]))\n",
    "\n",
    "Means = torch.zeros(Nmix,pd_dim,pd_dim)\n",
    "m = torch.FloatTensor(Nmix, pd_dim).zero_()\n",
    "for i in range(Nmix):\n",
    "    m[i,i] = dist/math.sqrt(2)\n",
    "Means = torch.diag_embed(torch.exp(m))\n",
    "\n",
    "print(Means)\n",
    "print(Cov_sqrts)\n",
    "print(squared_distance(Means[0:1].cuda(), Means[1:2].cuda()))\n",
    "\n",
    "Pndataset = PndataTangentGaussianMixture(N, Means, Cov_sqrts)\n",
    "#torch.save(Pndataset, 'P'+str(pd_dim)+'TangentGaussianMixture'+data_date+'.pth')\n",
    "#Pndataset = torch.load('P'+str(pd_dim)+'TangentGaussianMixture'+data_date+'.pth')\n",
    "\n",
    "N_testscore = 10000\n",
    "Pndataset_testscore = PndataTangentGaussianMixture(N_testscore, Means, Cov_sqrts)\n",
    "#torch.save(Pndataset_testscore, 'P'+str(pd_dim)+'TangentGaussianMixtureForTest'+data_date+'.pth')\n",
    "#Pndataset_testscore = torch.load('P'+str(pd_dim)+'TangentGaussianMixtureForTest'+data_date+'.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.5000, 0.5000], device='cuda:2')\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor(0, device='cuda:2'),\n",
       " tensor(0, device='cuda:2'),\n",
       " tensor(0, device='cuda:2'),\n",
       " tensor(0, device='cuda:2'))"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = vec2mat(Pndataset.train_data.cuda())\n",
    "weights = torch.cuda.FloatTensor([1/Nmix]*Nmix)\n",
    "print(weights)\n",
    "logp = log_rho_tangentGaussianMixture(X.cuda(), weights, Means.cuda(), CovInvs.cuda())\n",
    "score_true_train, _ = geometricScore_tangentGaussianMixture(X.cuda(), weights, Means.cuda(), CovInvs.cuda())\n",
    "logp.isinf().sum(), logp.isnan().sum(), score_true_train.isinf().sum(), score_true_train.isnan().sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[0.2154, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0681, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0681]],\n",
      "\n",
      "        [[0.0681, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0681, 0.0000],\n",
      "         [0.0000, 0.0000, 0.2154]]])\n"
     ]
    }
   ],
   "source": [
    "logX = Log_mat(X)\n",
    "logX.std(axis=0)\n",
    "print(Cov_sqrts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(5.1923) tensor(5.0006)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "\"\\nax = fig.add_subplot(111, projection='3d')\\ndata = Pndataset.train_data\\ndata2 = Pndataset_testscore.train_data\\nax.scatter(data[:,0], data[:,1], data[:,2], c='r', marker='.')\\nax.scatter(data2[:,0], data2[:,1], data2[:,2], c='b', marker='.')\\nprint(torch.max(torch.abs(data)), torch.max(torch.abs(data2)))\\nax.set_xlabel('X')\\nax.set_ylabel('Y')\\nax.set_zlabel('Z')\\n\""
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAnrElEQVR4nO3dfXicZZ0v8O9v3pJJi03aBErayosX2z1QCoXActmKSoGKQE+tEg4onF3ForgUUFtaV0uoL5TWFenZi6Pdoi6KsgFKKHDYisCKQVFSAuFNjkfAbSZg09IEkkwzL8/v/PHMTOflmcwkmZnnnvT7ua5ebeflee5OM9+5537u+3eLqoKIiMzlcbsBREQ0NgY1EZHhGNRERIZjUBMRGY5BTURkOF85DtrY2KjHHntsOQ5NRDQl7d69e5+qNjndV5agPvbYY9HV1VWOQxMRTUki8pd893Hog4jIcAxqIiLDMaiJiAzHoCYiMhyDmojIcGWZ9VFpHd0hbNn1GvoGwmiuD2LNsvlYsWiO280iIiqJqg/qju4Q1u94EeFoHAAQGghj/Y4XAYBhTURTQtUPfWzZ9VoqpJPC0Ti27HrNpRYREZVW1Qd130B4XLcTEVWbqg/q5vrguG4nIqo2VR/Ua5bNR9Dvzbgt6PdizbL5LrWIiKi0qv5iYvKCIWd9ENFUVfU9ak7NI6Kprqp71JyaR0SHg6J61CLypoi8KCLPi4gx9UuLmprX0w7ctgBoq7d/72mvbCOJiCZpPD3qj6rqvrK1ZAIKTs3raQceWg1EE38f3GP/HQAWtlaghUREk1fVY9QFp+Y9vvFQSCdFw/btRERVotigVgC/FJHdIrLK6QEiskpEukSkq7+/v3QtHEPBqXmDvc5PzHc7EZGBig3qJap6GoALAHxJRM7OfoCqblPVFlVtaWpy3Par5FYsmoNbVp6MOfVBCIA59UHcsvLkQxcSZ8x1fmK+24mIDFTUGLWqhhK/7xWRBwCcCeCpcjasWCsWzck/w2PphswxagDwB+3biYiqRMEetYhME5Ejkn8GcD6Al8rdsJJY2ApcvBWYMQ+A2L9fvJUXEomoqhTToz4KwAMiknz8z1X1P8raqlJa2MpgJqKqVjCoVfV1AKdUoC0lwZWKRDTVVPXKxGxcqUhEU1FVz6POxk0EiGgqmlJBzU0EiGgqquqhj+zx6Po6Pw6MRHMex00EiKiaVW1Qd3SHsObeFxC1FIA9Hu0B4PcKonFNPY6bCBBRtavaoG7b+XIqpJMsAFZaSDfU+XHTxSfxQiIRVbWqHaMeCB8a4lju6URnYDVer7kcnYHVWO7pBAAcjFpuNY+IqGSqtkedtNzTiU3+7aiTCABgruzDFv8P0Ya7UI8h7H2wCfB+h4teiKhqVV9Q97QDj2/E67W96LNmoV6GUiGdVCNx1GAIADAb/axBTURVrbqGPpIbAQzugQeKuZ59mIaDhZ/HGtREVMWqK6gfvTFnIwC7BEkRBvdwGy4iqkrVE9Q97UD4nckd46HVDGsiqjrVE9TjGLpQzXMHh0CIqApVT1AXuX3WiAZwV/zc/GHNbbiIqMpUT1Dn2T4r6gkipI2wVNBrNWJd9CrcFPssQtro+PiR4OxytpKIqOSqJ6iXbrC30UrnD+LbnquxeHQrjh+9G0siW7HTWgIA2BxrxYgGMh4+ogFsjl5aqRYTEZVE9QR1nm21fjJ0puPDd1pLsC56FfZb06Fqj1sfRAAHRiKOjyciMlV1LXjJ2larozsEwfNQ2CsU2/x3oSGx0OUdnY6HrbMQlEhqCt9MDGFT4E6gZxEXvxBR1aiqoH525w8x77ktOFL7sVea8HtcBsVZWO7pxHf92xCQWOqxs2QIV8iv4MmaZx3EqD3zg0FNRFWiaoL62Z0/xILdX0dQIoDYS8O/oT/AsCeGtb72jJBOyg7plME95W0sEVEJVc0Y9bznttghnaZOIljra0ez7BvXsWLqweJNT6CjO1TKJhIRlUXVBPWR2u94e7PsR1+eqXj5eGClNr5lWBOR6aomqPdKk+PtfToLm2OtiGjxozgHdDoAbnxLRNWhaoJ6z2lrEM6aFx3WAH577DV4FB/CV6OrENfiKjSlF3LixrdEZLqquZh4xvKr8SyAE577Jmboe4AAPp8Pn+jdgkv8h8JWtXBFvXoMp/7MjW+JyHRVEdQd3SE8/8g2rI5uxwwZSgWxPz5i/8EhmFUBCwKv5Bb9sCB4veZyvIVG9J24FsA55Ws8EdEkGT/00dEdQucDd2Bt9A7MlCGnTHZkh7nmjF0rAJ9Y8AgwR/bhjOfWArcex/KnRGQs44N6y67XcD3uydluqxheARSKd3Q6LBXE1OMc9OF3WKuaiIxlZlD3tAO3LQDa6vHvI58f9zzpdDUSx4jW4vjRu+HBGLuSs1Y1ERnKvDHq5L6IiS235nr2wcpXW7pIc2Qf3qi5HBYE9uBHHqxVTUQGMq9H/fjGnH0RPYJJhbWI/csrmn9DASBvzWsiIjeZF9R5erUCoNeyNwgYM2wLEEGq7GkGf9CueU1EZBjzgjpPrzYOD5plHwYwbdKnEAGi8GEARyC9tjUr6hGRicwbo166IWOMGjg0pQ6wa0qXQkBi2GvVoH4jx6WJyGzm9aizd3IRb9Fzp8er2bO/TEcmIiqdooNaRLwi0i0iD5ezQQDssL7hJaBtANAxptQVoDr2RciD3OiWiKrAeHrU1wF4tVwNyWsSMzFEHFeXA7CHU+ou4LxpIjJfUUEtInMBXAhge3mb42DphpzdxMcjX4GmiH8GLx4SUVUotkf9fQBrgbGW9pXJwlZs9l8zqSl52SJSg5qLv1u6AxIRlVHBoBaRiwDsVdXdBR63SkS6RKSrv995N5aJOvXCVRhGzaSPowrE4MGe93+CvWkiqhrF9KgXA1guIm8CuAfAOSLys+wHqeo2VW1R1ZamJufdWCZqhfdpTJPRSR9HBPDBwvFv3oPRb7+fRZiIqCqIjmNMQUQ+AuCrqnrRWI9raWnRrq6uybUs3W0LxrVzePKfVGgDASRrf8yYZ8/fZi+biFwiIrtVtcXpPvMWvDjQwd5xzaUuHNCpI9u/De6xF9kADGsiMs64Fryo6n8W6k2XWkd3CH06q/wnYplTIjKUeSsTs2zZ9RpujbZiVL3lPxnLnBKRgYwP6r6BMHZaS6CVaCrLnBKRgYwfo26uD+L0dx9DDaLlPRHLnBKRoYzuUXd0hzA8GsNaX/s4LhBOAMucEpHBjO1Rd3SHsH7HiwhH45hTM/E9E4tyw0vlPT4R0SQY26Pesus1hKNx3Oz7UZnPJFz4QkRGMzao+wbCWO7pxJXeX5V32AOK/ffdgMWbnkBHd6icJyIimhDjhj46ukPYsus1KFD+semEmTKE0999DOt3RAAAKxbNKf9JiYiKZFSPOjkuHRqwt+FqljKPTSeI2B8K4WgcW3a9VpFzEhEVy6igTo5LJ/VpY8XO3Sz2tlx9A+ECjyQiqiyjgjo7JDfHWktah3rMcyeWqTfXBytzQiKiIhk1Rp1c3LLW145m2Yc+bcRvrJPwIc/LZR2rHtEANsdaEfR7sWbZ/PKdiIhoAowK6u+f+Ccs2L0dQbEv6s2VfZip75b1nKpAEBF8LXAvrjjtWJyx6GNlPR8R0XgZNfRxxp//Vyqkk+qy/l5qIvav2ejHGS/exDnVRGQco4I6X/W6SkzRA8BSp0RkJLOC2oTqdSx1SkSGMSuol26wq9hVQN7ZJCZ8WBARpTErqBe22lXsZswDIIiVsXkWBCMayLgtjABLnRKRccwKasAO6xteAtoGcHfsnLLNo/ZAsS56FXqtRlgq6LUasRFfYKlTIjKOUdPzsi31PF+2C4nJ/F8S2Zq6TQDcUp7TERFNmHk96jTlrPXhEeB2/x3oDKzGck+nfT6uSiQiA5nbo+5pR7lXj4vYi2o2+bcjoB4sWXZNmc9IRDR+ZgZ1TzvQcQ28FZo/XScRXK/34NJdSwGwzCkRmcW8oO5pBx74AqDxwo8toWbZj9BAGOt3vAiAYU1E5jBrjLqnHXhodcVDGgAEis7AapwX/zVrUhORUcwK6sc32su4XSACzPXY49Ut7z7mShuIiJyYFdQVXr7tNEe7TiJYH7i3ou0gIhqLWUFd4eXb+eZoH4V+VtEjImOYFdQVrPUxFgHssXKGNREZwKygTtT6GPXXl30OdUEseUpEhjArqHvaMfLoBvgjgzig0xHXShWizoMlT4nIAObMo05MzauLhgEBZmIIltoX/Cq2cUA2ljwlIgOY06N2mJrncbFDHUYNS54SkRHMCWq3t+ECUj34XqsR6yKfY8lTIjKCOUMfM+YCg3tcO70q8NP4ubgp9lkAgFcEH+0OcSk5EbnOnB61y8MMIsAl3qdSJU/jqli/40V0dIdcbRcRUcGgFpFaEfmDiLwgIi+LyM1lacnCVsA/rSyHLladRNDmvyv193A0zrofROS6YnrUowDOUdVTAJwK4GMiclZZWhMdKcthx6MBQ6leNQD0DbhTe4SIKKlgUKttKPFXf+JXedajGDAdTgRY6zu0IpG7vhCR24oaoxYRr4g8D2AvgMdU9fcOj1klIl0i0tXf3z+x1hiyhHyO7Ett0bVm2Xy3m0NEh7miglpV46p6KoC5AM4UkQUOj9mmqi2q2tLU1DSx1iSWkCeqbbgmWfL01sB2rPA+7WpbiIjGNetDVQcAPAngY2VpDWDU3OUgIqz3QUSuK2bWR5OI1Cf+HARwHoA/lrVVBoxVp7DeBxG5rJge9dEAnhSRHgDPwh6jfrisrTrhfPer5yWZ9KFBRIelgisTVbUHwKIKtOWQP/3S5VHqBH/Q9YU4RETmrExM6OgOwXJxuCEOASDAjHn2hU2DxsyJ6PBkVFB3dIew5r4X0GfNcq0NA9Y0LK7dgY6P7GJIE5ERjArqLbteQzSu2BxrxYgGXGlDgwwjNBBmnQ8iMoZRQZ1crr3TWoJ10ascdwkvexvU7s2zzgcRmcKooE5frr3TWlLx80fUh82xQ8MdIdb5ICIDGBXUa5bNh99rz/dIL4xUCarAV6OrMj4gBODwBxG5TrQM4wstLS3a1dU1oed2dIdw4oPn4wTtrejuLnEVCBR92ojNsdZUYM+pD+LpdedUriFEdFgSkd2q2uJ0nzk7vCSs6PkigN6Kl/vwiv2BNVf2YZN/OxC1h19Y5pSI3GbU0AcA4I1fu90C1EkkVeqUZU6JyG3mBbUhmmU/gn4vy5wSkesY1HnslUbcsvJkbm5LRK4zK6h72gGPz5X50+lUgaPQjxW/XGK3iYjIReYEdU878NBqwIpVdLaHE5HEtczwO8ADXwBuPQ5oqwduW8DgJqKKM2fWx+MbgaiBMyw0bgc2AAzusT9MANYBIaKKMadHXS0F+qNh7vpCRBVlTlAHG9xuQfGq5UOFiKYEc4K6mnDXFyKqIGOCWsMH3G5CcbjrCxFVmDFB/Vc0ut2E/GbMA3d9ISK3GBPUt0QumdRmAY9Mq8P5c5ux8Nh5OH9uMx6ZVleSdo0Ej8bi0a047uDdWDy6FR3xxSU5LhFRsYwJ6q73nTfhzQIemVaHtsaZeMvvg4rgLb8PbY0zJx3WMW8tNgx/EqGBMBTgzi9E5ApjgnrNsvkI+CbWnNsb6nHQk/ncgx4Pbm+on3B74gp8I74K90U+mHE7d34hokozZsHLCu/TuMi/HRIf/3Pf9nnHdXsxBIJfHDzL8T6WPiWiSjKmR43HN8IXPzihp86OOad7vtuLkdw70QlLnxJRJZkT1JNYRHLdgQHUWlbGbbWWhesODEzoeKrA5lgrgn5Pzv4FLH1KRJVmzNDHSHA26sJvTei5Fw6PALDHqt/2eTE7Fsd1BwZSt4/XMGqw01oCLxTp1zYFwCdPn8PSp0RUUUYEdUd3CJ3Dn8RG2YY6iUzoGBcOj0w4mLNF4AcAxK3MKSgK4Mk/9pfkHERExTIiqLfseg2hyAcR8Vi43X+H62VO6zGc9z5eSCSiSjNijDoZfsmdv93GC4lEZBIjgtqk8FMFHrdOdbyPFxKJyA1GBPWaZfMR9Ntzng9guqttEQEu8T6F5Z7OjNu9ItxDkYhcYURQr1g0B7esPBkNdX60Ra9ERN0dOq+TCNb6Dm25FfR78c+tpzCkicgVRgQ1YId194bzcc4l/4ivRle5vsFts+wHAMypD7InTUSuMmLWR7oV3qdxrv9Ot5uBPp0FAfD0unPcbgoRHebMCurETuTTZdTVZiRXJpp0kZOIDl9mBbVhO5F/9G+bsHjTE+gbCKO5Pog1y+ZzCISIKq7gGLWIzBORJ0XkFRF5WUSuK1trDNk0VgT4p5p7cf/uEGtRE5HrirmYGAPwFVU9EcBZAL4kIieWpTUG7UR+pPbjvPivM25jLWoickPBoFbVt1T1ucSf3wPwKoDyfP+PuTs2nU4AbPJvz5lPzSXkRFRp45qeJyLHAlgE4PcO960SkS4R6ervn2Dhomj+GhtuyJ5PDQD1dX6XWkNEh6uig1pEpgO4H8D1qvpu9v2quk1VW1S1pampqZRtdFVyPnXS0MEYx6mJqKKKCmoR8cMO6btVdUfZWhOcWbZDT1R2gaaopRynJqKKKmbWhwC4E8Crqvq9cjbm2f+2DhGd+D6HpTaiAWyOtebcznFqIqqkYnrUiwFcAeAcEXk+8evj5WjM9a+cgK9Gr8Z7Vo2rS8hVgXd0OjboKsfSq1wIQ0SVVHDBi6p2AjlbB5ZF30AYp3uAaTLqyuYBqkBIG7E51poKaL9HEE3b6YWlTomo0sxZmdjTjt/Vfg1Hab9rO7xYAJZEtmbcNr3Wh7qAj6sTicg1ZgR1osbHbIQr1Hd35gGw3NOZMdwxMBJF94bz3WsUER32zChzakiNDxHgu/5tGYtcOB5NRG4zI6gNqfEBAAGJoc1/FwDA7xWORxOR68wI6hlz3W5BhgYMoaHOjy2f4q4uROQ+M4J66QbAnznE4OYGLyJA94bzGdJEZAQzgnphK3DxVmDGPACS+N1F/mnunp+IKI0Zsz4AO6wXHloFGL+pAT6x3GmLr8ad8xIROTCjR+1gtyxwb3Vi+IBLJyYiymVsUJ8cdG/hi2kXN4no8GZsUNeF33bv5Es3uHduIqIsxga1a71a8WSMlRMRuc3coD7BpWXbp/+DO+clIsrDzKDuaQde+HnFTqcKWAqg5XPARWUtuU1ENG5mBnWFa3+IAG9LI0OaiIxkZlC7UPtjtu7nXohEZCQzg9qFC4l9Oot7IRKRkYwM6mc/cC3CqNzqQEuBzbFW7oVIREYyLqg7ukO48tljcGPkc+i1GmGplHWFoirw0/i52GktYe1pIjKScUG9ZddrCEfj2GktwZLIVvw0vrSs53tHp+Om2Gfh97D2NBGZybigTh9+WO7pxJXeX5V1KXmDDCPo92DLJaw9TURmMqd6XkJzfRChRFi3+e8qe72Pg3Wz8eqNF5T3JEREk2Bcj3rNsvkI+r1Y7ulEA4bKfr66CzaW/RyHtZ524LYFQFu9/XtPe2WeSzSFGNejTg4/nPXgP1ZmQ3LW9SheT7u9GGmw155CuXSD8+uXetwe2NvKJ64GD+4BHlpt/7nQ657YmT618Gk8zyWaYozrUQN2WM/GvvKfyO2dZKrJw18GdqxKhK8eCs7sXm4yYAf3JG7ImrITDQMPfKFw79hpdWo0bN9OdJgxrkedMmNu2pu99MKoQbDKypl2dIewZddr6BsI439O/wPW+v/dLgc7Vu+2FHraga4fwTF0H9+IjvjiVLt+V/s1zEaB+egaR+zBa/GtnS/j34bORHN9EGuWzc+8mJtvdWrW7cnXJDQQhlcEcVXMcTpelvTX0vH8yX93Md8giMpMtAyTlFtaWrSrq2tyB8n+6ltCybnTV37r/pIfe1LGCIaO7hDW73gR4Wgcyz2d2OTfjjqJpJ4a89biQesj+Lt4F5o9+7Hf24RvH7wEHfHF8Irgsr+bh5ZjZjoE/VuAeAGN298wkh9eyXYEG4CDA4A6b4tmQXBS/B6Eo3EAwOs1l8NT5JjVOzodI1qLZtmHt9CIvtPX4ozlV9t33rbA8YP6bTThmf/+a6xYNCfjNckW9Htxy8qTHcPa6Xk5j3f8+ROg5bOsCUNlISK7VbXF8T5jgxrIHOsUT96wmIi30YTZbf9v0scpqmdWxPNb3n0MmwJ3IojR1H0WgP5ZZ+GoWAjWYC/6rFnYHGvFTb67MMuTe6HVUmSE5IgGsC56FXZaSwDY91kKx6BP8QaAeCxx9sJ6rUZsjrVira8dzbIPCsBbZFCrImNWT7K90wI+tPn/DTXRwYzHJ+8HgDb/T9Eg7wFqt1QA9KndluS/d059EE9/fF/Oh9/i/9OI0EAYyz2dqXb3aSO2Bz6Dtq/fbJ8szwcFIMDKbaXrWTt8OKd/Q5nIz1S1mux7qdpVb1BnS/xQ6+AeWBB4Em2fyBQ+hUDaBibVnKJ6Ztl62jHy6AbUht9OBe9Oawk6A6sx15M7Lp8dZsn/rmL/zem91gM6HSJAA4ZKMu0x/UenVNMoVe0BFk/Wv/kApuOh+FlY6fkNpsto3vOlP/YizzOY6RnKuCidHvbf9W9DQGIZz5UzEqVu2+qRM9STPEbwaNTd+EfH+8YVNg699pi3FuuiV+G+yAdTtwmAW//mj2gd/PHEh2Gcvq0BGbc9+4Frcf0rJ7gSlON5L03VQJ86QQ2Ubkhkxjzghpcm9NT0cdF8/t5pDBlA7MFr4YsfTD0uGRzf999R9JDBeGQHfbUa7wfUWP/umHpwEH5Ml9Gc+1SBIa1BVPyYKc7TQ1WBPjTi1mgrTvf8X3za+wS8ad9AQmk9e79X8O0PvIqz/+t/40jtx15pwp7T1iA07yKc9eCHMRv9OcfvtRqxJLI19fflns6cDxUAjvXT07+hrQ/ci6OwDxJsACJDQNzhG1SasAZwY9o3MACoD/rRtvyksgZhR3cIX2l/AXGHLJpTH8TT687JeOy4O0dVYmoFdd6vpOPgDwIXb53Q19exxkUB+011k+8uzJTMXqsCGFU/aiWa85y4CjzQKRGo1aLQB1gxHwxjPSb5zWBYa3K+ART6cLZUcH30i2nDSQKv5HmfrvzXnOsY58V/nX9oq4CYevDl6BcywtoxCEs0ZFPo/SQA3th0Yervizc94dhByg70ajRWUJs76yOfImpVpz57BLlzsYMzgQtuHXdId3SH8Pwj27A6uh2veIaAGnts1AO7B1YnEXiQfyhGANQgN6QB5H8TjtNU6T1XQqHXqZjXcazHSOJn7wiHXnudRHC7/448AyuAQHG7/4604+f/+YjtuBpfuacbXe87D8OjMYSjcawNtE8opAHAJxa2+H8IRIGd1hJ7HF/a0fzgfuA/7UB+9s0DOOW59QggEa6De6A7Po+z9Qg8Eb0CISxBaCCM9TteBIAxwzpZ2yef7EJp+SpcTvXKl1OuR51+Ae3vp/8BbdPun/T0qo7uEDofuAPfkR/kfv0kclnyLTyMWvgQQw1ik/7AVrWvbxwhBzN+5mPeWhyMqeOwEZB7ATujp+vQCz/u59Pyfgw59eTz9ai9Ivjn1vLU66nUmPjUGvpwGKO2Ev+E7Kv+2V+bxjym07S4xO3WYC8sFfikdLNOiKpVoW9u6WPsqffgw18Guu7MfKDHjzbPl/CToTNzjpEveLOHStJn7uRM8SyBSo6JT62gBuwAffRGIPwOFMABnY626JUZ42qA87hV9qfj90/8E8548abMi5P+IHDK5Yh1351x4Y+IClMFjhu1N6dOTZHc8Xnnx8LuYN0abc14/15W+wy+Ebz30MX4E84H/vRLYLAXI8HZ+Nq7K2Gp5ozFh1GD4Mp/mfD0yex8GInEcGAkd8gymS0TXXDlZFJBLSI/AnARgL2quqCYE1YkqLN61dlfufxewZZPnZKxgGHk0Q2oHXkbfXpoWtzTNasxR3KnxcXhybiST0TFsRS4PnpN6r3429rVaC5QEiKMAG6MXJUaF88OYAVyplmGNeC4nqCoGV15LoaOdWEznQC47dJT8z4+J3+KMNmLiT8B8C8A7ir6jOXmUAeiTiJY62vHzsgSeAS49Ix5OavM6qJhQIC5sg+b/NuBKHB0nh8gj1oOVyKpGI9Mq8PtDfV42+fF7Fgc1x0YwIXDI243iyrEI8D3/D9IXZCcrfsKvpeCOPT+XevLvRia/fQ6iSCIPBdMnSYcpAdzsAEYfQ+wEj3lRN2a5/VqhKO5wzBOmuuDY14IjcYVNz/0csmGRwoWZVLVpwC8U5KzlUqemR/Nsh+A/Yl+/+7QoV3F8wS7feWdaVxKj0yrQ1vjTLzl90FF8Jbfh7bGmXhkWp3bTaMKSs4eea5mVdHvsDmyD8s9nWh2+IY7LtmbY2cXFAu/cyikk6JhfD26Fa/XXI7OwGos93TmPXzQ78WaZfMLzjRxGjKZKCOr5xWUZ5fyPp2V+nM4Gj+0q3ieYBfJPzWO09wm5vaGehz0ZP5YHfR4cHtDvTsNItfUSDxnPcFYRIBN/u0YwPSiHj8gR+Rsgh1GDa7rvxiLNz1hd9TyFRRz4BMLHgHmeuxv3Mmwrg/6Mac+CIE9Np28kFjJPVZLFtQiskpEukSkq78/d7VVSS3dYF/wSzOiAWyOZV5ASH3i5Ql2Kr23fd5x3U6Urk4iaMBQaiZXUvaltBENoC1yhX3hcMY8KAQhbcSNkc/hQcuex73m3hcw8ugGFBPSTu1Y62uHABgI2z3j2y49FU+vOyc1nJHc5CSf+qB/3OfNp2RBrarbVLVFVVuamppKdVhnC1vtlYUz5gEQvI2mjAuJSalPvCorZ1rNZsecx+zy3U7VowwTxByJ2OPcCvtXr9WIu+LnotdqhKWCXqsR66JXoet959lZcMNLWCTtWDy6NZUByz2deNJ3LYLhtybcjmbZn4r45AKe1HAq7IU8t6w8GQ11uYHs9wjalp804XNnq76ViUkLW1NTcJ7pDuGxHS8CVuZcx9Su4gtbU9P5qLyuOzCAtsaZGcMftZaF6w4MuNcoKolSDwcWmo8tADBjHn5w3C9w9zP/hZvS7gv6vbgl+f7GoV4vUKA65DikD6UCh4ZT0y8Qrlg0J1Vyt5yLYgoGtYj8AsBHADSKSC+Am1T1zrGfVVnJF2TMF+qCW/MWc+LS69JJzu7grA/KR9UuXPW4dSou8T41dqAO7sG33rgM36ztxV/RiFsil6DrfeeNGYROs0bGy2koFci/VD0Z2OVSMKhV9bKynb2ECr5QyQnwiTKpcfXAAwt9iR+YpZ7n0Sz7YcF5BSLDvHgXDo8wmKeoyb4Pstc77Lb+Bmt97Zgj+/IcV4DBPRAAs9GP26f9GPj4ImBh5kK2hjp/apbFRGaNjKoXwwiiHsPYK434gf/T2DmaO1WvkhcQ01XnysRJylexSwS4WDpxq387gmmfyCMawL3xs3Gl91dj1kFOHoPIbaXoWDht7OCBOlaAtBSIwZdRFyR5QdBKLB5Lln99quajGI1ZGe8/5+GKtI2R0zksaOnoDmHNfS8gGte8td0zBGeid9iDZtmfsQAueVanxSzlLqc6tarnlcDYQyUXAj2LELp/PY5G5n/ixd5nMBO5K6FUgd9YJ+G++IcdS5wSVdKIBtBlnYAPeV4e989het2c9G+ayfcBkLvhQkR96DjmaziucTrmPbcFR+o+7JVGPHXMF3HTGyflhl3iIlvy/Tcj6Mdv5KNYfxCHamiPtWeqw3Tb9Pf0lndbc3ZLyuAPAhfciksTO/1ka64PFjecWkGHZY+6GE697k/4nsZ3fP+asSIqGdJXRv8pdVu+fQPLsSMKVT+nb2Pj6RGr2ntXeqAZmxbc7PuRvamB2JuViUPvNAbBuzoN9RjO6JQIgE+f9X60HDMTbTtfzrlYZxdC2o+3MGvMQkiTusiWr1LmeJeIBxvs28IHMoqumbYJwdQrylQhjj9k3qdTPwAjwdnYMPzJjG2Tgn4vdk+/3t40Nkuyqtifaz8Nr8ObZjzDJ2O9kTme7h6n117Vrh2Tee0j8bV+xjw8+4Fr8UhPH66K/CzVe3W+0CZQKAZxBOKWosEzjEH/kbhdL3OsQJc8S6qCZFoxMwAZtdmLDdSKboPltJvTJDb9cGLStl4M6jLKG+Z5ikY95v0wHj7+ARz/l3syltYme+ZNMoC/9STmaubpgUdV8PP4UnzG+6ucjWSTxzle/pq6qFKOLb6yz5mt0M4olf4gSc7JTV84kN2OmAoseMesOT5W20c0gPviZ2OltxPTYFddtAD8LH4u3qhdUFRt9PR6y+k9173SiNkrv5M3oKbszif5ShBPQQxqN6Q24nWeVvTnH1+NY95shxcW4vDg7vg5aIt9Fp8+6/341oqTHY+Vvgnp57uPw9mjT+Lbvjsx3WOPxYl4gNP/Abjoexl7530ncCfqMHro0oxTz90bQNyy4NXMzV6TMh7r8WNY6hCMvZtzIQbAoa/csBxDLbnh7hzPPsc6ENk/kunHGFUvBDL2Bg7iBTTtQnGiF3bdPd1Ykwi+ZK/VaQz2Rr9d31iyLmYlLyqf67WfI8EGjMYs+KOD6LNmYXvgM4gt+BTu3x2a8NfpiX4dN+1rPI0fg9pQrn7tyrMrdfoO6d+NX4qO+GLnjXodvi4Pj8YKLjwIowYvnfZNe0yzp91xs98NugpLPnGN/VoU2j3bafwx/f609ubrdTbU+VEX8OX+P0ywNzfZ/9eJPt+kr/E0fgxqqginXt2nAr/Fxmn354R8SrJOeOLDYXvgMzj1wlVl21KJvU4yFYOaKsb0Xp3p7aPDF4OaiMhwYwV1ddajJiI6jDCoiYgMx6AmIjIcg5qIyHAMaiIiw5Vl1oeI9AP4S8kPXJxGAJPcxnhK4+szNr4+Y+PrM7bJvD7HqKrjPoZlCWo3iUhXvikuxNenEL4+Y+PrM7ZyvT4c+iAiMhyDmojIcFMxqLe53QDD8fUZG1+fsfH1GVtZXp8pN0ZNRDTVTMUeNRHRlMKgJiIy3JQJahH5kYjsFZECu14efkRknog8KSKviMjLInKd220yjYjUisgfROSFxGt0s9ttMo2IeEWkW0QedrstJhKRN0XkRRF5XkRKWj50yoxRi8jZAIYA3KWqC9xuj0lE5GgAR6vqcyJyBIDdAFao6isuN80YIiIApqnqkIj4AXQCuE5Vn3G5acYQkS8DaAHwPlW9yO32mEZE3gTQoqolXxA0ZXrUqvoUgHfcboeJVPUtVX0u8ef3ALwKgNXy06htKPFXf+LX1OjFlICIzAVwIYDtbrflcDRlgpqKIyLHAlgE4PcuN8U4ia/2zwPYC+AxVeVrdMj3AayFvbE6OVMAvxSR3SKyqpQHZlAfRkRkOoD7AVyvqu+63R7TqGpcVU8FMBfAmSLCITQAInIRgL2qutvtthhuiaqeBuACAF9KDMeWBIP6MJEYd70fwN2qusPt9phMVQcAPAngYy43xRSLASxPjMHeA+AcEfmZu00yj6qGEr/vBfAAgDNLdWwG9WEgcaHsTgCvqur33G6PiUSkSUTqE38OAjgPwB9dbZQhVHW9qs5V1WMB/A8AT6jqZ1xullFEZFriQj1EZBqA8wGUbAbalAlqEfkFgN8BmC8ivSLyObfbZJDFAK6A3RN6PvHr4243yjBHA3hSRHoAPAt7jJrT0KhYRwHoFJEXAPwBwCOq+h+lOviUmZ5HRDRVTZkeNRHRVMWgJiIyHIOaiMhwDGoiIsMxqImIDMegJiIyHIOaiMhw/x89I4sBVnlMTwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "data = Pndataset.train_data\n",
    "data2 = Pndataset_testscore.train_data\n",
    "fig = plt.figure()\n",
    "plt.scatter(data[:,0], data[:,pd_dim])\n",
    "plt.scatter(data2[:,0], data2[:,pd_dim])\n",
    "plt.scatter(Means[:,0,0], Means[:,1,1])\n",
    "print(torch.max(torch.abs(data)), torch.max(torch.abs(data2)))\n",
    "\"\"\"\n",
    "ax = fig.add_subplot(111, projection='3d')\n",
    "data = Pndataset.train_data\n",
    "data2 = Pndataset_testscore.train_data\n",
    "ax.scatter(data[:,0], data[:,1], data[:,2], c='r', marker='.')\n",
    "ax.scatter(data2[:,0], data2[:,1], data2[:,2], c='b', marker='.')\n",
    "print(torch.max(torch.abs(data)), torch.max(torch.abs(data2)))\n",
    "ax.set_xlabel('X')\n",
    "ax.set_ylabel('Y')\n",
    "ax.set_zlabel('Z')\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(Pndataset, 'P'+str(pd_dim)+'TangentGaussianMixture'+data_date+'.pth')\n",
    "torch.save(Pndataset_testscore, 'P'+str(pd_dim)+'TangentGaussianMixtureForTest'+data_date+'.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Three mixtures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2.6101572156825372, 0.8254041852680184)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd_dim = 3\n",
    "vec_dim = int(pd_dim*(pd_dim+1) / 2)\n",
    "r = math.sqrt(10) # std condition number\n",
    "c1 = math.pow(r, 1-0)\n",
    "c2 = math.pow(r, -0)\n",
    "c1, c2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.1875"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "0.25*0.75"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[3.4033, 0.0000, 0.0000],\n",
      "         [0.0000, 1.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 1.0000]],\n",
      "\n",
      "        [[1.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 3.4033, 0.0000],\n",
      "         [0.0000, 0.0000, 1.0000]],\n",
      "\n",
      "        [[1.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 1.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 3.4033]]])\n",
      "tensor([3.0000], device='cuda:2')\n"
     ]
    }
   ],
   "source": [
    "N = 10000\n",
    "Nmix = 3\n",
    "\n",
    "pd_dim = 3\n",
    "vec_dim = int(pd_dim*(pd_dim+1) / 2)\n",
    "\n",
    "dist = math.sqrt(2*vec_dim)/2\n",
    "\n",
    "data_date = '210912m'+str(Nmix)\n",
    "\n",
    "r = math.sqrt(10) # std condition number\n",
    "c1 = math.pow(r, 1-0)\n",
    "c2 = math.pow(r, -0)\n",
    "var = 0.01\n",
    "Cov_sqrt = math.sqrt(var)*torch.eye(vec_dim)*c2\n",
    "Cov_sqrts = torch.zeros(Nmix,vec_dim,vec_dim)\n",
    "\n",
    "for i in range(Nmix):\n",
    "    Cov_sqrts[i] = Cov_sqrt.clone()\n",
    "    j = rowcol2idx(i, i, pd_dim)\n",
    "    Cov_sqrts[i,j,j] *= r\n",
    "\n",
    "CovInvs = torch.zeros(Nmix, vec_dim, vec_dim)\n",
    "for i in range(Nmix):\n",
    "    CovInvs[i] = torch.inverse(torch.mm(Cov_sqrts[i], Cov_sqrts[i]))\n",
    "\n",
    "Means = torch.zeros(Nmix,pd_dim,pd_dim)\n",
    "m = torch.FloatTensor(Nmix, pd_dim).zero_()\n",
    "for i in range(Nmix):\n",
    "    m[i,i] = dist/math.sqrt(2)\n",
    "Means = torch.diag_embed(torch.exp(m))\n",
    "\n",
    "print(Means)\n",
    "print(squared_distance(Means[0:1].cuda(), Means[1:2].cuda()))\n",
    "\n",
    "Pndataset = PndataTangentGaussianMixture(N, Means, Cov_sqrts)\n",
    "#torch.save(Pndataset, 'P'+str(pd_dim)+'TangentGaussianMixture'+data_date+'.pth')\n",
    "#Pndataset = torch.load('P'+str(pd_dim)+'TangentGaussianMixture'+data_date+'.pth')\n",
    "\n",
    "N_testscore = 10000\n",
    "Pndataset_testscore = PndataTangentGaussianMixture(N_testscore, Means, Cov_sqrts)\n",
    "#torch.save(Pndataset_testscore, 'P'+str(pd_dim)+'TangentGaussianMixtureForTest'+data_date+'.pth')\n",
    "#Pndataset_testscore = torch.load('P'+str(pd_dim)+'TangentGaussianMixtureForTest'+data_date+'.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.3333, 0.3333, 0.3333], device='cuda:2')\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor(0, device='cuda:2'),\n",
       " tensor(0, device='cuda:2'),\n",
       " tensor(0, device='cuda:2'),\n",
       " tensor(0, device='cuda:2'))"
      ]
     },
     "execution_count": 117,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = vec2mat(Pndataset.train_data.cuda())\n",
    "weights = torch.cuda.FloatTensor([1/Nmix]*Nmix)\n",
    "print(weights)\n",
    "logp = log_rho_tangentGaussianMixture(X.cuda(), weights, Means.cuda(), CovInvs.cuda())\n",
    "score_true_train, _ = geometricScore_tangentGaussianMixture(X.cuda(), weights, Means.cuda(), CovInvs.cuda())\n",
    "logp.isinf().sum(), logp.isnan().sum(), score_true_train.isinf().sum(), score_true_train.isnan().sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(9.4716) tensor(9.8082)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAhYElEQVR4nO3df3TcdZ3v8ed7fiSdNNq0TZeSFKh6vXX5USkGYW3BsxTtYZFSOVpZFll/3eJRoYi2tK5bStcVpKwCrrj0FveuB/wRsMRWLhYo6lLPEdsSCL/s5ShKm7TQXym0SZvJzOf+MZlkkswkk+Q7M59JX49zODQz3/nMO5PkNd/5fD8/zDmHiIj4K1TqAkREZGgKahERzymoRUQ8p6AWEfGcglpExHORQjRaW1vrZs6cWYimRUTGpR07dux3zk3Ldl9BgnrmzJls3769EE2LiIxLZvaXXPep60NExHMKahERzymoRUQ8p6AWEfGcglpExHMFGfVRKE3NrazdvJO29k7qamIsWzCLRXPqS12WiEhBlU1QNzW3snLD83TGEwC0tneycsPzAAprERnXyqbrY+3mnb0hndYZT7B2884SVSQiUhxlE9Rt7Z0jul1EZLwom6Cuq4mN6HYRkfGibIJ62YJZxKLhfrfFomGWLZhVoopERIqjbC4mpi8YatSHiJxoyiaoIRXWCmYROdGUTdeHiMiJSkEtIuI5BbWIiOcU1CIinlNQi4h4TkEtIuI5BbWIiOcU1CIinlNQi4h4TkEtIuI5BbWIiOcU1CIiniuLRZm0V6KInMi8D2rtlSgiJzrvuz4G7pW4MLSVx+2LLPz5GfCdM6GlsYTViYgUnvdn1Jl7Ii4MbeW26HqqrCt1w+FdsOn61L9nLy5BdSIihed9UP9j9e/5XNf91Nl+koSIWLL/AfFO2LJGQS0i45bfQd3SyD8nvkc4FAcgRDL7cYd3F7EoEZHi8ruP+tGbCLv48MdNmlH4WkRESsTvoO48OPwhVLLtXdcVoRgRkdLwO6hzcA6SztidrOWmrs9yzbbTaGpuLXVZIiIF4XcfdWxK1rPqg66a93Wt67shmWDt5p0aVy0i45LXZ9Tb/noFXYT73Xbchbml+5pBx2YO4xMRGU+8Deqm5lau2XYaX+26lt3JWpLOOOiq6bYod0Xv4dXKq3i18iqeqVzCwtBW6mpipS5ZRKQgvO36WLt5Jx9K/IblkUbqbD+HXDWTrIMISbC+46ZwhDui63ju9JnARaUqV0SkYPI6ozazL5vZi2b2gpn92MwmFLqwhjcf57boemaE9hMymBo6MniyS48K6+bcP3630CWJiJTEsEFtZvXA9UCDc+5MIAxcWejCVlY82DdVPB+a9CIi41S+fdQRIGZmEaAKaCtcSSknsX9kD9CkFxEZp4YNaudcK3AH8BqwBzjsnHus0IXZCII3YVGYv6qA1YiIlE4+XR+TgcuBdwB1wEQzuzrLcUvMbLuZbd+3b9/YK5u/CqK5R3I4l/rvoKvmpsS1NCXmjv05RUQ8lE/Xx8XAq865fc65OLAB+MDAg5xz65xzDc65hmnTpo29stmL4bK7YdIpWe82Sz8vdHUnWbt559ifU0TEQ/kE9WvA+WZWZWYGzAdeLmxZPWYvhi+/QL/xeBmsZzTI2ui9NLz5eFFKEhEptnz6qJ8GHgKeAZ7vecy6IR8UtGH6qystwcqKB4tUjIhIceU14cU5dzNwc4FryaqpuZXEm2fwUbeLUPYTa2AUo0RERMqEt1PIIRXSWx++h0uSTw4Z0gCvU8s7VjzC3Nue1Ep6IjKueDuFHFLTyH/KT4ad+OIcbE68F4d2KReR8cfrM+q29k7qbPguDTOYH3q29+vOeEKjQERk3PA6qOtqYrS52vyOtQP9vtaypyIyXngd1MsWzOJOrqTDVQx7bDsT+32tZU9FZLzwOqgXzalnwjlX8lDiQrpdCOdyH5t5XywaZtmCWYUvUESkCLy+mAgQeeEhPhH+Vc4lTtMm21EA6mtiLFswSxcSRWTc8Dqom5pbua5rPZWhxLDHJjE+Vf17Vq+4pQiViYgUj7ddH03NrSx76Dmm2JG8jo9Ykq+7/4CWxgJXJiJSXN4G9drNO4knhuiUziKSOAZb1hSoIhGR0vA2qNvaO1kY2jryB2qnFxEZZ7zto54Ui7I80di7nGn+D9ROLyIyvnh7Rm1GXrMS+wlXaKcXERl3vA3q9o543rMSe1VUp9awFhEZR7wN6rqaGFuSZ5McyfXEzoOwuga+c6ZGf4jIuOFtUN95+issDv/3sMubDubg8C7YdL3CWkTGBW+D+tw/fpfYMMubDineqaF6IjIueBvUgQyz01A9ERkH/A3q2OSxt6GheiIyDvgb1CM0aGW9aExD9URkXPA2qF3noREdbwbdLkTSGXuZBpfdraF6IjIueDsz8XVqmc6+ET0mhOOdxx/AgFdnX1qYwkREiszbM+pbuz6e184umdK7vGh3FxEZT7wN6u1v/xAr4p9jd7IW5yDhbMgdXiDVT63dXURkvPE2qJctmMWm5Dzmdd3N0vgX2OOmDvuYyaGj3HrFWdrdRUTGFW+DetGcehywMLSVO6LrmBHaP+xKem8wwrVBRETKgLcXEwFCBrdH11Fh3cMe6xz8v+RfsezB5wB0Vi0i44a3Z9S0NNJS8RkqGT6kITU8b17oRS7hKVZvfLHAxYmIFI+fQd3SCJuup9qOjWjjgJDB8kgj7Z3xwtUmIlJkfgb1ljWpRZVGoc4OBFyMiEhp+RfULY2pZUpHKYnxp8qrtCa1iIwbfl1M7OnyGC3nIGLJ1BfpNalBU8lFpKz5dUY9hi4PYHB/ttakFpFxwK+gLsT60VqTWkTKnF9BXYj1o7UmtYiUOb+Cev4qYMSbJA7B4N0fDrA9EZHi8yuoZy+Ghs8E2KCD536k0R8iUtbyCmozqzGzh8zsD2b2spn9TcEq+si3oeGzw66UlzddUBSRMpfv8Ly7gF865z5mZhVAVQFrglceG9GMxGHpgqKIlLFhg9rMJgEXAp8CcM51AV0FrSroYNUFRREpY/l0fbwD2Af8p5k1m9l6M5s48CAzW2Jm281s+759I9tCa5Agg1Wb3IpImcsnqCPAOcD3nXNzgKPAioEHOefWOecanHMN06ZNG1tV81eNuY/aOSA2RZvcikjZyyeodwO7nXNP93z9EKngLpzZi3kqecaYwtoMqJiokBaRsjdsUDvn9gK7zCy9EeF84KVCFtX4g3/jnfb62BvSRUQRGQfyHfVxHfBAz4iPPwGfLlhFv7iRj/3lPkJBjPDWRUQRGQfyCmrn3LNAQ2FLITUxZfsPCAUwNK+TSmK6iCgi44BfMxO3rAHGPtMl4eCmrs/SlJg79ppERErMr/WoA+pTDgEbk/N4VBvdisg44NcZdcB9yvGkY+3mnYG2KSJSbH4F9fxVEIqOuZmjVPb+u6199BsRiIj4wK+gnr2Y4+FBkx5HrIIEC0NbAairiY25PRGRUvKrjxqIxg+PuY0K6+bb0f8gnDA+uOCLAVQlIlI6fp1RA4eSYz+jhtQmt2sr72NR+LeBtCciUireBXU4iEHUPSKJY1qLWkTKnndBPYkjwTaoaeQiUua8C2oLetq3ppGLSJnzLqiZv4pkUNtwYVqLWkTKnn9BPXsx9ycuDmTPRIfTMqciUva8C+qm5la2J/9nIG29zhg3MBAR8YB3Qb12805ujvxwzJvbOgePdb83mKJERErIu6BuePNxptjYR36YwYcjzwVQkYhIaXk3M3FlxYMENZL6JPYH1JKISOl4d0YdZLgGPtRPRKQEvAvq4MJVQ/NEZHzwLqiZv0pD80REMngV1E3NrcxpquGp5BljDuuEC6X2YBQRKXPeXExsam5l5Ybn6YwnOKPiL2MenhexJGy6PvWFzqxFpIx5c0a9dvNOOuOpBf+DGJ4HQLxTq+eJSNnzJqjb2jtZGNrKHdF1Yz6bzuS0ep6IlDlvgrquJsbq6A+psO5A232d2kDbExEpNm+CetmCWUwOeC3qhINbuz4eaJsiIsXmTVAvmlMfeJshYHJVReDtiogUkzdBDXCI6kDbM4Pl0Z8G2qaISLF5FdSr49cEMtklU1Xn3mAbFBEpMq+CesfbPxR8o7HJwbcpIlJEXgX1sgWzaA14lMbx7kSg7YmIFJtXQb1oTj173rc80O6PaPzN4BoTESkBr4IaoPWUjwTaXrubGGh7IiLF5l1Q//rBfy91CSIiXvEqqL/e9DxfjTQGOoW8hqPBNSYiUgJeBfWPn95FnQW7fdYepgbanohIsXkV1AnnSAZY0nEX5ltxLXEqIuUt71Q0s7CZNZvZLwpVTNiMMMnA2nMUaGy2iEgRjeT0dSnwcqEKAfj7804hGdge5FBJgjtPfyWw9kRESiGvoDazGcClwPpCFvONRWcRsuAGUZvBuS/fFlh7IiKlkO8Z9Z3AcsjdL2FmS8xsu5lt37dvXxC1BcJ1HqSpubXUZYiIjNqwQW1mHwHecM7tGOo459w651yDc65h2rRpo6umpTHVsRwkB1sfvkdhLSJlK58z6rnAQjP7M/AT4CIzu78g1WxZE+gYakh1f9zAT1i7eWewDYuIFMmwQe2cW+mcm+GcmwlcCTzpnLu6INUUaH/DOjtAW3tnQdoWESk0r8ZRd8SmF6TdNjeVuppYQdoWESm0EQW1c+7XzrlgV03K8Oix9wa+cUCHq+BOrmTZglnBNiwiUiRenVGfl9geaB+1c3CzW8K8j36hIHsyiogUgzdB3dTcGvg6HwBnJf8QeJsiIsXkTVCv3byTNhfs7i5mcHX4CQ3PE5Gy5k1Qt7V3cnv34sD7qEManiciZc6boK6ribExOS/AJZky2tbwPBEpY94E9bIFs4iGjfsTF+c8q35kYhUfnlHH7Jmn8OEZdTwysSqvtjU8T0TKmTdBvWhOPRMrItzc/Zms9z8ysYrVtVPYE43gzNgTjbC6dsqwYX3chTU8T0TKmjdBDXC4M57zvrsm13As1L/cY6EQd02uGbLNo8Q0PE9EyppXQT0pFs15395IeES3p03myJhqEhEpNa+C+kOJ3/BC5aez3je9OzGi2zMlNt44prpERErJm6DetvFebrXvUW3Hs85OXHqonQnJ/mNCJiSTLD3UPmS7ZnB58rEAKxURKa5IqQuA1KzEc3fcTmSI3V0uPdoBpPqq90bCTO9OsPRQe+/tQwlbIQb9iYgUhxdBvXbzTp5i+Onjlx7tyCuYB3IWCnAnRhGR4vKi66OtvTPw6eNpzkFo5gUFaVtEpBi8COq6mlhBpo9Dqo+avc8H37CISJF4EdTLFsziUbsg8O0Se3UeLFTLIiIF50VQp2clFqr7Q0SknHkR1JCalVio7g8RkXLmTVCnV887Tu7ZiaMWmxJ8myIiReJNUC9bMItYNEwludf7GJVwBVzyrWDbFBEpIi/GUQN9iyY1Bddm0kHo8u/B7MXBNSoiUmTenFEDLAr/liBnphjAljXQ0hhcoyIiReZVULNlTaAzCM2Aw7tg0/UKaxEpW34F9eFdhWk33pk6sxYRKUNeBXWykCtyHN5duLZFRArIm6Buam7FCjmIetKMwrUtIlJA3gT12s07C9d4NAbzVxWufRGRAvImqNvaOzlEdeDtJi0El92tIXoiUra8Ceq6mhibEucHP4XcOYW0iJQ1b4J62YJZXBx+Nus2XGPRlpwabIMiIkXmTVAvmlPPyXYg0Dadg/UVVwfapohIsXkT1BDs2a9z8FTyDLrP/FhgbYqIlIJXQb2+4mo6XEVg7T2U+CA/29FKU3NrYG2KiBSbV0F99qVLeNh9MJALimawOvpDOuOJwg79ExEpMK+CetGcej468YXALihO5ggLQ1tpa+8MpkERkRLwKqgBYp17AmvLDJZHGqmriQXWpohIsXkX1ImAS6q3/Ww9dgV850ytoCciZWnYVDSzU8zsV2b2kpm9aGZLC1qQSwbanhkYTsudikjZyuf0tRv4inPudOB84ItmdnqhCnrDphWqaS13KiJladigds7tcc490/Pvt4CXgfpCFbTrnGV0uQLuEKblTkWkzIwoEc1sJjAHeDrLfUuAJQCnnnrqyCtpaYQtazj38G7i4Qm4RHfg08kBLXcqImUn7yt3ZlYN/Ay4wTn35sD7nXPrnHMNzrmGadNG2H3R0pjqPz68C3BEk51jDukuFyFhA96HtNypiJShvILazKKkQvoB59yGwKvYsibVfxyQhDOee983CX/0+zDpFMBS/9dypyJShobt+jAzA+4DXnbOfbsgVQTcbxwyx7kLr019oWAWkTKXzxn1XOCTwEVm9mzPf38XaBVB9xs7tL6HiIwbw55RO+e2QiF3nSXVb7zp+kC7P77S+ByQmpYuIlLOCjgObgTS3RMPfx5cIpAmE86xcsPzgMJaRMqbP1PIZy+GgGclauU8ERkP/AlqCKSv2jn4g+s7g9bKeSJS7vwK6vmrUmOdx8AM3sbx3q+1cp6IlDs/+qjTevuqrx1TN0hdz96LBvztewq4doiISBH4dUYNqbCeUDOmJtpcau9FB9qKS0TKnl9n1Gmdh0b9UOfg9u6+SS7pC4on2siPpuZW1m7eSVt7J3U1MZYtmOXta1BOtYqUgp9BPWlGz7ofI3fQVbMxOa/fbb5dUBxpMDU1t/LsI+v4XNf91IUOcCw2napL1vSbdZnZ5qRYlKNd3cQTjoWhrSzvaKSu6QAdjw1+3Eht23gvpzyzlr9y+3jDprHrnGV9s0Bz6Vlwi8O7Uz/b+at6a2hqbmXlhufpjKeGZba2dwY3rDLjeTti07k9/gn+68j7T6g3A70Jjg/mgthJdoCGhga3ffv2UT9+28Z7OWPH16myrt7bnGPYhZqcg0NUszp+Tb+wrq+J8dsVF426noHy+eXftvFe6p+5neluP22ulu+FruK8yz8P0BtMC0NbuTnyQ6bYESBV+xOn3cjiz3yFpuZWbtn0Ioc64iwMbeW26Pp+r0d3eAKRy78LQMejq5jQsZc2N5Xbuxf3fu9DPa4pMZdbNr3IBcd+ldquzA6w16byra7F/KbybzGD9o54v+9v28Z7OXPH14lltNfhKlgR/xyPuAu41J7iaxUPchL7MAunxsTHpkDXEUj0PaaTSl445184d+G1zFnzWEYNqddqS/JsFkSeYzr7BwV73tILfWVMokrXCnBTNPU922jbH8DHQBz4JggQi4a59YqzSl6bDGZmO5xzDVnv8y2om5pb+fJPn+W/ov/KBaEXR7WKXvpbShDigcRFvNKwmm8sOmtU9WSrL/OXf2FoK/8auY9qOz7o2Mzau1yEt9wEJoeOcig5kYl2nErig76/Lhfhq/ElAL3hZWR/k0p/n5n3DfxxZnvcW8lKYhYnTDLr4x3QwQSqOEabq+X27sVsiXyQraHPMpkjg9rbnUyF6yfDTxDK8+flHLwVejsPx9/Px8P/PeSbcmaw9+o5W3aHd9PmpvKteN8bVH1NjMftC1Rl2X/zQLKaauuk0vrCK2GR1AJePWGd/tRwkttHwkKESWKTTskZ6PkG4ljCfDSPnXvbk7Rm+TQZ6InLEJ+WxsLHN75CK6ugnnvbk7zvzce5M3pP3n/0Q3EOfm+zOW/1U723NTW38rUNLXTE+4LqH847tV+YZ/6xpifQJwjxSvJkZllrvzn1Qa+bnS2Ag25/JG1n/ork+4YRdC3JnjeQEJAeDxTO8gaVfpOZyLGctWa7/YirpDIcIpLshBzHdLgKbo9+gbMvXcKiOfW9YTIwDBeGtqbeZEP7SRIi7JLssVrWdn+Ch7vn9jtuRUUjJzP4zD4zqGqqohw51k082feDSL8RAP0C7c7TX+HcP34XDu9md7L/J6w0A1697dLUF8N0S6Xb/sfq37M8+lOqOvf2HQeDl36IxvJbpXIE3WGZ329Bw7pAbzr5Kp+gbmnkrYe+RLUdDzSk0mG95/KfAHBj47Nk/M73/mHVhw5wmGomuE4qyb5xwUhDTspDvj/XA8lqzOj9ZHHQVfOiO40PhF7u/YQCud/QnkqewTXxf8raLeUcHLcJtJ62iAl/foKT2d/7iQbo1z10e/diHg9fyEq3nqvCTxImSRIjCUSt75e7w1XwYOJC5oee7X3s+oqrWf31W7J2D6WDtikxt18X3cBaO6kkGZ7AxMThwd/opFPgyy/kfhGHeF5mL+73SaD3Tc/284ZNY/oV3yxMeA5TUzGUR1D/4kbc9vsKtvpT+o/knfY6dbafDiYQ41jv+ESFr+QjW6CP5M17uE8n2dobrosrn2s3mcf0Xt949CboPDjo+HbeRgeVvddXYhxjamhwl1fu79tgdXvugr5zZvbBAj0B/44Vj+DIfo2lYOE5TE1pheySGSqo/Rj10dKYV0g/MrGKuybXsDcSZnp3gqWH2rn0aEdeT2FGvz7vao6NreZxZCyv64kmWzCN5E0+n2MHHhP0c0YSx3Ab/lfqvizHT3JvUWNvgcEM2z/ousdwOmLTqRrqgBzrz7vDu5l325Okn255pLF/SEPfBtX5BnW+3Rm51sTPuL2gI5SG4cWEl45HV+UV0qtrp7AnGsGZsScaYXXtFB6ZOOSvRD86ax4siNdVyo+Re+3ifN4o0rcPDPEOV8E/vXUF2zbemzpLXV2T+n9LY99BOdb0aXNT+/X319n+7E+c79DdAVv8cXgXbFgCv7hx8LG51hnKuH3t5p39+s0hY+G3lsbc328AvAjqWJar8wPdNbmGY6H+5R4Lhbhrck2Bqjox6HWVsTCDbhci6YzdyVpWxD9HIuk485l/7h+Qm67vC68sa/p0EWGCO8afKq9ia8X1LAxtpZ3qXM+aXxBm3eLPwfYfDH58tnWGBuyxmms+RsObjw9+Q8j8fgPgRVAn3PBl7I2ER3T7QAXoih8Xxvq6ioRwvPP4A8zrupuNyXksjzQSY8Bw1XSXBaS6Hi67u28/09gUnHNMDR0hZDAjtJ87out4G0dzPKPrawtyn83m3OJvwOOz1ZRlj9VcC7ytrHhw8BtC5vcbAC/6qEM2/AJM07sT7IkOLnd6dzAbDZyo9LrKWKXX1knL3WWREZyzF/eF4HfOpHLARc0K6x76SdNtDRytkT6bhaFnOGcL8cyasli2YFbWYYMnkcf3O0ZenFEfi5087DFLD7UzIdk/0Cckkyw91F6gqk4Mel1lLI67cL+1dQDaXG32g3P1A48m0NJtZeveSJ/Nzl9Fzp74Uax9v2hOPbdecRb1NTGM1MShW684KzUGPqDnyMWLM+qqS9bQueGLxOjKeUx6FMJoRiekh+aNdqbjeDaW11XKR7eDY1QykbHPUUh3Ix501dzSfc2gCTW3dy/OPqwuo7+3n5Gu7ZPZ1lCjNWYvhtd+l+qTJqPvc6hahrFoTv3gER7hLHu+juE5svEiqJm9mBf+fIizdnwt50QTSIXKSAPEOfhh4mJu7v4Mr1ZeNeRxJ2qIj+Z1lcIb7ncy3xmjmYGaa6LNSMaBL41/gUfcBSSco74mxp0LZvWbobkxOQ/iPRN0QgcIDTfLL9vm1uGK1JMl4/2PjU2BS77V11aukE+fzX7k23Dq+YWdcZhuq4DP4c+EF+DsWx5je/LjRAIKzMyQBnhp8leyrv/QTYgfdV/ExZFnOdntx0HO9TXkxDWaN/ORTEzJfMxBl5oBmV6wa6D0AlPp8E0tMrUfszDOJdhDLbd19U0fn1wV5ebLzmDRnHoaf/BvfODP91BnB2hzU9mSPJsrw7/qv/4JqX7RzJId8MbU8znpus2D6hnztO9s451h+PDzYEZhUMpjZiKpH/bWh+9hrf37mELSOTjmwtzUfS0bk/OIRUPcesVsFoV/O+wPNf0Lt8L9bz4ZfmJwD5dl7/XKNXtsPIV9odcgGeo5Bz5vvqspxjEcoX4hNNTx2WYdOuhd1e+y8O9S08dz/R4AFqpInQlOmsG2d13HDS+9m4Y3H2dlxYOcxH6OuoqcXRBdLsI3o19KrSeS7fcVIDaFbX+9ghteeveoZ8gNnGGXuUZIbyi+9jvY8X9SqyBaGN73qdQZap5tFm0hpRKv0RGUsglqSP2wExtv5KPJX454UabUH2aIr8Y/zxPhC/nmFbMH/6Lk8UMd8heupZHun19HJNE3s7HLKrnZXcvRru7ej3vHYtOpOuPviD/zIyKJzmGnHWf+GI5SSbQiRkW8nYQLDbuGxMDH5zPVONuPPVuNaemPz7FoiDV8v1/wdTsImRHK6AdMfyrJJelS9yfo//31FUNqqdR0OLQ0ppZz7dxLW3IqraE63k9L7ufI+Ii8beO9/I9n1lDjjgx6DY64SiZaF29YLS2x82jo+HXvOh6HqOaW+DX8PKMP1oB/OP9UvvHOl3t+jzI+ducRZtC3vvj18fVM7jljtgE19xonISTDK6ug7tXSyL6f/xO13W8MuivXmchX40t4xF3A3593SmDLmuaqbUR/PC2NHN+0jIp4O/R8rH2DSbzH+rYISy/WUz/UmUhLI/z8S5DoG6PqgNeZxq1dH2f72z+U/cxo9uJha9628V7eteNfmMxbwOALRTWxKKsXntG7LnVq84D9vGG12TcPaGmETTdAvG8sbDxcRTjRSVtyKnckPkFTYi5G32WezI/neRnhzyGvugc4EZfblNIoz6DOkH25xT2pMxiXSA1OL7MzDR8DwMeaRE4UZR/UIiLj3VBB7cWEFxERyU1BLSLiOQW1iIjnFNQiIp5TUIuIeK4goz7MbB/wl8Ab9lct5Frr8ISn12Zoen1yO9Fem9Occ9Oy3VGQoD7RmNn2XMNqTnR6bYam1yc3vTZ91PUhIuI5BbWIiOcU1MFYV+oCPKbXZmh6fXLTa9NDfdQiIp7TGbWIiOcU1CIinlNQj5KZnWJmvzKzl8zsRTNbWuqafGNmYTNrNrNflLoW35hZjZk9ZGZ/MLOXzexvSl2TL8zsyz1/Uy+Y2Y/NbEKpayo1BfXodQNfcc6dDpwPfNHMTi9xTb5ZCrxc6iI8dRfwS+fce4D3otcJADOrB64HGpxzZwJh4MrSVlV6CupRcs7tcc490/Pvt0j9oWmV/R5mNgO4FFhf6lp8Y2aTgAuB+wCcc13OufaSFuWXCBAzswhQBbSVuJ6SU1AHwMxmAnOAp0tcik/uBJZDtg0RT3jvAPYB/9nTNbTezCaWuigfOOdagTuA14A9wGHn3GOlrar0FNRjZGbVwM+AG5xzb5a6Hh+Y2UeAN5xzO0pdi6ciwDnA951zc4CjwIrSluQHM5sMXE7qzawOmGhmV5e2qtJTUI+BmUVJhfQDzrkNpa7HI3OBhWb2Z+AnwEVmdn9pS/LKbmC3cy79CewhUsEtcDHwqnNun3MuDmwAPlDimkpOQT1KZmak+hhfds59u9T1+MQ5t9I5N8M5N5PUhaAnnXMn/FlRmnNuL7DLzGb13DQfeKmEJfnkNeB8M6vq+Rubjy60Eil1AWVsLvBJ4Hkze7bntq855/5v6UqSMnId8ICZVQB/Aj5d4nq84Jx72sweAp4hNbKqGU0l1xRyERHfqetDRMRzCmoREc8pqEVEPKegFhHxnIJaRMRzCmoREc8pqEVEPPf/AV/eKo+CP86xAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "data = Pndataset.train_data\n",
    "data2 = Pndataset_testscore.train_data\n",
    "fig = plt.figure()\n",
    "plt.scatter(data[:,0], data[:,pd_dim])\n",
    "plt.scatter(data2[:,0], data2[:,pd_dim])\n",
    "plt.scatter(Means[:,0,0], Means[:,1,1])\n",
    "print(torch.max(torch.abs(data)), torch.max(torch.abs(data2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(Pndataset, 'P'+str(pd_dim)+'TangentGaussianMixture'+data_date+'.pth')\n",
    "torch.save(Pndataset_testscore, 'P'+str(pd_dim)+'TangentGaussianMixtureForTest'+data_date+'.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Four mixtures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[4.8605, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 1.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 1.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 1.0000]],\n",
      "\n",
      "        [[1.0000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 4.8605, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 1.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 1.0000]],\n",
      "\n",
      "        [[1.0000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 1.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 4.8605, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 1.0000]],\n",
      "\n",
      "        [[1.0000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 1.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 1.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 4.8605]]])\n",
      "tensor([5.], device='cuda:2')\n"
     ]
    }
   ],
   "source": [
    "N = 10000\n",
    "Nmix = 4\n",
    "\n",
    "pd_dim = 4\n",
    "vec_dim = int(pd_dim*(pd_dim+1) / 2)\n",
    "\n",
    "dist = math.sqrt(2*vec_dim)/2\n",
    "\n",
    "data_date = '210912m'+str(Nmix)\n",
    "\n",
    "r = math.sqrt(10) # std condition number\n",
    "c1 = math.pow(r, 1-0)\n",
    "c2 = math.pow(r, -0)\n",
    "var = 0.01\n",
    "Cov_sqrt = math.sqrt(var)*torch.eye(vec_dim)*c2\n",
    "Cov_sqrts = torch.zeros(Nmix,vec_dim,vec_dim)\n",
    "\n",
    "for i in range(Nmix):\n",
    "    Cov_sqrts[i] = Cov_sqrt.clone()\n",
    "    j = rowcol2idx(i, i, pd_dim)\n",
    "    Cov_sqrts[i,j,j] *= r\n",
    "\n",
    "CovInvs = torch.zeros(Nmix, vec_dim, vec_dim)\n",
    "for i in range(Nmix):\n",
    "    CovInvs[i] = torch.inverse(torch.mm(Cov_sqrts[i], Cov_sqrts[i]))\n",
    "\n",
    "Means = torch.zeros(Nmix,pd_dim,pd_dim)\n",
    "m = torch.FloatTensor(Nmix, pd_dim).zero_()\n",
    "for i in range(Nmix):\n",
    "    m[i,i] = dist/math.sqrt(2)\n",
    "Means = torch.diag_embed(torch.exp(m))\n",
    "\n",
    "print(Means)\n",
    "print(squared_distance(Means[0:1].cuda(), Means[1:2].cuda()))\n",
    "\n",
    "Pndataset = PndataTangentGaussianMixture(N, Means, Cov_sqrts)\n",
    "#torch.save(Pndataset, 'P'+str(pd_dim)+'TangentGaussianMixture'+data_date+'.pth')\n",
    "#Pndataset = torch.load('P'+str(pd_dim)+'TangentGaussianMixture'+data_date+'.pth')\n",
    "\n",
    "N_testscore = 10000\n",
    "Pndataset_testscore = PndataTangentGaussianMixture(N_testscore, Means, Cov_sqrts)\n",
    "#torch.save(Pndataset_testscore, 'P'+str(pd_dim)+'TangentGaussianMixtureForTest'+data_date+'.pth')\n",
    "#Pndataset_testscore = torch.load('P'+str(pd_dim)+'TangentGaussianMixtureForTest'+data_date+'.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.2500, 0.2500, 0.2500, 0.2500], device='cuda:2')\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor(0, device='cuda:2'),\n",
       " tensor(0, device='cuda:2'),\n",
       " tensor(0, device='cuda:2'),\n",
       " tensor(0, device='cuda:2'))"
      ]
     },
     "execution_count": 121,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = vec2mat(Pndataset.train_data.cuda())\n",
    "weights = torch.cuda.FloatTensor([1/Nmix]*Nmix)\n",
    "print(weights)\n",
    "logp = log_rho_tangentGaussianMixture(X.cuda(), weights, Means.cuda(), CovInvs.cuda())\n",
    "score_true_train, _ = geometricScore_tangentGaussianMixture(X.cuda(), weights, Means.cuda(), CovInvs.cuda())\n",
    "logp.isinf().sum(), logp.isnan().sum(), score_true_train.isinf().sum(), score_true_train.isnan().sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(21.2802) tensor(19.9224)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAh4UlEQVR4nO3df5RU9Znn8fdT3dVQgKH50RG6wZBkPOyYqKA9JhnZWRNU1CTIuJHoOInZjYOZmImZzQExySAhydHAJBmNs3GIump+OHYm2CHBBDmYnUg2P2wEQaOMxtWFboRGaBRo6B/17B91u6nuruqun1T1rc/rnD5V9b3fW/eponjq1vd+73PN3RERkfCKlDoAEREpLiV6EZGQU6IXEQk5JXoRkZBTohcRCbnqUgeQytSpU33WrFmlDkNEZNTYunXrAXevS7WsLBP9rFmzaGlpKXUYIiKjhpm9mm6Zhm5EREJOiV5EJOSU6EVEQm7ERG9mM83sl2b2BzN7zsxuDtonm9kmM3sxuJ2UZv3rgz4vmtn1hX4BIiIyvEz26HuAz7v7WcB7gZvM7CxgObDZ3c8ENgePBzCzycBtwHuAC4Db0n0hiIhIcYw468bd9wJ7g/tvmtnzQANwJXBR0O1B4H8DtwxafQGwyd0PApjZJuAy4OECxJ5S87ZW1mzcRVtHJ/W1MZYumM2iuQ3F2pyISNnLanqlmc0C5gK/A04PvgQAXgNOT7FKA7A76fGeoC3Vcy8BlgCcccYZ2YTVr3lbK7eu20lndy8ArR2d3LpuJ4CSvYhUrIwPxprZBODHwOfc/Y3kZZ6odZxXvWN3X+vuje7eWFeXcs7/iNZs3NWf5Pt0dveyZuOufEITERnVMkr0ZhYlkeR/4O7rguZ9ZjY9WD4d2J9i1VZgZtLjGUFbUbR1dGbVLiJSCTKZdWPAfcDz7v7NpEXrgb5ZNNcDP0mx+kbgUjObFByEvTRoK4r62lhW7SIilSCTPfoLgY8BHzCz7cHfFcAdwCVm9iJwcfAYM2s0s3sBgoOwXwGeCv5W9R2YLYalC2YTi1YNaItFq1i6YHaxNikiUvasHC8l2NjY6LnWutGsGxGpRGa21d0bUy0ry6Jm+Vg0t0GJXUQkiUogiIiEnBK9iEjIKdGLiIScEr2ISMgp0YuIhJwSvYhIyCnRi4iEnBK9iEjIKdGLiIScEr2ISMgp0YuIhJwSvYhIyCnRi4iEXOiqV/ZRuWIRkYRQJnpdJFxE5KRQDt3oIuEiIieNuEdvZvcDHwL2u/u7g7ZHgL7r89UCHe4+J8W6rwBvAr1AT7qrnxSaLhIuInJSJkM3DwB3Aw/1Nbj7R/vum9k3gMPDrP9+dz+Qa4C5qK+N0Zoiqesi4SJSiUYcunH3XwEpL+htZgYsBh4ucFx50UXCRUROyneM/j8D+9z9xTTLHXjczLaa2ZLhnsjMlphZi5m1tLe35xXUorkN3H7V2TTUxjCgoTbG7VedrQOxIlKR8p11cy3D783Pc/dWM3srsMnMXgh+IQzh7muBtQCNjY2eZ1y6SLiISCDnPXozqwauAh5J18fdW4Pb/cCjwAW5bk9ERHKTz9DNxcAL7r4n1UIzG29mp/XdBy4Fns1jeyIikoMRE72ZPQz8BphtZnvM7JPBomsYNGxjZvVm9ljw8HRgi5k9A/we2ODuvyhc6CIikokRx+jd/do07Z9I0dYGXBHcfxk4N8/4REQkT6E8M1ZERE5SohcRCbnwJfodTfCtd8PK2sTtjqZSRyQiUlLhql65owl++lnoDsofHN6deAxwzuLSxSUiUkLh2qPfvOpkku/T3ZloFxGpUOFK9IdTTulP3y4iUgHClegnzsiuXUSkAoQr0c9fAdFBpYijsUS7iEiFCleiP2cxfPgumDgTsMTth+/SgVgRqWjhmnUDiaSuxC4i0i9ce/QiIjKEEr2ISMgp0YuIhFz4xugHad7WypqNu2jr6KS+NsbSBbN15SkRqSihTvTN21q5dd1OOrt7AWjt6OTWdTsBlOxFpGKEeuhmzcZd/Um+T2d3L2s27ipRRCIip15oE33ztlZaOzpTLmtL0y4iEkaZXErwfjPbb2bPJrWtNLNWM9se/F2RZt3LzGyXmb1kZssLGfhw+oZs0qmvjaVdJiISNpns0T8AXJai/VvuPif4e2zwQjOrAv4ZuBw4C7jWzM7KJ9hMpRqy6ROLVrF0wexTEYaISFkYMdG7+6+Agzk89wXAS+7+srt3Af8KXJnD82RtuKGZ2686WwdiRaSi5DNG/xkz2xEM7UxKsbwB2J30eE/QlpKZLTGzFjNraW9vzyOs9EMzDbUxJXkRqTi5JvrvAO8E5gB7gW/kG4i7r3X3RndvrKury+u5li6YTSxaNaBNQzYiUqlymkfv7vv67pvZd4GfpejWCsxMejwjaCu6vr12nSglIpJjojez6e6+N3j4l8CzKbo9BZxpZm8nkeCvAf4qpyhzsGhugxK7iAgZJHozexi4CJhqZnuA24CLzGwO4MArwI1B33rgXne/wt17zOwzwEagCrjf3Z8rxotIR+UPRETA3L3UMQzR2NjoLS0teT1H31z6S3r/nWXVTdTbAfYylbbzl/FnC28sUKQiIuXBzLa6e2OqZaGtdbNm4y4u6f137ojeyzjrAqCBA0x++h9g1iRdnEREKkZoSyC0dXSyrLqpP8n3iXECNq8qUVQiIqdeaBN9fW2MejuQeuHhPac2GBGREgptov+ns14knu7lTZxxaoMRESmhcCb6HU3MfWYF1RYfsqinaizMX1GCoERESiOciX7zKqp7jw9p7vEIX7VP6UCsiFSUcCb6w7tTNkeI8+CRC05xMCIipRWqRN+8rZUL73iCHk/9suJEVIteRCpOaObRJ18ftmrM0LF5gCriKmwmIhUnNHv0yRcbafWpKfukaxcRCbPQJPrki42s7lnMMa8ZsPyY17C6Z7EuDC4iFSc0Qzf1tbH+i4Gvj8+DboIaN6/T5lNY3bOY9fF5mC4MLiIVJjSJfumC2fz9I9vpK9G2Pj6P9V3zhvTTwVgRqTShGbpZNLeBTOpwHj3RQ/O2U3L9ExGRshCaPXqAKjN6B5VdXhjZ0l+muM2nsvrEYm5dl5iVo9r0IlIJQpXoe90HJPZDPoHT7Dg11gPADDvAHdF7oRvWbKxRoheRihCaoRuAT0z4PXdE72VG5AARgymRI/1Jvs8462JZddOAWToiImE2YqI3s/vNbL+ZPZvUtsbMXjCzHWb2qJnVpln3FTPbaWbbzSy/S0ZlYFn0kSH151Opt9d1UFZEKkYme/QPAJcNatsEvNvdzwH+A7h1mPXf7+5z0l3iqpDGdb6WUb+9TNEZsiJSMUZM9O7+K+DgoLbH3b1vTOS3QHkUeI9NGrGLA8dnXazxeRGpGIUYo//vwM/TLHPgcTPbamZLhnsSM1tiZi1m1tLe3l6AsNJsB3hnx6+L9vwiIuUmr0RvZl8EeoAfpOkyz93PAy4HbjKzv0j3XO6+1t0b3b2xrq4ut4A6D2XWT5cSFJEKknOiN7NPAB8CrnP3lOcquXtrcLsfeBQobjH4TC8RmMEQj4hIWOSU6M3sMmAZsNDdj6XpM97MTuu7D1wKPJuqb8HMX0HqAsWDdB2BHU1FDUVEpFxkMr3yYeA3wGwz22NmnwTuBk4DNgVTJ+8J+tab2WPBqqcDW8zsGeD3wAZ3/0VRXkWgufdCOnzCyB17u2DzqmKGIiJSNkY8M9bdr03RfF+avm3AFcH9l4Fz84ouS2s27mIhRzLrrHF6EakQoToztq2jk0OZ7NFD5uP5IiKjXKgSfX1tDLOR+/VUjYX5K4ofkIhIGQhVol+6YDa1IwzduMPy7hto7r3wFEUlIlJaoUr0AB0MP3QTx/i3rj/XJQVFpGKEKtFv37CW8QxflTJCopRxq6pXikiFCFWiv6Hr+4yx3mH7mMFt1Q9RlclgvohICIQq0ddHXs+o32Q7MuRKVCIiYRWqRH88Ni3jvtqjF5FKEapEP+5dV5DJjvohJmiPXkQqRqgS/bHnHhtxHr07rOz+OLWx6KkJSkSkxEKT6Ju3tTL22N6M+x/t6qF5W2sRIxIRKQ+hSfRrNu4insHLMYNl1U1097rm0otIRQhNom/r6KQqsyLF1Nvr/euIiIRdaBJ9fW2MgxkWNOtgfP86IiJhF5pEv3TB7IwKmgGMp5MrI1tYumB2cYMSESkDoUn0i+Y2jFjQrM8Y62VpdROL5jYUOSoRkdLLKNGb2f1mtt/Mnk1qm2xmm8zsxeA25YVYzez6oM+LZnZ9oQJPZaSCZskyPYtWRGS0y3SP/gHgskFty4HN7n4msDl4PICZTQZuA95D4sLgt6X7QsjbjiYmWOYHV7M5i1ZEZDTLKNG7+6+Ag4OarwQeDO4/CCxKseoCYJO7H3T3Q8Amhn5hFMbmVdQwfEGzPj1VYxl3ua4ZKyKVIZ8x+tPdve8MpddIXAx8sAZgd9LjPUFb4WV4DVh3qL7y23DO4qKEISJSbgpyMNbdHcireIyZLTGzFjNraW9vz/4JYpmNCB1iAl96+U+zf34RkVEqn0S/z8ymAwS3+1P0aQVmJj2eEbQN4e5r3b3R3Rvr6uryCGt4Ue/mh7/7f0V7fhGRcpNPol8P9M2iuR74SYo+G4FLzWxScBD20qCt8DoPZdRtgp0g7qjOjYhUjEynVz4M/AaYbWZ7zOyTwB3AJWb2InBx8BgzazSzewHc/SDwFeCp4G9V0FZ4E2dk1f3WdTuV7EWkIpiXYV32xsZGb2lpyW6lHU34j/8mozLFN3d/mvXxeTTUxvj18g/kHqiISJkws63u3phqWWjOjOWcxRnVujGDb0bvYWFki4qaiUhFCE+iB77c8/GMrjBVbXGWVTepqJmIVIRQJfr18XkZ962311XUTEQqQmgSffO2Vr5cfX/G/V+zKSpqJiIVobrUARRC87ZWbl23k51VT2RUqtgdOmddXPzARETKQCj26Nds3EVnd2/GV5gyg7ftaYYdTcUNTESkDIQi0ffNnunN4uVU9x6HzSpsJiLhF4pE3zd75v/E/zSjWTf9MiyEJiIymoUi0S9dMJtYtIoLIi9kfDlBIOuzaUVERqNQHIztmz0zpjmzevQAPRalev6KYoUkIlI2QrFHD2Q9VfKN+BjVpBeRihCaRJ9tgbJajhYpEhGR8hKaRL9m4y56yXyA/pCPL2I0IiLlIzSJvvGNTUSyuMhVVgdtRURGsVAcjAW4paYpq28tDd2ISKUIzR79NF7Pqn+bTylSJCIi5SU0ib4tnnnidofVPZpxIyKVIedEb2azzWx70t8bZva5QX0uMrPDSX2KNnH93pq/zu6sWBGRCpFzonf3Xe4+x93nAOcDx4BHU3R9sq+fuxetuMycDy7JuK8Z3BJVQTMRqQyFGrqZD/zR3V8t0PMVXT0HSh2CiMgpUahEfw3wcJpl7zOzZ8zs52b2rnRPYGZLzKzFzFra29uzDmDNxl3ZrWCoTLGIVIS8E72Z1QALgR+lWPw08DZ3Pxf4NtCc7nncfa27N7p7Y11dXdZxtHV0ZjGLPpHnVaZYRCpBIfboLweedvd9gxe4+xvufiS4/xgQNbOpBdjmEPW1sSzOiw0c3l2MUEREykohEv21pBm2MbNpZolzUM3sgmB72U14z9D/qvpa9itZVeEDEREpM3mdGWtm44FLgBuT2j4F4O73AB8B/tbMeoBO4Br34kyCPPNoS/ZlDTzzssYiIqNVXone3Y8CUwa13ZN0/27g7ny2kXkwkPXYjfboRaQChObM2Fy49uhFpAKEJtE/GX9X1mfGtsanZl3HXkRktAlNov949xez6n/Ma1jdszj7+fciIqNMaMoUf7n6/oz7ukNL/EzWx+dhHZ1FjEpEpPRCs0d/XdUTGc+6MYMLI38AEvPvRUTCLDSJvop4Vv0jOLFoFUsXzC5SRCIi5SE0iT67NJ/wX89vYNHchoLHIiJSTkKT6LuJZr3OL1/IvniaiMhoE5pEP4burPrHMdp0IFZEKkBoEn22IjgTY9n/ChARGW1Ck+hzKaCTdW0cEZFRKDSJPpec3XEsu+EeEZHRKDSJPpdMf/2E3xc+DhGRMhOeRJ8lM1gWfaTUYYiIFF14En0Og/TjOl8rfBwiImUmPIk+l0H6iTMKHoaISLkJT6LPkgPMX1HqMEREii7vRG9mr5jZTjPbbmYtKZabmd1lZi+Z2Q4zOy/fbRaCA5yzuNRhiIgUXaH26N/v7nPcvTHFssuBM4O/JcB3CrTNvJiji46ISEU4FUM3VwIPecJvgVozm17wrQw6GLth/DgunVHPObNmcumMejaMHzdkle0b1hY8DBGRclOIRO/A42a21cyWpFjeAOxOerwnaBvAzJaYWYuZtbS351dsbMP4caycOpm90WrcjL3RalZOnTwg2ZvBDV3fz2s7IiKjQSES/Tx3P4/EEM1NZvYXuTyJu69190Z3b6yrq8t+/aT7d06q5Xhk4Es7Holw56TaAW31kddziFREZHTJO9G7e2twux94FLhgUJdWYGbS4xlBW0Elz658rboqZZ/B7R0+vtBhiIiUnbwSvZmNN7PT+u4DlwLPDuq2Hvh4MPvmvcBhd9+bz3ZHMq2nN6t2EZEwy3eP/nRgi5k9A/we2ODuvzCzT5nZp4I+jwEvAy8B3wU+nec2R3TzoQ7Gxgdec2psPM7NhzoGtNVytNihiIiUXHU+K7v7y8C5KdrvSbrvwE35bCdbHzx6DEiM1b9WXcW0nl5uPtTR396nzaegc2NFJOzySvTlxBk4Tv/Bo8eGJPYB/R1W9yzmrqJHJiJSWqEpgaBriIiIpBaORL+jKetMbwZfr7k3sa6ISIiFI9FvXpXTHn2MLti8quDhiIiUk3Ak+sO7R+6Thh/eU8BARETKTzgSvaU+QSoT+5hawEBERMpPKBK9e24nQrnD7V1XFzgaEZHyEopEH89xzk0caHnLJYUNRkSkzIQi0Uc8hwvGkpios3TB7MIGIyJSZkKR6HMWibBo7pCKySIioRKKRN9hE3JaL+LxkTuJiIxyoUj0L523gpxGb2KTCx6LiEi5CUWi/7OFN5Y6BBGRshWKRA/Q6jnMh+88VPhARETKTGgS/eqexdkP30xUkWIRCb/QJPpsuQPzV5Q6DBGRogtNol8ZfQjL9rypcxYXJRYRkXKSc6I3s5lm9ksz+4OZPWdmN6foc5GZHTaz7cFf0XahJ3Ekq/7dqmAvIhUinytM9QCfd/engwuEbzWzTe7+h0H9nnT3D+WxnaKIktvZtCIio03Oe/Tuvtfdnw7uvwk8D5TsNNN41uM2IiKVoSBj9GY2C5gL/C7F4veZ2TNm9nMze9cwz7HEzFrMrKW9vT3rGCJZ7qFrf15EKkXeid7MJgA/Bj7n7m8MWvw08DZ3Pxf4NtCc7nncfa27N7p7Y11dXdZx7CO7dQxo3taa9XZEREabvBK9mUVJJPkfuPu6wcvd/Q13PxLcfwyImllRrvRxe9fVWc+jX7NxVzFCEREpK/nMujHgPuB5d/9mmj7Tgn6Y2QXB9l7PdZvDmTSuJqv+B30CbR2dxQhFRKSs5DPr5kLgY8BOM9setH0BOAPA3e8BPgL8rZn1AJ3ANe45Fo8fwbLoI1hPZn3d4Tl/G/W1sWKEIiJSVnJO9O6+BYafjO7udwN357qNbIzrfC3jvmbw55HnddEREakIoTkzNtu6NVUW10VHRKQihCfRz8+uJr1ZVfFiEREpI+FJ9OcsHmEg6SQHOP8TRQxGRKR85HMwdtRxh14iPDXlSt73oZQThU6Z5m2trNm4i9aOTqrM6HWnoTbG0gWzsxtS2tEEm1fB4T2J4av5K1IXa8u0X9hU6usWSWJFmgSTl8bGRm9pacl+xa/VQ/fRtIv7Xmoc45VZH+Wd/+1fhn++dEliRxP8/BboPHiyb2wyXP71xP3Nq+DwbhwDHByOMoYuotTaUfb6FL7evZj18XkALIxsYWX0oZOF2YLVAA4xgZXdHwdgeU0T0ziAu508E9gG/5Ax9k15Dz0HXmI6BzjkExhj3Yy3EwP6ucNRG0M3USb6EfZbHV4zgWld/zexbYPfcDbXHb+V+hRfQE33f4MLX/2fTOcAHUxgTHUV43sPg1WB98LEmf3vV9+XWltHJ7XjorjD4c7uAffra2O8/z/V8csX2mnr6By6zeR/i9ikRFvnoRG/3Hp+8ndU9x7vb+qpGkv1ld8eMdk3b2tl+4a13ND1feojr3M8No1xl686+e+vLw8pM2a21d0bUy4LS6J/av2/0Lh1Wcalih2wt/8XmPvXA/7T/rH2Qsa/upm3ejs2KIkmv1WpttO3PJMYBr/tw62TzfP29c+l9M/g9Ub6aIy0DXeIM3R8MA50E2UM3QPa27yWaXa4/0vMg20cYwzj/ETa7bnDCaoZQw+e/G+W5n1whw5OI+7OJDsy5IvKvZdDPoEJHKcmac7uCa+ix6KM4/jAL9dojD/WX9n/udnLVL7evZitb7mkf2ZX3xfdgC+w/i+M3f1fkMdi01nd/VEePHJByi9YkXRCn+ibt7WysPksIlkmt2DHdWBbjklS8jea3/vBsccdvtd7MV+NfxIMunsH/j+7duxv+UrV2gG/Nvoc8xqWd9/A+vg8YtEqbr/q7NS/bEb4NZH8S6okXxqZxHoKfx2V/P0ostAn+qUrvshqu3vAf7SvTq7lR285rX+P8uo33uRLBzsKHWpJbRg/jjsn1fJadRXTenq5+VAHHzx6rNRhSSDu8LnuT/cP0SXbUvNZZkQOpF13T3wq87ruAqDKjG8sPpdFVb9OPRQ19zp48fEBybK590JuXbeTzu7eAc9bG4uycuG7ip/gdjTBTz8L3Ulnn0dj8OG7TibyTPoUSPO21iHvx5Av0VEu9Il+z4p3DvhP89XJtTzyltOGjEN8NETJfsP4caycOpnjkZMDI2PjcVYeOKhkX0YO+gTOO7GWhZEtLKtuot4O0OZTabADw/56ibvxjhM/6H9swJNpvhyG/DKNxljpN/LAkQtSPnfWCS6Xve5vvTsxJDXYxJnw989m3qdALrzjCVpTlDxpqI3x6+UfKOi2clGIXxvDJfpQTK9ssIEf/h8NTvIAZon2kLhzUu2AJA9wPBLhzkm1pQlIUprEEb5cfT93RO9lRuQAEYMZkQMjlslu8ykDHjtQb6l/AQz5vuju5Iau76d97s7u3swL+vXtdR/enYji8O7E4x1Nw693eM/I7Zn0KZB0da3Kod5V36+N1o5OHGjt6OTWdTsLWl03FIl+sHiW7aPRa9WpT/hK1y6lYQbXVT3BOOsa0B6xxNBOKse8htU9Q/eY2zzzwq/1keFrB2ac4DavGji0AonHm1cNv166M9WT2zPpUyDp6lqVQ72rNRt3DRliy+rLOAOhTPTpXlSYXuy0nt6s2qV0qiz9Lsae+FTiDj0ewR2OxabzhZ6/STmuv7pnMcd8YJXWdF8Wx2PTiEXTf+lnnOBy3euevyIx3p4sGku0Z9OnQJYumD3k/YhFq8qi3tWp+LURitzng4Ywrn7jzaFzA90T7YPXLb9DFBm5+VAHY+MDE8jYeJybD3WUJiBJK125jTZPHHB9x4kf8icnvs8/zN3CuFte4KKP3JQySa+Pz2N59w3Bl4OxJz6Vh/0SeqrGDuwYjTHu8lXcftXZTBoXHfI8WSW4XPe6z1mcOKg6cSZgidvBB1kz6VMgi+Y2cPtVZ9NQG8NIjM2Xy4HYU/FrIxQHY/nZ/8Cfui+nWTfHPUoN3VlPzSwHlTbrJtvzCXJ9npGmebrDUcbSRTWT7CgWm0T3iU6q451Dx8uramDux+CZHw4YAun0Gm4JplCawXXvOYOvLjq7f3nymdNJ588l4g4e959JXfXrYQ+W5nWg7xTOjKlUhZoRFPpZNwCH15zHW478Mask0GtRnp77NTbsaOPvuu5lciRxZqpZBPc4cU98SaQ6iWhwWyYnDyU7FfPFs5mXns3rSte3T6YnXQ23rcGP++aWnx/5D66reoKqFEdc4iTOGO67Tbed7/VeQkv8TFbVfI+JBL/ygu4HfQI/i7+X+ZHt1NsB4kSoIk6vJW73UcftXVfTEpwMNeA/4uAzpvvOls7zbNqSz//WmcBFV+xZN6FJ9AD7vr2Atx747bB7agAYdMamnzylPY20p8EDh376JSZ27afNp7A5PoeLqxKJwS0CHudQfAIRMyZyhDafwuqegSUPllU3DZktdIJqxlgv1n+K/8H+MzWP+hhidA1IYL19SYgIEeJ0Vk/E3RnX+2Z/XFdFnmSCnUj5+uIYhtPmU1nds5jP1KznTPac3CudMJ0TJ45R0324fw5fV3QiYz78j4nlg/7zN/demL5swGCDk8eZlw6cCz74cYpyCun+QzRva+XLP32OQ8e6WRjZwvKaJqbzOqYkJSFWMYkeyK0miojIKDdcos+reqWZXQbcCVQB97r7HYOWjwEeAs4nca3Yj7r7K/lsc0TnLFYyFxFJks/FwauAfwYuB84CrjWzswZ1+yRwyN3/BPgW8PVctyciIrnJZ3rlBcBL7v6yu3cB/wpcOajPlcCDwf1/A+abjdayVSIio1M+ib4BSC5UsSdoS9nH3XuAw8AUUjCzJWbWYmYt7e3teYQlIiLJyuaEKXdf6+6N7t5YV1dX6nBEREIjn0TfCsxMejwjaEvZx8yqgYkkDsqKiMgpks+sm6eAM83s7SQS+jXAXw3qsx64HvgN8BHgCc9gPufWrVsPmNmrSU1TgfTFu6WP3qfM6H3KjN6nzJTL+/S2dAtyTvTu3mNmnwE2kpheeb+7P2dmq4AWd18P3Ad8z8xeAg6S+DLI5LkHjN2YWUu6+aFykt6nzOh9yozep8yMhvcpr3n07v4Y8NigthVJ948DV+ezDRERyU/ZHIwVEZHiGC2Jfm2pAxgl9D5lRu9TZvQ+Zabs36eyrHUjIiKFM1r26EVEJEdK9CIiIVfWid7MLjOzXWb2kpktL3U85czMXjGznWa23cxyrPEcPmZ2v5ntN7Nnk9omm9kmM3sxuJ1UyhjLQZr3aaWZtQafqe1mdkUpYywHZjbTzH5pZn8ws+fM7Oagvaw/U2Wb6DOsjikDvd/d55T7nN5T7AHgskFty4HN7n4msDl4XOkeYOj7BPCt4DM1J5hOXel6gM+7+1nAe4GbgrxU1p+psk30ZFYdU2RY7v4rEifrJUuuqvogsOhUxlSO0rxPMoi773X3p4P7bwLPkyjeWNafqXJO9JlUx5STHHjczLaa2ZJSB1PmTnf3vcH914DTSxlMmfuMme0IhnbKajii1MxsFjAX+B1l/pkq50Qv2Znn7ueRGOq6ycz+otQBjQZB7SXNMU7tO8A7gTnAXuAbJY2mjJjZBODHwOfc/Y3kZeX4mSrnRJ9JdUwJuHtrcLsfeJTE0Jekts/MpgMEt/tLHE9Zcvd97t7r7nHgu+gzBYCZRUkk+R+4+7qguaw/U+Wc6PurY5pZDYmCaOtLHFNZMrPxZnZa333gUuDZ4deqaH1VVQluf1LCWMpWX+IK/CX6TBFcIe8+4Hl3/2bSorL+TJX1mbHBdK5/4mR1zK+VNqLyZGbvILEXD4lCdT/Ue5VgZg8DF5EoJbsPuA1oBpqAM4BXgcXuXtEHItO8TxeRGLZx4BXgxqRx6IpkZvOAJ4GdQDxo/gKJcfqy/UyVdaIXEZH8lfPQjYiIFIASvYhIyCnRi4iEnBK9iEjIKdGLiIScEr2ISMgp0YuIhNz/B9HJVlfepanoAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "data = Pndataset.train_data\n",
    "data2 = Pndataset_testscore.train_data\n",
    "fig = plt.figure()\n",
    "plt.scatter(data[:,0], data[:,pd_dim])\n",
    "plt.scatter(data2[:,0], data2[:,pd_dim])\n",
    "plt.scatter(Means[:,0,0], Means[:,1,1])\n",
    "print(torch.max(torch.abs(data)), torch.max(torch.abs(data2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(Pndataset, 'P'+str(pd_dim)+'TangentGaussianMixture'+data_date+'.pth')\n",
    "torch.save(Pndataset_testscore, 'P'+str(pd_dim)+'TangentGaussianMixtureForTest'+data_date+'.pth')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
