{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.special import expit\n",
    "from scipy.stats import gamma, poisson, bernoulli, truncnorm\n",
    "import influence_cavi as ic\n",
    "import minibatch_influence_cavi as mic\n",
    "import utils as ut\n",
    "import scipy.sparse as sparse\n",
    "from importlib import reload"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_set_overlap(Beta_p, Beta, k=50):\n",
    "    top = np.argsort(Beta)[-k:]\n",
    "    top_p = np.argsort(Beta_p)[-k:]\n",
    "    return np.intersect1d(top, top_p).shape[0]/np.union1d(top, top_p).shape[0]\n",
    "\n",
    "def report_sparsity(A, Y, Z, Gamma, Beta, Y_past):\n",
    "    print(\"Sparsity in A:\", A.mean())\n",
    "    print(\"Sparsity in Obs:\", Y.mean())\n",
    "    print(\"Sparsity in Past Obs:\", Y_past.mean())\n",
    "    print(\"Sparsity in Z:\", Z.mean())\n",
    "    print(\"Sparsity in Gamma:\", Gamma.mean())\n",
    "    print(\"Sparsity in Beta:\", Beta.mean())\n",
    "    print(\"Max in Beta:\", Beta.max())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simulate the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 193,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sparsity in A: 0.024155\n",
      "Sparsity in Obs: 0.069212\n",
      "Sparsity in Past Obs: 0.023637\n",
      "Sparsity in Z: 0.04965549912496958\n",
      "Sparsity in Gamma: 0.04849078587465319\n",
      "Sparsity in Beta: 0.07175969865499741\n",
      "Max in Beta: 18.100305793131714\n"
     ]
    }
   ],
   "source": [
    "N = 1000\n",
    "M = 1000\n",
    "K = 10\n",
    "\n",
    "mean = 0.05\n",
    "scale = 1./10.\n",
    "shape = mean/scale\n",
    "\n",
    "Z = gamma.rvs(shape, scale=scale, size=(N, K))\n",
    "A = poisson.rvs(np.dot(Z, Z.transpose()))\n",
    "A[A>1] = 1\n",
    "\n",
    "Gamma = gamma.rvs(shape, scale=scale, size=(M, K))\n",
    "Y_past = poisson.rvs(np.dot(Z, Gamma.transpose()))\n",
    "\n",
    "# alpha = 2.\n",
    "# scale = 1./50.\n",
    "# infmean = alpha*Z.sum(axis=1)\n",
    "# infshape = infmean/scale\n",
    "\n",
    "infshape = 0.01\n",
    "infscale = 10.\n",
    "Beta = gamma.rvs(infshape, scale=infscale, size=N)\n",
    "\n",
    "# p = 0.01\n",
    "# J = bernoulli.rvs(p, size=N)\n",
    "# Beta = truncnorm.rvs(0,10., loc=J*2., scale=1./10.,)\n",
    "\n",
    "\n",
    "M = Beta*A\n",
    "I = np.dot(M,Y_past)\n",
    "P = np.dot(Z, Gamma.transpose())\n",
    "rate = I + P\n",
    "Y = poisson.rvs(rate)\n",
    "\n",
    "report_sparsity(A, Y, Z, Gamma, Beta, Y_past)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "metadata": {},
   "outputs": [],
   "source": [
    "reload(ic)\n",
    "model = ic.PoissonInfluenceModel(n_components=K, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 157,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Bound: -14778.085262928942\n",
      "Bound: -12310.046818572138\n",
      "Bound: -11550.900993437725\n",
      "Bound: -11443.770947940706\n",
      "Bound: -11399.054934280368\n",
      "Bound: -11376.06446453091\n",
      "Bound: -11363.475452185914\n",
      "Bound: -11356.205791341292\n",
      "Bound: -11351.806965453663\n",
      "Bound: -11349.036105987512\n",
      "Bound: -11347.230345783859\n",
      "Bound: -11346.01925004554\n",
      "Bound: -11345.18686235424\n",
      "Bound: -11344.602536097338\n",
      "Bound: -11344.184662588847\n",
      "Bound: -11343.88083366588\n",
      "Bound: -11343.65657771697\n",
      "Bound: -11343.488746234807\n",
      "Bound: -11343.361511922467\n",
      "Bound: -11343.263880507942\n",
      "Bound: -11343.188106109072\n",
      "Bound: -11343.128660697252\n",
      "Bound: -11343.081551672623\n",
      "Bound: -11343.04386281141\n",
      "Bound: -11343.01344118614\n",
      "Bound: -11342.988680996345\n",
      "Bound: -11342.968372630323\n",
      "Bound: -11342.951596174558\n",
      "Bound: -11342.937645545351\n",
      "Bound: -11342.925973932635\n",
      "Bound: -11342.916154217513\n",
      "Bound: -11342.907850004616\n",
      "Bound: -11342.900794243933\n",
      "Bound: -11342.894773323766\n",
      "Bound: -11342.889615139498\n",
      "Bound: -11342.885180074061\n",
      "Bound: -11342.881354127057\n",
      "Bound: -11342.878043641493\n",
      "Bound: -11342.875171226973\n",
      "Bound: -11342.872672585487\n",
      "Bound: -11342.870494022874\n",
      "Bound: -11342.868590484815\n",
      "Bound: -11342.866923996775\n",
      "Bound: -11342.865462417067\n",
      "Bound: -11342.86417843425\n",
      "Bound: -11342.863048756304\n",
      "Bound: -11342.862053451301\n",
      "Bound: -11342.861175408427\n",
      "Bound: -11342.860399895166\n",
      "Bound: -11342.859714191683\n",
      "Bound: -11342.859107287666\n",
      "Bound: -11342.858569629712\n"
     ]
    }
   ],
   "source": [
    "model.fit(Y, A, Z, Y_past)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 158,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([58, 61, 85, 49, 81, 54, 69, 78, 22,  1])"
      ]
     },
     "execution_count": 158,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bhat=model.E_beta\n",
    "np.argsort(bhat)[-10:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 159,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([44, 58, 28,  7, 92, 61, 69, 22, 78,  1])"
      ]
     },
     "execution_count": 159,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.argsort(Beta)[-10:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 164,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.42857142857142855"
      ]
     },
     "execution_count": 164,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_set_overlap(bhat, Beta,k=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 194,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Bound: -274.7631100493652\n",
      "Bound: -220.67177300350585\n",
      "Bound: -217.02586862476545\n",
      "Bound: -197.9468081839389\n",
      "Bound: -192.30652667771145\n",
      "Bound: -203.74611685228734\n",
      "Bound: -199.03514404818023\n",
      "Bound: -203.83952166746033\n",
      "Bound: -191.39440287980221\n",
      "Bound: -187.07010435702702\n",
      "Bound: -199.3731983695717\n",
      "Bound: -196.7297440664757\n",
      "Bound: -201.39017247206212\n",
      "Bound: -189.82849240809148\n",
      "Bound: -185.6057707818171\n",
      "Bound: -197.9181051430487\n",
      "Bound: -195.9130155830102\n",
      "Bound: -200.3649029758743\n",
      "Bound: -189.1355135211117\n",
      "Bound: -184.93343424158076\n",
      "Bound: -197.20678554748372\n",
      "Bound: -195.5083611180293\n",
      "Bound: -199.806555311368\n",
      "Bound: -188.74785110635466\n",
      "Bound: -184.55209072915648\n",
      "Bound: -196.79024607173895\n",
      "Bound: -195.27073933702843\n",
      "Bound: -199.456084308986\n",
      "Bound: -188.50034952312615\n",
      "Bound: -184.30758940230683\n",
      "Bound: -196.51874662374235\n",
      "Bound: -195.11585757748958\n",
      "Bound: -199.21556095924035\n",
      "Bound: -188.32829789510816\n",
      "Bound: -184.13772167124898\n",
      "Bound: -196.32874497465798\n",
      "Bound: -195.0075076960611\n",
      "Bound: -199.04009869539811\n",
      "Bound: -188.20140011230285\n",
      "Bound: -184.01282914453535\n",
      "Bound: -196.1888698810366\n",
      "Bound: -194.92773629505828\n",
      "Bound: -198.90629932340212\n",
      "Bound: -188.1036493011061\n",
      "Bound: -183.91705954033444\n",
      "Bound: -196.08191850432803\n",
      "Bound: -194.86669339718364\n",
      "Bound: -198.80078792697194\n",
      "Bound: -188.0258121138234\n"
     ]
    }
   ],
   "source": [
    "reload(mic)\n",
    "model = mic.PoissonInfluenceModel(n_components=K, batch_size=200, verbose=True)\n",
    "model.fit_stochastic(Y, A, Z, Y_past)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 197,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8181818181818182"
      ]
     },
     "execution_count": 197,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bhat=model.E_beta\n",
    "get_set_overlap(bhat, Beta,k=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 169,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MSE in each dim K: 0 0.007258264381788932\n",
      "Random baseline MSE in each dim K: 0 0.010961647941648141\n",
      "****************************************\n",
      "MSE in each dim K: 1 0.006416355209766748\n",
      "Random baseline MSE in each dim K: 1 0.010471312151890564\n",
      "****************************************\n",
      "MSE in each dim K: 2 0.007078706215902152\n",
      "Random baseline MSE in each dim K: 2 0.010332348675580649\n",
      "****************************************\n",
      "MSE in each dim K: 3 0.007703752569680112\n",
      "Random baseline MSE in each dim K: 3 0.011458356284885996\n",
      "****************************************\n",
      "MSE in each dim K: 4 0.0066277628932556246\n",
      "Random baseline MSE in each dim K: 4 0.010607716932089978\n",
      "****************************************\n",
      "MSE in each dim K: 5 0.006140618249398667\n",
      "Random baseline MSE in each dim K: 5 0.009050979336241959\n",
      "****************************************\n",
      "MSE in each dim K: 6 0.005823072514256757\n",
      "Random baseline MSE in each dim K: 6 0.009238932447931569\n",
      "****************************************\n",
      "MSE in each dim K: 7 0.007674043965004905\n",
      "Random baseline MSE in each dim K: 7 0.0112078579626222\n",
      "****************************************\n",
      "MSE in each dim K: 8 0.005819906511710219\n",
      "Random baseline MSE in each dim K: 8 0.007983937945362708\n",
      "****************************************\n",
      "MSE in each dim K: 9 0.006794922339042028\n",
      "Random baseline MSE in each dim K: 9 0.010571274967069055\n",
      "****************************************\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import mean_squared_error as mse\n",
    "ghat = model.E_gamma\n",
    "\n",
    "M=1000\n",
    "K=10\n",
    "random_gamma = gamma.rvs(shape, scale=scale, size=(M, K))\n",
    "\n",
    "for i in range(K):\n",
    "    print(\"MSE in each dim K:\", i, mse(ghat[:,i], Gamma[:,i]))\n",
    "    print(\"Random baseline MSE in each dim K:\", i, mse(random_gamma[:,i], Gamma[:,i]))\n",
    "    print(\"*\"*40)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
