{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example use of VGAN (with and without kernel learning) on a Normal Population"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.vgan import VGAN_no_kl, VGAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_516716/1230951515.py:5: RuntimeWarning: covariance is not symmetric positive-semidefinite.\n",
      "  X_data = np.random.multivariate_normal(mean, cov, 2000)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np \n",
    "mean = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
    "cov = [[1, 0, 0, 0, 0, 0, 0, 0, 500, 500], [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],\n",
    "           [0, 0, 0, 0, 0, 0, 0, 1, 0, 0], [500, 0, 0, 0, 0, 0, 0, 0, 1, 500], [500, 0, 0, 0, 0, 0, 0, 0, 500, 1]]\n",
    "X_data = np.random.multivariate_normal(mean, cov, 2000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The networks follow a very simple scheme for its use-case. Before anything, let's first fit the networks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/i40/cribeiro/Software/Python/V-GAN/src/vgan.py:588: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:83.)\n",
      "  noise_tensor = torch.cuda.FloatTensor(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                             \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch: 4.556676626205444\n",
      "Epoch 1 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch: 4.326139688491821\n",
      "Epoch 2 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch: 4.559314966201782\n",
      "Epoch 3 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch: 4.552696943283081\n",
      "Epoch 4 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch: 4.55057966709137\n",
      "Epoch 5 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch: 4.565077781677246\n",
      "Epoch 6 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch: 4.556469559669495\n",
      "Epoch 7 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch: 4.565714240074158\n",
      "Epoch 8 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch: 4.547932386398315\n",
      "Epoch 9 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch: 4.551324725151062\n",
      "Epoch 10 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch: 4.549830436706543\n",
      "Epoch 11 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch: 4.555707335472107\n",
      "Epoch 12 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch: 4.552123188972473\n",
      "Epoch 13 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch: 4.559509992599487\n",
      "Epoch 14 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch: 4.560412883758545\n",
      "Epoch 0 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch Generator: nan\n",
      "Average loss in the epoch Detector: 30.823333740234375\n",
      "Epoch 1 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch Generator: 3.5762786865234375e-07\n",
      "Average loss in the epoch Detector: 30.823333740234375\n",
      "Epoch 2 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch Generator: 1.1920928955078125e-07\n",
      "Average loss in the epoch Detector: 30.823333740234375\n",
      "Epoch 3 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch Generator: 3.5762786865234375e-07\n",
      "Average loss in the epoch Detector: 30.823333740234375\n",
      "Epoch 4 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch Generator: 3.5762786865234375e-07\n",
      "Average loss in the epoch Detector: 30.823333740234375\n",
      "Epoch 5 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch Generator: 5.960464477539062e-07\n",
      "Average loss in the epoch Detector: 30.823333740234375\n",
      "Epoch 6 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch Generator: 5.960464477539062e-07\n",
      "Average loss in the epoch Detector: 30.60318374633789\n",
      "Epoch 7 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch Generator: 3.5762786865234375e-07\n",
      "Average loss in the epoch Detector: 30.60318374633789\n",
      "Epoch 8 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch Generator: 5.960464477539062e-07\n",
      "Average loss in the epoch Detector: 30.60318374633789\n",
      "Epoch 9 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch Generator: 4.76837158203125e-07\n",
      "Average loss in the epoch Detector: 30.60318374633789\n",
      "Epoch 10 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch Generator: 0.0\n",
      "Average loss in the epoch Detector: 30.60318374633789\n",
      "Epoch 11 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch Generator: -3.5762786865234375e-07\n",
      "Average loss in the epoch Detector: 30.60318374633789\n",
      "Epoch 12 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch Generator: -3.5762786865234375e-07\n",
      "Average loss in the epoch Detector: 30.533544540405273\n",
      "Epoch 13 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch Generator: 7.152557373046875e-07\n",
      "Average loss in the epoch Detector: 30.533544540405273\n",
      "Epoch 14 of 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                     "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average loss in the epoch Generator: 5.960464477539062e-07\n",
      "Average loss in the epoch Detector: 30.533544540405273\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    }
   ],
   "source": [
    "model = VGAN_no_kl(epochs=15, lr=0.001)\n",
    "model.fit(X_data)\n",
    "model_kl = VGAN(epochs=15)\n",
    "model_kl.fit(X_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once VGAN has been fitted, one can generate the subspaces (i.e., the operators) simply by using:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(array([False, False, False, False, False,  True,  True,  True,  True,\n",
      "        True]), array([False,  True,  True,  True,  True,  True, False,  True, False,\n",
      "       False]))\n"
     ]
    }
   ],
   "source": [
    "u_1 = model_kl.generate_subspaces(10) #By default, the axis parallel subspace is given as a boolean\n",
    "u_2 = model.generate_subspaces(10)\n",
    "print((u_1[0].to('cpu').numpy(), u_2[0].to('cpu').numpy())) #Subspaces are, by default, tensors"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This way, one could simply project a dataset using one subspace by:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ -2.05672058,  -0.81552988,  -0.46335229,  -1.0589473 ,\n",
       "        -10.04305584],\n",
       "       [ -1.16834401,   0.55466862,   0.19290648,  17.74766961,\n",
       "          4.62679968],\n",
       "       [  0.18583889,   1.64861358,  -1.37171922, -10.71782789,\n",
       "        -22.99980958],\n",
       "       ...,\n",
       "       [ -0.46791208,   0.95610433,   1.45573305,  42.29484227,\n",
       "         44.07136549],\n",
       "       [  1.18069009,   1.19694152,   0.45473216,  -7.80522535,\n",
       "          5.02193906],\n",
       "       [  0.48067956,  -1.18935593,  -0.67588351,  17.92311814,\n",
       "         11.02612502]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_data[:,u_1[0].to('cpu').numpy()]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Useful functions\n",
    "\n",
    "As the operator space is discret in our case, one can approximate it by sampling enough subspaces. VGAN contains the class methods to do so"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_kl.approx_subspace_dist();  model.approx_subspace_dist()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After running the class method `self.approx_subspace_dist()`, a number of subspaces (by default 500) are sampled, and then used to approximate the distribution of $\\mathbf{U}$. The subspaces can be found in `self.subspaces` and each corresponding probability in `self.proba`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Distribution of the operator obtained by using VGAN with kl:\n",
      "           0      1      2      3      4      5      6      7      8      9\n",
      "0.458  False  False  False  False  False   True   True   True   True   True\n",
      "0.542   True   True   True   True   True  False  False  False  False  False\n",
      "\n",
      " Distribution of the operator obtained by using VGAN without kl:\n",
      "           0      1      2      3     4      5      6      7      8      9\n",
      "0.282  False  False   True   True  True  False  False  False  False  False\n",
      "0.314  False  False   True   True  True  False  False   True  False  False\n",
      "0.082  False  False   True   True  True   True  False   True  False  False\n",
      "0.046  False   True  False  False  True   True  False   True  False  False\n",
      "0.060  False   True   True  False  True   True  False   True  False  False\n",
      "0.212  False   True   True   True  True   True  False   True  False  False\n",
      "0.002   True   True  False  False  True   True  False  False  False  False\n",
      "0.002   True   True  False  False  True   True  False   True  False  False\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd \n",
    "print(\"Distribution of the operator obtained by using VGAN with kl:\")\n",
    "print(pd.DataFrame(model_kl.subspaces,model_kl.proba))\n",
    "print(\"\\n Distribution of the operator obtained by using VGAN without kl:\")\n",
    "print(pd.DataFrame(model.subspaces,model.proba))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we can even check if the corresponding approximated distribution is myopic. The test is based in the MMD GoF test, as stated in the paper. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "       0.01  recommended bandwidth\n",
      "p-val   1.0                    0.0\n",
      "       0.01  recommended bandwidth\n",
      "p-val   1.0                    0.0\n"
     ]
    }
   ],
   "source": [
    "print(model.check_if_myopic(X_data)); print(model_kl.check_if_myopic(X_data)) #Not a lens operator!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "subsel_torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
