{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bd2425b6-2497-46fb-ac71-4cf9befb1d52",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ppi_plusplus_multi import *\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "# folder_dir = \"../mexico_predictions/\"\n",
    "# site_name = \"mexico\"\n",
    "# # mod_name = \"gpt4zs\"\n",
    "# mod_name = \"bert\"\n",
    "# file_name = folder_dir + site_name + \"_\"+ mod_name + \".csv\"\n",
    "# df = pd.read_csv(file_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "978a4a4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv('../data/results/mexico_KNN.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "39c3319e-4e55-4fe6-ae86-ca007cef6a6b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1306, 3)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38d87bf6-9fd7-4fe6-acee-656196adea2b",
   "metadata": {},
   "source": [
    "## Mexico"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c1c01787-8470-41e6-a461-90c31201659b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# split the df into 20%(labeled) and 80% (unlabeled)\n",
    "# stratified\n",
    "# labeled_df = df.groupby('Y', group_keys=False).apply(lambda x: x.sample(frac=0.2))\n",
    "# unlabeled_df = df.drop(labeled_df.index)\n",
    "\n",
    "# split the df into 40%(labeled) and 60% (unlabeled)\n",
    "# random\n",
    "np.random.seed(123)\n",
    "labeled_df = df.sample(frac=0.2)\n",
    "unlabeled_df = df.drop(labeled_df.index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "b72b9283-8860-4919-97d5-e49c971a38fe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimization terminated successfully.\n",
      "         Current function value: 0.961287\n",
      "         Iterations 7\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table class=\"simpletable\">\n",
       "<caption>MNLogit Regression Results</caption>\n",
       "<tr>\n",
       "  <th>Dep. Variable:</th>           <td>y</td>        <th>  No. Observations:  </th>  <td>  1306</td> \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>Model:</th>                <td>MNLogit</td>     <th>  Df Residuals:      </th>  <td>  1302</td> \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>Method:</th>                 <td>MLE</td>       <th>  Df Model:          </th>  <td>     0</td> \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>Date:</th>            <td>Wed, 20 Mar 2024</td> <th>  Pseudo R-squ.:     </th>  <td>0.05837</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>Time:</th>                <td>16:18:32</td>     <th>  Log-Likelihood:    </th> <td> -1255.4</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>converged:</th>             <td>True</td>       <th>  LL-Null:           </th> <td> -1333.3</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>Covariance Type:</th>     <td>nonrobust</td>    <th>  LLR p-value:       </th>  <td>   nan</td> \n",
       "</tr>\n",
       "</table>\n",
       "<table class=\"simpletable\">\n",
       "<tr>\n",
       "  <th>y=1</th>    <th>coef</th>     <th>std err</th>      <th>z</th>      <th>P>|z|</th>  <th>[0.025</th>    <th>0.975]</th>  \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>x1</th>  <td>    0.0125</td> <td>    0.003</td> <td>    4.846</td> <td> 0.000</td> <td>    0.007</td> <td>    0.018</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>y=2</th>    <th>coef</th>     <th>std err</th>      <th>z</th>      <th>P>|z|</th>  <th>[0.025</th>    <th>0.975]</th>  \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>x1</th>  <td>    0.0040</td> <td>    0.003</td> <td>    1.404</td> <td> 0.160</td> <td>   -0.002</td> <td>    0.010</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>y=3</th>    <th>coef</th>     <th>std err</th>      <th>z</th>      <th>P>|z|</th>  <th>[0.025</th>    <th>0.975]</th>  \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>x1</th>  <td>   -0.0294</td> <td>    0.005</td> <td>   -5.709</td> <td> 0.000</td> <td>   -0.039</td> <td>   -0.019</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>y=4</th>    <th>coef</th>     <th>std err</th>      <th>z</th>      <th>P>|z|</th>  <th>[0.025</th>    <th>0.975]</th>  \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>x1</th>  <td>    0.0429</td> <td>    0.002</td> <td>   19.309</td> <td> 0.000</td> <td>    0.039</td> <td>    0.047</td>\n",
       "</tr>\n",
       "</table>"
      ],
      "text/latex": [
       "\\begin{center}\n",
       "\\begin{tabular}{lclc}\n",
       "\\toprule\n",
       "\\textbf{Dep. Variable:}   &        y         & \\textbf{  No. Observations:  } &     1306    \\\\\n",
       "\\textbf{Model:}           &     MNLogit      & \\textbf{  Df Residuals:      } &     1302    \\\\\n",
       "\\textbf{Method:}          &       MLE        & \\textbf{  Df Model:          } &        0    \\\\\n",
       "\\textbf{Date:}            & Wed, 20 Mar 2024 & \\textbf{  Pseudo R-squ.:     } &  0.05837    \\\\\n",
       "\\textbf{Time:}            &     16:18:32     & \\textbf{  Log-Likelihood:    } &   -1255.4   \\\\\n",
       "\\textbf{converged:}       &       True       & \\textbf{  LL-Null:           } &   -1333.3   \\\\\n",
       "\\textbf{Covariance Type:} &    nonrobust     & \\textbf{  LLR p-value:       } &      nan    \\\\\n",
       "\\bottomrule\n",
       "\\end{tabular}\n",
       "\\begin{tabular}{ccccccc}\n",
       "\\textbf{y=1} & \\textbf{coef} & \\textbf{std err} & \\textbf{z} & \\textbf{P$> |$z$|$} & \\textbf{[0.025} & \\textbf{0.975]}  \\\\\n",
       "\\midrule\n",
       "\\bottomrule\n",
       "\\end{tabular}\n",
       "\\begin{tabular}{lcccccc}\n",
       "\\textbf{x1}  &       0.0125  &        0.003     &     4.846  &         0.000        &        0.007    &        0.018     \\\\\n",
       "\\bottomrule\n",
       "\\end{tabular}\n",
       "\\begin{tabular}{ccccccc}\n",
       "\\textbf{y=2} & \\textbf{coef} & \\textbf{std err} & \\textbf{z} & \\textbf{P$> |$z$|$} & \\textbf{[0.025} & \\textbf{0.975]}  \\\\\n",
       "\\midrule\n",
       "\\bottomrule\n",
       "\\end{tabular}\n",
       "\\begin{tabular}{lcccccc}\n",
       "\\textbf{x1}  &       0.0040  &        0.003     &     1.404  &         0.160        &       -0.002    &        0.010     \\\\\n",
       "\\bottomrule\n",
       "\\end{tabular}\n",
       "\\begin{tabular}{ccccccc}\n",
       "\\textbf{y=3} & \\textbf{coef} & \\textbf{std err} & \\textbf{z} & \\textbf{P$> |$z$|$} & \\textbf{[0.025} & \\textbf{0.975]}  \\\\\n",
       "\\midrule\n",
       "\\bottomrule\n",
       "\\end{tabular}\n",
       "\\begin{tabular}{lcccccc}\n",
       "\\textbf{x1}  &      -0.0294  &        0.005     &    -5.709  &         0.000        &       -0.039    &       -0.019     \\\\\n",
       "\\bottomrule\n",
       "\\end{tabular}\n",
       "\\begin{tabular}{ccccccc}\n",
       "\\textbf{y=4} & \\textbf{coef} & \\textbf{std err} & \\textbf{z} & \\textbf{P$> |$z$|$} & \\textbf{[0.025} & \\textbf{0.975]}  \\\\\n",
       "\\midrule\n",
       "\\bottomrule\n",
       "\\end{tabular}\n",
       "\\begin{tabular}{lcccccc}\n",
       "\\textbf{x1}  &       0.0429  &        0.002     &    19.309  &         0.000        &        0.039    &        0.047     \\\\\n",
       "\\bottomrule\n",
       "\\end{tabular}\n",
       "%\\caption{MNLogit Regression Results}\n",
       "\\end{center}"
      ],
      "text/plain": [
       "<class 'statsmodels.iolib.summary.Summary'>\n",
       "\"\"\"\n",
       "                          MNLogit Regression Results                          \n",
       "==============================================================================\n",
       "Dep. Variable:                      y   No. Observations:                 1306\n",
       "Model:                        MNLogit   Df Residuals:                     1302\n",
       "Method:                           MLE   Df Model:                            0\n",
       "Date:                Wed, 20 Mar 2024   Pseudo R-squ.:                 0.05837\n",
       "Time:                        16:18:32   Log-Likelihood:                -1255.4\n",
       "converged:                       True   LL-Null:                       -1333.3\n",
       "Covariance Type:            nonrobust   LLR p-value:                       nan\n",
       "==============================================================================\n",
       "       y=1       coef    std err          z      P>|z|      [0.025      0.975]\n",
       "------------------------------------------------------------------------------\n",
       "x1             0.0125      0.003      4.846      0.000       0.007       0.018\n",
       "------------------------------------------------------------------------------\n",
       "       y=2       coef    std err          z      P>|z|      [0.025      0.975]\n",
       "------------------------------------------------------------------------------\n",
       "x1             0.0040      0.003      1.404      0.160      -0.002       0.010\n",
       "------------------------------------------------------------------------------\n",
       "       y=3       coef    std err          z      P>|z|      [0.025      0.975]\n",
       "------------------------------------------------------------------------------\n",
       "x1            -0.0294      0.005     -5.709      0.000      -0.039      -0.019\n",
       "------------------------------------------------------------------------------\n",
       "       y=4       coef    std err          z      P>|z|      [0.025      0.975]\n",
       "------------------------------------------------------------------------------\n",
       "x1             0.0429      0.002     19.309      0.000       0.039       0.047\n",
       "==============================================================================\n",
       "\"\"\""
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y = df['Y'].to_numpy()\n",
    "X = df['X'].to_numpy()\n",
    "X = X.reshape(-1,1)\n",
    "\n",
    "sort_idx = np.argsort(Y)\n",
    "Y_sorted = Y[sort_idx]\n",
    "X_sorted = X[sort_idx]\n",
    "\n",
    "\n",
    "import statsmodels.api as sm\n",
    "mn_logit = sm.MNLogit(Y_sorted, X_sorted)\n",
    "mn_logit_res = mn_logit.fit(method = \"newton\", full_output = True)\n",
    "mn_logit_res.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "44ea6197-93ca-4f33-8bc2-62284d785ae1",
   "metadata": {},
   "outputs": [],
   "source": [
    "Y = labeled_df['Y'].to_numpy()\n",
    "X = labeled_df['X'].to_numpy()\n",
    "X = X.reshape(-1,1)\n",
    "Yhat = labeled_df['Y_hat'].to_numpy()\n",
    "X_unlabeled = unlabeled_df['X'].to_numpy()\n",
    "X_unlabeled = X_unlabeled.reshape(-1,1)\n",
    "Yhat_unlabeled = unlabeled_df['Y_hat'].to_numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "92b205fa-fba2-4dfd-8925-70995c37c274",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_full = np.concatenate((X, X_unlabeled))\n",
    "Y_full = np.concatenate((Y, Yhat_unlabeled))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "feaa9496-9a32-4479-85c9-334cf0da018e",
   "metadata": {},
   "outputs": [],
   "source": [
    "Y = labeled_df['Y'].to_numpy()\n",
    "\n",
    "sort_idx = np.argsort(Y)\n",
    "Y_full_sorted = Y_full[sort_idx]\n",
    "X_full_sorted = X_full[sort_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "b06dc18d-3ed2-4253-98f8-50785675b049",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimization terminated successfully.\n",
      "         Current function value: 0.901449\n",
      "         Iterations 7\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table class=\"simpletable\">\n",
       "<caption>MNLogit Regression Results</caption>\n",
       "<tr>\n",
       "  <th>Dep. Variable:</th>           <td>y</td>        <th>  No. Observations:  </th>  <td>   261</td> \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>Model:</th>                <td>MNLogit</td>     <th>  Df Residuals:      </th>  <td>   257</td> \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>Method:</th>                 <td>MLE</td>       <th>  Df Model:          </th>  <td>     0</td> \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>Date:</th>            <td>Wed, 20 Mar 2024</td> <th>  Pseudo R-squ.:     </th>  <td>0.03177</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>Time:</th>                <td>16:18:35</td>     <th>  Log-Likelihood:    </th> <td> -235.28</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>converged:</th>             <td>True</td>       <th>  LL-Null:           </th> <td> -243.00</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>Covariance Type:</th>     <td>nonrobust</td>    <th>  LLR p-value:       </th>  <td>   nan</td> \n",
       "</tr>\n",
       "</table>\n",
       "<table class=\"simpletable\">\n",
       "<tr>\n",
       "  <th>y=1</th>    <th>coef</th>     <th>std err</th>      <th>z</th>      <th>P>|z|</th>  <th>[0.025</th>    <th>0.975]</th>  \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>x1</th>  <td>    0.0151</td> <td>    0.006</td> <td>    2.569</td> <td> 0.010</td> <td>    0.004</td> <td>    0.027</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>y=2</th>    <th>coef</th>     <th>std err</th>      <th>z</th>      <th>P>|z|</th>  <th>[0.025</th>    <th>0.975]</th>  \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>x1</th>  <td>   -0.0035</td> <td>    0.007</td> <td>   -0.476</td> <td> 0.634</td> <td>   -0.018</td> <td>    0.011</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>y=3</th>    <th>coef</th>     <th>std err</th>      <th>z</th>      <th>P>|z|</th>  <th>[0.025</th>    <th>0.975]</th>  \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>x1</th>  <td>   -0.0204</td> <td>    0.010</td> <td>   -2.014</td> <td> 0.044</td> <td>   -0.040</td> <td>   -0.001</td>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>y=4</th>    <th>coef</th>     <th>std err</th>      <th>z</th>      <th>P>|z|</th>  <th>[0.025</th>    <th>0.975]</th>  \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>x1</th>  <td>    0.0448</td> <td>    0.005</td> <td>    8.646</td> <td> 0.000</td> <td>    0.035</td> <td>    0.055</td>\n",
       "</tr>\n",
       "</table>"
      ],
      "text/latex": [
       "\\begin{center}\n",
       "\\begin{tabular}{lclc}\n",
       "\\toprule\n",
       "\\textbf{Dep. Variable:}   &        y         & \\textbf{  No. Observations:  } &      261    \\\\\n",
       "\\textbf{Model:}           &     MNLogit      & \\textbf{  Df Residuals:      } &      257    \\\\\n",
       "\\textbf{Method:}          &       MLE        & \\textbf{  Df Model:          } &        0    \\\\\n",
       "\\textbf{Date:}            & Wed, 20 Mar 2024 & \\textbf{  Pseudo R-squ.:     } &  0.03177    \\\\\n",
       "\\textbf{Time:}            &     16:18:35     & \\textbf{  Log-Likelihood:    } &   -235.28   \\\\\n",
       "\\textbf{converged:}       &       True       & \\textbf{  LL-Null:           } &   -243.00   \\\\\n",
       "\\textbf{Covariance Type:} &    nonrobust     & \\textbf{  LLR p-value:       } &      nan    \\\\\n",
       "\\bottomrule\n",
       "\\end{tabular}\n",
       "\\begin{tabular}{ccccccc}\n",
       "\\textbf{y=1} & \\textbf{coef} & \\textbf{std err} & \\textbf{z} & \\textbf{P$> |$z$|$} & \\textbf{[0.025} & \\textbf{0.975]}  \\\\\n",
       "\\midrule\n",
       "\\bottomrule\n",
       "\\end{tabular}\n",
       "\\begin{tabular}{lcccccc}\n",
       "\\textbf{x1}  &       0.0151  &        0.006     &     2.569  &         0.010        &        0.004    &        0.027     \\\\\n",
       "\\bottomrule\n",
       "\\end{tabular}\n",
       "\\begin{tabular}{ccccccc}\n",
       "\\textbf{y=2} & \\textbf{coef} & \\textbf{std err} & \\textbf{z} & \\textbf{P$> |$z$|$} & \\textbf{[0.025} & \\textbf{0.975]}  \\\\\n",
       "\\midrule\n",
       "\\bottomrule\n",
       "\\end{tabular}\n",
       "\\begin{tabular}{lcccccc}\n",
       "\\textbf{x1}  &      -0.0035  &        0.007     &    -0.476  &         0.634        &       -0.018    &        0.011     \\\\\n",
       "\\bottomrule\n",
       "\\end{tabular}\n",
       "\\begin{tabular}{ccccccc}\n",
       "\\textbf{y=3} & \\textbf{coef} & \\textbf{std err} & \\textbf{z} & \\textbf{P$> |$z$|$} & \\textbf{[0.025} & \\textbf{0.975]}  \\\\\n",
       "\\midrule\n",
       "\\bottomrule\n",
       "\\end{tabular}\n",
       "\\begin{tabular}{lcccccc}\n",
       "\\textbf{x1}  &      -0.0204  &        0.010     &    -2.014  &         0.044        &       -0.040    &       -0.001     \\\\\n",
       "\\bottomrule\n",
       "\\end{tabular}\n",
       "\\begin{tabular}{ccccccc}\n",
       "\\textbf{y=4} & \\textbf{coef} & \\textbf{std err} & \\textbf{z} & \\textbf{P$> |$z$|$} & \\textbf{[0.025} & \\textbf{0.975]}  \\\\\n",
       "\\midrule\n",
       "\\bottomrule\n",
       "\\end{tabular}\n",
       "\\begin{tabular}{lcccccc}\n",
       "\\textbf{x1}  &       0.0448  &        0.005     &     8.646  &         0.000        &        0.035    &        0.055     \\\\\n",
       "\\bottomrule\n",
       "\\end{tabular}\n",
       "%\\caption{MNLogit Regression Results}\n",
       "\\end{center}"
      ],
      "text/plain": [
       "<class 'statsmodels.iolib.summary.Summary'>\n",
       "\"\"\"\n",
       "                          MNLogit Regression Results                          \n",
       "==============================================================================\n",
       "Dep. Variable:                      y   No. Observations:                  261\n",
       "Model:                        MNLogit   Df Residuals:                      257\n",
       "Method:                           MLE   Df Model:                            0\n",
       "Date:                Wed, 20 Mar 2024   Pseudo R-squ.:                 0.03177\n",
       "Time:                        16:18:35   Log-Likelihood:                -235.28\n",
       "converged:                       True   LL-Null:                       -243.00\n",
       "Covariance Type:            nonrobust   LLR p-value:                       nan\n",
       "==============================================================================\n",
       "       y=1       coef    std err          z      P>|z|      [0.025      0.975]\n",
       "------------------------------------------------------------------------------\n",
       "x1             0.0151      0.006      2.569      0.010       0.004       0.027\n",
       "------------------------------------------------------------------------------\n",
       "       y=2       coef    std err          z      P>|z|      [0.025      0.975]\n",
       "------------------------------------------------------------------------------\n",
       "x1            -0.0035      0.007     -0.476      0.634      -0.018       0.011\n",
       "------------------------------------------------------------------------------\n",
       "       y=3       coef    std err          z      P>|z|      [0.025      0.975]\n",
       "------------------------------------------------------------------------------\n",
       "x1            -0.0204      0.010     -2.014      0.044      -0.040      -0.001\n",
       "------------------------------------------------------------------------------\n",
       "       y=4       coef    std err          z      P>|z|      [0.025      0.975]\n",
       "------------------------------------------------------------------------------\n",
       "x1             0.0448      0.005      8.646      0.000       0.035       0.055\n",
       "==============================================================================\n",
       "\"\"\""
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mn_logit = sm.MNLogit(Y_full_sorted, X_full_sorted)\n",
    "mn_logit_res = mn_logit.fit(method = \"newton\", full_output = True)\n",
    "mn_logit_res.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "1766b796-6787-420f-b304-26013bbb4600",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:76: RuntimeWarning: overflow encountered in exp\n",
      "  loss0 +=  -(Xi @ Ey) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:80: RuntimeWarning: overflow encountered in exp\n",
      "  loss1 += -(Xi @ Eyhat) @ _theta + np.log(np.sum(np.exp(Xi @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:87: RuntimeWarning: overflow encountered in exp\n",
      "  loss2 += -(Xi_unlabeled @ Ey_unlabeled) @ _theta + np.log(np.sum(np.exp(Xi_unlabeled @ EY @ _theta)))\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:90: RuntimeWarning: invalid value encountered in scalar subtract\n",
      "  loss = 1 / n * loss0 - lhat_curr / n * loss1 + lhat_curr / N * loss2\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:107: RuntimeWarning: overflow encountered in exp\n",
      "  a = np.exp(X @ _theta_2D.T)\n",
      "/Users/adam/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:108: RuntimeWarning: invalid value encountered in divide\n",
      "  s = a / np.sum(a, axis=1, keepdims=True)\n"
     ]
    },
    {
     "ename": "ValueError",
     "evalue": "too many values to unpack (expected 4)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[25], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m theta_ppi_ci \u001b[38;5;241m=\u001b[39m ppi_multiclass_logistic_ci(\n\u001b[1;32m      2\u001b[0m             X,\n\u001b[1;32m      3\u001b[0m             Y,\n\u001b[1;32m      4\u001b[0m             Yhat,\n\u001b[1;32m      5\u001b[0m             X_unlabeled,\n\u001b[1;32m      6\u001b[0m             Yhat_unlabeled,\n\u001b[1;32m      7\u001b[0m             optimizer_options \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdisp\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmaxiter\u001b[39m\u001b[38;5;124m'\u001b[39m:\u001b[38;5;241m1000\u001b[39m},\n\u001b[1;32m      8\u001b[0m         )\n",
      "File \u001b[0;32m~/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:393\u001b[0m, in \u001b[0;36mppi_multiclass_logistic_ci\u001b[0;34m(X, Y, Yhat, X_unlabeled, Yhat_unlabeled, alpha, alternative, lhat, coord, optimizer_options)\u001b[0m\n\u001b[1;32m    388\u001b[0m df \u001b[38;5;241m=\u001b[39m (K\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m*\u001b[39md \n\u001b[1;32m    390\u001b[0m use_unlabeled \u001b[38;5;241m=\u001b[39m lhat \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m--> 393\u001b[0m ppi_pointest \u001b[38;5;241m=\u001b[39m ppi_multi_class_pointestimate(\n\u001b[1;32m    394\u001b[0m     X,\n\u001b[1;32m    395\u001b[0m     Y,\n\u001b[1;32m    396\u001b[0m     Yhat,\n\u001b[1;32m    397\u001b[0m     X_unlabeled,\n\u001b[1;32m    398\u001b[0m     Yhat_unlabeled,\n\u001b[1;32m    399\u001b[0m     lhat\u001b[38;5;241m=\u001b[39mlhat,\n\u001b[1;32m    400\u001b[0m     coord\u001b[38;5;241m=\u001b[39mcoord,\n\u001b[1;32m    401\u001b[0m     optimizer_options\u001b[38;5;241m=\u001b[39moptimizer_options,\n\u001b[1;32m    402\u001b[0m \n\u001b[1;32m    403\u001b[0m )\n\u001b[1;32m    405\u001b[0m grads, grads_hat, grads_hat_unlabeled, hessian, inv_hessian \u001b[38;5;241m=\u001b[39m _multiclass_ci_get_stats(\n\u001b[1;32m    406\u001b[0m     ppi_pointest,\n\u001b[1;32m    407\u001b[0m     X,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    412\u001b[0m     use_unlabeled\u001b[38;5;241m=\u001b[39muse_unlabeled,\n\u001b[1;32m    413\u001b[0m )\n\u001b[1;32m    415\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m lhat \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
      "File \u001b[0;32m~/Desktop/Github/va_nlp/utils/ppi_plusplus_multi.py:167\u001b[0m, in \u001b[0;36mppi_multi_class_pointestimate\u001b[0;34m(X, Y, Yhat, X_unlabeled, Yhat_unlabeled, lhat, coord, optimizer_options)\u001b[0m\n\u001b[1;32m    164\u001b[0m ppi_pointest \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mdelete(theta_2d(ppi_pointest_extra, K), \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mflatten(order \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mC\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m    166\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m lhat \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 167\u001b[0m     (\n\u001b[1;32m    168\u001b[0m         grads,\n\u001b[1;32m    169\u001b[0m         grads_hat,\n\u001b[1;32m    170\u001b[0m         grads_hat_unlabeled,\n\u001b[1;32m    171\u001b[0m         inv_hessian,\n\u001b[1;32m    172\u001b[0m     ) \u001b[38;5;241m=\u001b[39m _multiclass_ci_get_stats(\n\u001b[1;32m    173\u001b[0m         ppi_pointest,\n\u001b[1;32m    174\u001b[0m         X,\n\u001b[1;32m    175\u001b[0m         Y,\n\u001b[1;32m    176\u001b[0m         Yhat,\n\u001b[1;32m    177\u001b[0m         X_unlabeled,\n\u001b[1;32m    178\u001b[0m         Yhat_unlabeled,\n\u001b[1;32m    179\u001b[0m     )\n\u001b[1;32m    181\u001b[0m     lhat \u001b[38;5;241m=\u001b[39m _calc_lhat_glm(\n\u001b[1;32m    182\u001b[0m         grads,\n\u001b[1;32m    183\u001b[0m         grads_hat,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    186\u001b[0m         clip\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m    187\u001b[0m     )\n\u001b[1;32m    189\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m ppi_multi_class_pointestimate(\n\u001b[1;32m    190\u001b[0m         X,\n\u001b[1;32m    191\u001b[0m         Y,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    197\u001b[0m         coord\u001b[38;5;241m=\u001b[39mcoord,\n\u001b[1;32m    198\u001b[0m     )\n",
      "\u001b[0;31mValueError\u001b[0m: too many values to unpack (expected 4)"
     ]
    }
   ],
   "source": [
    "theta_ppi_ci = ppi_multiclass_logistic_ci(\n",
    "            X,\n",
    "            Y,\n",
    "            Yhat,\n",
    "            X_unlabeled,\n",
    "            Yhat_unlabeled,\n",
    "            optimizer_options = {'disp': True, 'maxiter':1000},\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "5f599dda",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(261,)"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "8d7f0551-3200-4c2a-9ff5-6b3a699f7f83",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "power tuning parameter value is:\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'theta_ppi_ci' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[51], line 2\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpower tuning parameter value is:\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m----> 2\u001b[0m \u001b[38;5;28mprint\u001b[39m(theta_ppi_ci[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlhat\u001b[39m\u001b[38;5;124m'\u001b[39m])\n\u001b[0;32m      3\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparameter estimates: \u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m      4\u001b[0m \u001b[38;5;28mprint\u001b[39m(theta_ppi_ci[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpointest\u001b[39m\u001b[38;5;124m'\u001b[39m])\n",
      "\u001b[1;31mNameError\u001b[0m: name 'theta_ppi_ci' is not defined"
     ]
    }
   ],
   "source": [
    "print(\"power tuning parameter value is:\")\n",
    "print(theta_ppi_ci['lhat'])\n",
    "print(\"parameter estimates: \")\n",
    "print(theta_ppi_ci['pointest'])\n",
    "print(\"prediction-powered confidence interval:\")\n",
    "print(theta_ppi_ci['ci'])\n",
    "print('SE of PPI estiamtes:')\n",
    "print(theta_ppi_ci['se'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7741e6bb-aaec-43bb-bf02-aedc725cf7a9",
   "metadata": {},
   "source": [
    "## Pemba"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "64e55fd3-3fd4-42a2-92c8-eab8c18c4e5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_dir = \"../classic_predictions/\"\n",
    "site_name = \"pemba\"\n",
    "file_name = folder_dir + \"classic_predictions_ex_\" + site_name + \".csv\"\n",
    "df = pd.read_csv(file_name)\n",
    "\n",
    "labeled_df = df.sample(frac=0.4)\n",
    "unlabeled_df = df.drop(labeled_df.index)\n",
    "\n",
    "Y = labeled_df['Y'].to_numpy()\n",
    "X = labeled_df['X'].to_numpy()\n",
    "X = X.reshape(-1,1)\n",
    "Yhat_NB = labeled_df['Yhat_NB'].to_numpy()\n",
    "Yhat = Yhat_NB\n",
    "X_unlabeled = unlabeled_df['X'].to_numpy()\n",
    "X_unlabeled = X_unlabeled.reshape(-1,1)\n",
    "Yhat_unlabeled_NB = unlabeled_df['Yhat_NB'].to_numpy()\n",
    "Yhat_unlabeled = Yhat_unlabeled_NB\n",
    "\n",
    "theta_ppi_ci = ppi_multiclass_logistic_ci(\n",
    "            X,\n",
    "            Y,\n",
    "            Yhat,\n",
    "            X_unlabeled,\n",
    "            Yhat_unlabeled,\n",
    "            optimizer_options = {'disp': True, 'maxiter':1000},\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0319b6fd-a2b6-4611-ae9a-80e973a291e9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "power tuning parameter value is:\n",
      "0.47697416050122554\n",
      "parameter estimates: \n",
      "[-0.00264279 -0.01421964  0.01279802 -0.02854514]\n",
      "prediction-powered confidence interval:\n",
      "(array([-0.00371149, -0.01726476,  0.01186213, -0.03280631]), array([-0.00157409, -0.01117452,  0.01373391, -0.02428398]))\n",
      "SE of PPI estiamtes when lambda = 0:\n",
      "[0.00064972 0.0018513  0.00056898 0.00259061]\n"
     ]
    }
   ],
   "source": [
    "print(\"power tuning parameter value is:\")\n",
    "print(theta_ppi_ci['lhat'])\n",
    "print(\"parameter estimates: \")\n",
    "print(theta_ppi_ci['pointest'])\n",
    "print(\"prediction-powered confidence interval:\")\n",
    "print(theta_ppi_ci['ci'])\n",
    "print('SE of PPI estiamtes when:')\n",
    "print(theta_ppi_ci['se'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1ae8752-1209-4ed6-873a-59bf290bcfee",
   "metadata": {},
   "source": [
    "## ap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3cd66b7a-afe1-44b5-a99a-7f457bb3f9ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_dir = \"../classic_predictions/\"\n",
    "site_name = \"ap\"\n",
    "file_name = folder_dir + \"classic_predictions_ex_\" + site_name + \".csv\"\n",
    "df = pd.read_csv(file_name)\n",
    "\n",
    "labeled_df = df.sample(frac=0.4)\n",
    "unlabeled_df = df.drop(labeled_df.index)\n",
    "\n",
    "Y = labeled_df['Y'].to_numpy()\n",
    "X = labeled_df['X'].to_numpy()\n",
    "X = X.reshape(-1,1)\n",
    "Yhat_NB = labeled_df['Yhat_NB'].to_numpy()\n",
    "Yhat = Yhat_NB\n",
    "X_unlabeled = unlabeled_df['X'].to_numpy()\n",
    "X_unlabeled = X_unlabeled.reshape(-1,1)\n",
    "Yhat_unlabeled_NB = unlabeled_df['Yhat_NB'].to_numpy()\n",
    "Yhat_unlabeled = Yhat_unlabeled_NB\n",
    "\n",
    "theta_ppi_ci = ppi_multiclass_logistic_ci(\n",
    "            X,\n",
    "            Y,\n",
    "            Yhat,\n",
    "            X_unlabeled,\n",
    "            Yhat_unlabeled,\n",
    "            optimizer_options = {'disp': True, 'maxiter':1000},\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "ee723848-c3ad-454b-b4e6-c8a3a5c2257d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The left-out site is:ap\n",
      "power tuning parameter value is:\n",
      "0.3078180666869222\n",
      "parameter estimates: \n",
      "[-0.00624251 -0.00664452  0.02804172  0.00819674]\n",
      "prediction-powered confidence interval:\n",
      "(array([-0.00899999, -0.01039156,  0.02668012,  0.00667108]), array([-0.00348502, -0.00289748,  0.02940332,  0.00972241]))\n",
      "SE of PPI estiamtes when:\n",
      "[0.00167643 0.00227804 0.00082779 0.00092754]\n"
     ]
    }
   ],
   "source": [
    "print(\"The left-out site is:\" + site_name)\n",
    "print(\"power tuning parameter value is:\")\n",
    "print(theta_ppi_ci['lhat'])\n",
    "print(\"parameter estimates: \")\n",
    "print(theta_ppi_ci['pointest'])\n",
    "print(\"prediction-powered confidence interval:\")\n",
    "print(theta_ppi_ci['ci'])\n",
    "print('SE of PPI estiamtes when:')\n",
    "print(theta_ppi_ci['se'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3ca5842-bd7b-4fb7-8ba8-faeab65ad382",
   "metadata": {},
   "source": [
    "## bohol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "624581aa-0e34-4e09-8af9-89a657d6f727",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The left-out site is:bohol\n",
      "power tuning parameter value is:\n",
      "0.39727643474894964\n",
      "parameter estimates: \n",
      "[-0.00361776 -0.03339858  0.02243895 -0.04621501]\n",
      "prediction-powered confidence interval:\n",
      "(array([-0.00494442, -0.04367511,  0.02125039, -0.05895053]), array([-0.00229111, -0.02312204,  0.02362751, -0.03347949]))\n",
      "SE of PPI estiamtes when:\n",
      "[0.00080655 0.00624769 0.00072259 0.00774265]\n"
     ]
    }
   ],
   "source": [
    "folder_dir = \"../classic_predictions/\"\n",
    "site_name = \"bohol\"\n",
    "file_name = folder_dir + \"classic_predictions_ex_\" + site_name + \".csv\"\n",
    "df = pd.read_csv(file_name)\n",
    "\n",
    "labeled_df = df.sample(frac=0.4)\n",
    "unlabeled_df = df.drop(labeled_df.index)\n",
    "\n",
    "Y = labeled_df['Y'].to_numpy()\n",
    "X = labeled_df['X'].to_numpy()\n",
    "X = X.reshape(-1,1)\n",
    "Yhat_NB = labeled_df['Yhat_NB'].to_numpy()\n",
    "Yhat = Yhat_NB\n",
    "X_unlabeled = unlabeled_df['X'].to_numpy()\n",
    "X_unlabeled = X_unlabeled.reshape(-1,1)\n",
    "Yhat_unlabeled_NB = unlabeled_df['Yhat_NB'].to_numpy()\n",
    "Yhat_unlabeled = Yhat_unlabeled_NB\n",
    "\n",
    "theta_ppi_ci = ppi_multiclass_logistic_ci(\n",
    "            X,\n",
    "            Y,\n",
    "            Yhat,\n",
    "            X_unlabeled,\n",
    "            Yhat_unlabeled,\n",
    "            optimizer_options = {'disp': True, 'maxiter':1000},\n",
    "        )\n",
    "\n",
    "print(\"The left-out site is:\" + site_name)\n",
    "print(\"power tuning parameter value is:\")\n",
    "print(theta_ppi_ci['lhat'])\n",
    "print(\"parameter estimates: \")\n",
    "print(theta_ppi_ci['pointest'])\n",
    "print(\"prediction-powered confidence interval:\")\n",
    "print(theta_ppi_ci['ci'])\n",
    "print('SE of PPI estiamtes when:')\n",
    "print(theta_ppi_ci['se'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd24f655-bd96-467a-8565-248e3b023e70",
   "metadata": {},
   "source": [
    "## dar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "066fef1b-e7dc-4681-a8eb-4ba10d47665b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The left-out site is:dar\n",
      "power tuning parameter value is:\n",
      "0.3613783126538752\n",
      "parameter estimates: \n",
      "[-0.00527376 -0.02782146  0.03221836 -0.01039343]\n",
      "prediction-powered confidence interval:\n",
      "(array([-0.00749293, -0.03821043,  0.03071345, -0.01323097]), array([-0.0030546 , -0.01743249,  0.03372328, -0.0075559 ]))\n",
      "SE of PPI estiamtes when:\n",
      "[0.00134916 0.00631605 0.00091492 0.0017251 ]\n"
     ]
    }
   ],
   "source": [
    "folder_dir = \"../classic_predictions/\"\n",
    "site_name = \"dar\"\n",
    "file_name = folder_dir + \"classic_predictions_ex_\" + site_name + \".csv\"\n",
    "df = pd.read_csv(file_name)\n",
    "\n",
    "labeled_df = df.sample(frac=0.4)\n",
    "unlabeled_df = df.drop(labeled_df.index)\n",
    "\n",
    "Y = labeled_df['Y'].to_numpy()\n",
    "X = labeled_df['X'].to_numpy()\n",
    "X = X.reshape(-1,1)\n",
    "Yhat_NB = labeled_df['Yhat_NB'].to_numpy()\n",
    "Yhat = Yhat_NB\n",
    "X_unlabeled = unlabeled_df['X'].to_numpy()\n",
    "X_unlabeled = X_unlabeled.reshape(-1,1)\n",
    "Yhat_unlabeled_NB = unlabeled_df['Yhat_NB'].to_numpy()\n",
    "Yhat_unlabeled = Yhat_unlabeled_NB\n",
    "\n",
    "theta_ppi_ci = ppi_multiclass_logistic_ci(\n",
    "            X,\n",
    "            Y,\n",
    "            Yhat,\n",
    "            X_unlabeled,\n",
    "            Yhat_unlabeled,\n",
    "            optimizer_options = {'disp': True, 'maxiter':1000},\n",
    "        )\n",
    "\n",
    "print(\"The left-out site is:\" + site_name)\n",
    "print(\"power tuning parameter value is:\")\n",
    "print(theta_ppi_ci['lhat'])\n",
    "print(\"parameter estimates: \")\n",
    "print(theta_ppi_ci['pointest'])\n",
    "print(\"prediction-powered confidence interval:\")\n",
    "print(theta_ppi_ci['ci'])\n",
    "print('SE of PPI estiamtes when:')\n",
    "print(theta_ppi_ci['se'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cdaa57cc-0298-4e65-b533-8d75a235cf1a",
   "metadata": {},
   "source": [
    "## up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "3a6713e2-a696-4713-87b2-46a0f0b7c8b6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The left-out site is:up\n",
      "power tuning parameter value is:\n",
      "0.4485352499454914\n",
      "parameter estimates: \n",
      "[ 0.00577535 -0.00656915  0.00574502 -0.05527534]\n",
      "prediction-powered confidence interval:\n",
      "(array([ 0.00349284, -0.01143386,  0.00308815, -0.0980854 ]), array([ 0.00805785, -0.00170444,  0.00840188, -0.01246529]))\n",
      "SE of PPI estiamtes when:\n",
      "[0.00138767 0.00295753 0.00161526 0.02602667]\n"
     ]
    }
   ],
   "source": [
    "folder_dir = \"../classic_predictions/\"\n",
    "site_name = \"up\"\n",
    "file_name = folder_dir + \"classic_predictions_ex_\" + site_name + \".csv\"\n",
    "df = pd.read_csv(file_name)\n",
    "\n",
    "labeled_df = df.sample(frac=0.4)\n",
    "unlabeled_df = df.drop(labeled_df.index)\n",
    "\n",
    "Y = labeled_df['Y'].to_numpy()\n",
    "X = labeled_df['X'].to_numpy()\n",
    "X = X.reshape(-1,1)\n",
    "Yhat_NB = labeled_df['Yhat_NB'].to_numpy()\n",
    "Yhat = Yhat_NB\n",
    "X_unlabeled = unlabeled_df['X'].to_numpy()\n",
    "X_unlabeled = X_unlabeled.reshape(-1,1)\n",
    "Yhat_unlabeled_NB = unlabeled_df['Yhat_NB'].to_numpy()\n",
    "Yhat_unlabeled = Yhat_unlabeled_NB\n",
    "\n",
    "theta_ppi_ci = ppi_multiclass_logistic_ci(\n",
    "            X,\n",
    "            Y,\n",
    "            Yhat,\n",
    "            X_unlabeled,\n",
    "            Yhat_unlabeled,\n",
    "            optimizer_options = {'disp': True, 'maxiter':1000},\n",
    "        )\n",
    "\n",
    "print(\"The left-out site is:\" + site_name)\n",
    "print(\"power tuning parameter value is:\")\n",
    "print(theta_ppi_ci['lhat'])\n",
    "print(\"parameter estimates: \")\n",
    "print(theta_ppi_ci['pointest'])\n",
    "print(\"prediction-powered confidence interval:\")\n",
    "print(theta_ppi_ci['ci'])\n",
    "print('SE of PPI estiamtes when:')\n",
    "print(theta_ppi_ci['se'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa4d870b-f9c6-48b8-9206-5e6689a07477",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "hide_input": false,
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
