{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.lines as lines\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "import time\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sph_n_DataUtil import *\n",
    "from sph_n import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "useGPU = torch.cuda.is_available()\n",
    "#useGPU = False\n",
    "print(useGPU)\n",
    "np.random.seed()\n",
    "torch.cuda.set_device(0)\n",
    "torch.cuda.seed()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generate a Gaussian mixture data in tangent space of S(n) with equidistant means\n",
    "## True score is available"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Two mixtures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.9546)\n",
      "5000\n",
      "5000\n",
      "5000\n",
      "5000\n",
      "tensor(0) tensor(0) tensor(0) tensor(0)\n",
      "tensor(0) tensor(0) tensor(0) tensor(0)\n"
     ]
    }
   ],
   "source": [
    "N = 10000\n",
    "Nmix = 2\n",
    "\n",
    "sph_dim = 2\n",
    "dist = 1\n",
    "\n",
    "data_date = 'test'+str(Nmix)\n",
    "\n",
    "var = 0.01\n",
    "Cov_sqrt = np.sqrt(var)*torch.eye(sph_dim)\n",
    "Cov_sqrts = torch.zeros(Nmix,sph_dim,sph_dim)\n",
    "for i in range(Nmix):\n",
    "    Cov_sqrts[i] = Cov_sqrt.clone()\n",
    "\n",
    "CovInvs = torch.zeros(Nmix, sph_dim, sph_dim)\n",
    "for i in range(Nmix):\n",
    "    CovInvs[i] = torch.inverse(torch.mm(Cov_sqrts[i], Cov_sqrts[i]))\n",
    "    \n",
    "means = torch.zeros(Nmix,sph_dim)\n",
    "for i in range(Nmix):\n",
    "    means[i,i] = dist/math.sqrt(2)\n",
    "\n",
    "pos = getPos_torch(Exp_torch(torch.FloatTensor(means)))\n",
    "print(torch.acos((pos[0]*pos[1]).sum()))\n",
    "    \n",
    "Sndataset = SndataTangentGaussianMixture(N, means, Cov_sqrts)\n",
    "SnPosDataset = SnPosdataFromTh(Sndataset.train_data)\n",
    "\n",
    "N_testscore = 10000\n",
    "Sndataset_testscore = SndataTangentGaussianMixture(N_testscore, means, Cov_sqrts)\n",
    "SnPosDataset_testscore = SnPosdataFromTh(Sndataset_testscore.train_data)\n",
    "print(Sndataset.train_data.isnan().sum(), Sndataset.train_data.isinf().sum(), \n",
    "      SnPosDataset.train_data.isnan().sum(), SnPosDataset.train_data.isinf().sum())\n",
    "print(Sndataset_testscore.train_data.isnan().sum(), Sndataset_testscore.train_data.isinf().sum(), \n",
    "      SnPosDataset_testscore.train_data.isnan().sum(), SnPosDataset_testscore.train_data.isinf().sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.5 0.5]]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(0, 0, 0, 0)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "testinput1 = Sndataset_testscore.train_data.numpy()\n",
    "metricInv_sqrt_test = torch.FloatTensor(metricInvSqrt(testinput1))\n",
    "testPosInput = SnPosDataset_testscore.train_data.clone().cuda()\n",
    "traininput1 = Sndataset.train_data.numpy()\n",
    "\n",
    "# true score\n",
    "weights = np.asarray([[1/Nmix]*Nmix])\n",
    "score_true = geometricScore_coord0_tangentGaussianMixture(testinput1, weights, means.numpy(), \n",
    "                                                          CovInvs.numpy())\n",
    "score_true_train = geometricScore_coord0_tangentGaussianMixture(traininput1, weights, means.numpy(), \n",
    "                                                          CovInvs.numpy())\n",
    "print(weights)\n",
    "np.isinf(score_true).sum(), np.isnan(score_true).sum(), np.isinf(score_true_train).sum(), np.isnan(score_true_train).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(Sndataset, 'S'+str(sph_dim)+'TangentGaussianMixture'+data_date+'.pth')\n",
    "torch.save(Sndataset_testscore, 'S'+str(sph_dim)+'TangentGaussianMixtureForTest'+data_date+'.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Six mixtures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.9546)\n",
      "1666\n",
      "1670\n",
      "1666\n",
      "1670\n",
      "tensor(0) tensor(0) tensor(0) tensor(0)\n",
      "tensor(0) tensor(0) tensor(0) tensor(0)\n"
     ]
    }
   ],
   "source": [
    "N = 10000\n",
    "Nmix = 6\n",
    "\n",
    "sph_dim = 6\n",
    "dist = 1\n",
    "\n",
    "data_date = '210521m'+str(Nmix)\n",
    "\n",
    "var = 0.01\n",
    "Cov_sqrt = np.sqrt(var)*torch.eye(sph_dim)\n",
    "Cov_sqrts = torch.zeros(Nmix,sph_dim,sph_dim)\n",
    "for i in range(Nmix):\n",
    "    Cov_sqrts[i] = Cov_sqrt.clone()\n",
    "\n",
    "CovInvs = torch.zeros(Nmix, sph_dim, sph_dim)\n",
    "for i in range(Nmix):\n",
    "    CovInvs[i] = torch.inverse(torch.mm(Cov_sqrts[i], Cov_sqrts[i]))\n",
    "    \n",
    "means = torch.zeros(Nmix,sph_dim)\n",
    "for i in range(Nmix):\n",
    "    means[i,i] = dist/math.sqrt(2)\n",
    "\n",
    "pos = getPos_torch(Exp_torch(torch.FloatTensor(means)))\n",
    "print(torch.acos((pos[0]*pos[1]).sum()))\n",
    "    \n",
    "Sndataset = SndataTangentGaussianMixture(N, means, Cov_sqrts)\n",
    "SnPosDataset = SnPosdataFromTh(Sndataset.train_data)\n",
    "\n",
    "N_testscore = 10000\n",
    "Sndataset_testscore = SndataTangentGaussianMixture(N_testscore, means, Cov_sqrts)\n",
    "SnPosDataset_testscore = SnPosdataFromTh(Sndataset_testscore.train_data)\n",
    "print(Sndataset.train_data.isnan().sum(), Sndataset.train_data.isinf().sum(), \n",
    "      SnPosDataset.train_data.isnan().sum(), SnPosDataset.train_data.isinf().sum())\n",
    "print(Sndataset_testscore.train_data.isnan().sum(), Sndataset_testscore.train_data.isinf().sum(), \n",
    "      SnPosDataset_testscore.train_data.isnan().sum(), SnPosDataset_testscore.train_data.isinf().sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667]]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(0, 0, 0, 0)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "testinput1 = Sndataset_testscore.train_data.numpy()\n",
    "metricInv_sqrt_test = torch.FloatTensor(metricInvSqrt(testinput1))\n",
    "testPosInput = SnPosDataset_testscore.train_data.clone().cuda()\n",
    "traininput1 = Sndataset.train_data.numpy()\n",
    "\n",
    "# true score\n",
    "weights = np.asarray([[1/Nmix]*Nmix])\n",
    "score_true = geometricScore_coord0_tangentGaussianMixture(testinput1, weights, means.numpy(), \n",
    "                                                          CovInvs.numpy())\n",
    "score_true_train = geometricScore_coord0_tangentGaussianMixture(traininput1, weights, means.numpy(), \n",
    "                                                          CovInvs.numpy())\n",
    "print(weights)\n",
    "np.isinf(score_true).sum(), np.isnan(score_true).sum(), np.isinf(score_true_train).sum(), np.isnan(score_true_train).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(Sndataset, 'S'+str(sph_dim)+'TangentGaussianMixture'+data_date+'.pth')\n",
    "torch.save(Sndataset_testscore, 'S'+str(sph_dim)+'TangentGaussianMixtureForTest'+data_date+'.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Ten mixtures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.9546)\n",
      "1000\n",
      "1000\n",
      "1000\n",
      "1000\n",
      "tensor(0) tensor(0) tensor(0) tensor(0)\n",
      "tensor(0) tensor(0) tensor(0) tensor(0)\n"
     ]
    }
   ],
   "source": [
    "N = 10000\n",
    "Nmix = 10\n",
    "\n",
    "sph_dim = 10\n",
    "dist = 1\n",
    "\n",
    "data_date = '210521m'+str(Nmix)\n",
    "\n",
    "var = 0.01\n",
    "Cov_sqrt = np.sqrt(var)*torch.eye(sph_dim)\n",
    "Cov_sqrts = torch.zeros(Nmix,sph_dim,sph_dim)\n",
    "for i in range(Nmix):\n",
    "    Cov_sqrts[i] = Cov_sqrt.clone()\n",
    "\n",
    "CovInvs = torch.zeros(Nmix, sph_dim, sph_dim)\n",
    "for i in range(Nmix):\n",
    "    CovInvs[i] = torch.inverse(torch.mm(Cov_sqrts[i], Cov_sqrts[i]))\n",
    "    \n",
    "means = torch.zeros(Nmix,sph_dim)\n",
    "for i in range(Nmix):\n",
    "    means[i,i] = dist/math.sqrt(2)\n",
    "\n",
    "pos = getPos_torch(Exp_torch(torch.FloatTensor(means)))\n",
    "print(torch.acos((pos[0]*pos[1]).sum()))\n",
    "    \n",
    "Sndataset = SndataTangentGaussianMixture(N, means, Cov_sqrts)\n",
    "SnPosDataset = SnPosdataFromTh(Sndataset.train_data)\n",
    "\n",
    "N_testscore = 10000\n",
    "Sndataset_testscore = SndataTangentGaussianMixture(N_testscore, means, Cov_sqrts)\n",
    "SnPosDataset_testscore = SnPosdataFromTh(Sndataset_testscore.train_data)\n",
    "print(Sndataset.train_data.isnan().sum(), Sndataset.train_data.isinf().sum(), \n",
    "      SnPosDataset.train_data.isnan().sum(), SnPosDataset.train_data.isinf().sum())\n",
    "print(Sndataset_testscore.train_data.isnan().sum(), Sndataset_testscore.train_data.isinf().sum(), \n",
    "      SnPosDataset_testscore.train_data.isnan().sum(), SnPosDataset_testscore.train_data.isinf().sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(0, 0, 0, 0)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "testinput1 = Sndataset_testscore.train_data.numpy()\n",
    "metricInv_sqrt_test = torch.FloatTensor(metricInvSqrt(testinput1))\n",
    "testPosInput = SnPosDataset_testscore.train_data.clone().cuda()\n",
    "traininput1 = Sndataset.train_data.numpy()\n",
    "\n",
    "# true score\n",
    "weights = np.asarray([[1/Nmix]*Nmix])\n",
    "score_true = geometricScore_coord0_tangentGaussianMixture(testinput1, weights, means.numpy(), \n",
    "                                                          CovInvs.numpy())\n",
    "score_true_train = geometricScore_coord0_tangentGaussianMixture(traininput1, weights, means.numpy(), \n",
    "                                                          CovInvs.numpy())\n",
    "print(weights)\n",
    "np.isinf(score_true).sum(), np.isnan(score_true).sum(), np.isinf(score_true_train).sum(), np.isnan(score_true_train).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(Sndataset, 'S'+str(sph_dim)+'TangentGaussianMixture'+data_date+'.pth')\n",
    "torch.save(Sndataset_testscore, 'S'+str(sph_dim)+'TangentGaussianMixtureForTest'+data_date+'.pth')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
