{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.stats import poisson, gamma\n",
    "import spf_vi as spf\n",
    "import poisson_influence_factorization as pif\n",
    "import adjusted_model as am\n",
    "import utils as ut\n",
    "from importlib import reload"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_set_overlap(Beta_p, Beta, k=50):\n",
    "    top = np.argsort(Beta)[-k:]\n",
    "    top_p = np.argsort(Beta_p)[-k:]\n",
    "    return np.intersect1d(top, top_p).shape[0]/np.union1d(top, top_p).shape[0]\n",
    "\n",
    "def report_sparsity(A, Y, Z, Gamma, Beta, Y_past):\n",
    "    print(\"Sparsity in A:\", A.mean())\n",
    "    print(\"Sparsity in Obs:\", Y.mean())\n",
    "    print(\"Sparsity in Past Obs:\", Y_past.mean())\n",
    "    print(\"Sparsity in Z:\", Z.mean())\n",
    "    print(\"Sparsity in Gamma:\", Gamma.mean())\n",
    "    print(\"Sparsity in Beta:\", Beta.mean())\n",
    "    print(\"Max in Beta:\", Beta.max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sparsity in A: 0.049908\n",
      "Sparsity in Obs: 0.678685\n",
      "Sparsity in Past Obs: 0.102518\n",
      "Sparsity in Z: 0.050747058738564096\n",
      "Sparsity in Gamma: 0.05108208421695091\n",
      "Sparsity in Beta: 0.1086371359450045\n",
      "Max in Beta: 31.335932034746513\n"
     ]
    }
   ],
   "source": [
    "N = 1000\n",
    "M = 1000\n",
    "K = 20\n",
    "\n",
    "mean = 0.05\n",
    "scale = 1./10.\n",
    "shape = mean/scale\n",
    "\n",
    "Z = gamma.rvs(shape, scale=scale, size=(N, K))\n",
    "A = poisson.rvs(np.dot(Z, Z.transpose()))\n",
    "A[A>1] = 1\n",
    "\n",
    "eta = gamma.rvs(shape, scale=scale, size=(N, K))\n",
    "Gamma = gamma.rvs(shape, scale=scale, size=(M, K))\n",
    "\n",
    "Y_past = poisson.rvs(np.dot((Z + eta), Gamma.transpose()))\n",
    "\n",
    "infshape = 0.01\n",
    "infscale = 10.\n",
    "Beta = gamma.rvs(infshape, scale=infscale, size=N)\n",
    "\n",
    "M = Beta*A\n",
    "I = np.dot(M,Y_past)\n",
    "P = np.dot((Z+eta), Gamma.transpose())\n",
    "rate = I + P\n",
    "Y = poisson.rvs(rate)\n",
    "\n",
    "report_sparsity(A, Y, Z, Gamma, Beta, Y_past)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num. nonzero obs: 6974\n",
      "Computing element-wise product of A, Y past...\n",
      "Computing nonzero elements of AY...\n",
      "Bound: -133521.3597273614\n",
      "Bound: -63684.61832563224\n",
      "Bound: -31691.791188564093\n",
      "Bound: -25428.04393704461\n",
      "Bound: -23772.368164264994\n",
      "Bound: -23644.669067936527\n",
      "Bound: -23586.980770978866\n",
      "Bound: -23553.416659317612\n",
      "Bound: -23534.921745852884\n",
      "Bound: -23523.648849469395\n"
     ]
    }
   ],
   "source": [
    "reload(pif)\n",
    "model = pif.PoissonInfluenceModel(n_components=K, verbose=True, max_iter=10)\n",
    "model.fit(Y,A,Z,Y_past)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num. nonzero obs: 213721\n",
      "Computing element-wise product of A, Y past...\n",
      "Computing nonzero elements of AY...\n",
      "Bound: -1954011.4692824583\n",
      "Bound: -1287492.1615231142\n",
      "Bound: -1242874.3882415232\n",
      "Bound: -1198988.141037965\n",
      "Bound: -1146687.4769900506\n",
      "Bound: -1045079.3378411288\n",
      "Bound: -913279.5166077882\n",
      "Bound: -777219.9952940337\n",
      "Bound: -684478.3686974661\n",
      "Bound: -635180.4929041285\n",
      "Bound: -611225.0050030864\n",
      "Bound: -596468.1360763303\n",
      "Bound: -587425.6350380523\n",
      "Bound: -581935.2360978422\n",
      "Bound: -578887.5546272103\n",
      "Bound: -576946.3118670288\n",
      "Bound: -575781.6164765793\n",
      "Bound: -574858.236477788\n",
      "Bound: -574198.6138619623\n",
      "Bound: -573585.0093336877\n",
      "Bound: -573107.5595895171\n",
      "Bound: -572644.9166049585\n",
      "Bound: -572275.0321603825\n",
      "Bound: -571916.72622318\n",
      "Bound: -571628.0161834751\n",
      "Bound: -571351.1513163171\n",
      "Bound: -571127.8052509308\n",
      "Bound: -570915.5255692946\n",
      "Bound: -570743.7818414487\n",
      "Bound: -570581.3554398128\n",
      "Bound: -570449.307339749\n",
      "Bound: -570324.8429080596\n",
      "Bound: -570223.2164131234\n",
      "Bound: -570127.8936619038\n",
      "Bound: -570049.9550403304\n",
      "Bound: -569977.3600751488\n",
      "Bound: -569918.1500415573\n",
      "Bound: -569863.5664459136\n",
      "Bound: -569819.4370717441\n",
      "Bound: -569779.3837816678\n",
      "Bound: -569747.6271007303\n",
      "Bound: -569719.5356197341\n",
      "Bound: -569698.1864068205\n",
      "Bound: -569680.2075484068\n",
      "Bound: -569667.819822689\n",
      "Bound: -569658.5501087438\n",
      "Bound: -569653.9795297898\n",
      "Bound: -569652.2658410888\n",
      "Bound: -569654.5192947036\n",
      "Bound: -569659.364886429\n",
      "Bound: -569667.5726463445\n",
      "Bound: -569678.0983051681\n",
      "Bound: -569691.4366654688\n",
      "Bound: -569706.7677517205\n",
      "Bound: -569724.3930440049\n",
      "Bound: -569743.6599230879\n",
      "Bound: -569764.7292119311\n",
      "Bound: -569787.0688480369\n",
      "Bound: -569810.7362600318\n",
      "Bound: -569835.2927609882\n",
      "Bound: -569860.7330581689\n",
      "Bound: -569886.7059894303\n",
      "Bound: -569913.1993203592\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-35-dcf17300c3e6>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mreload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mam\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0ms\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mam\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAdjustedInfluenceModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_components\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mK\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0ms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mY\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mA\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mZ\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mY_past\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/Documents/social-fake-news/src/adjusted_model.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, Y, A, Z, Y_past)\u001b[0m\n\u001b[1;32m    182\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    183\u001b[0m                         \u001b[0;31m#E-step\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 184\u001b[0;31m                         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compute_psi_prob\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    185\u001b[0m                         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compute_rho_prob\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mK\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    186\u001b[0m                         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_exp_normalize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Documents/social-fake-news/src/adjusted_model.py\u001b[0m in \u001b[0;36m_compute_psi_prob\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m     94\u001b[0m                 \u001b[0mlog_ay\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnonzero_AY_vals\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     95\u001b[0m                 \u001b[0mlog_beta\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mspecial\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpsi\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbeta_shape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mJ\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbeta_rates\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mJ\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 96\u001b[0;31m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpsi_prob\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mrows\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mJ\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlog_ay\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mlog_beta\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     97\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     98\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0m_compute_rho_prob\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mK\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "reload(am)\n",
    "s = am.AdjustedInfluenceModel(n_components=K, verbose=True)\n",
    "s.fit(Y,A,Z,Y_past)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6129032258064516"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bhat = s.E_beta\n",
    "get_set_overlap(bhat, Beta)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "11.126268319553038 20.677016796175373\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import mean_squared_error as mse\n",
    "random_beta = gamma.rvs(infshape, scale=infscale, size=N)\n",
    "print(mse(Beta, bhat), mse(Beta, random_beta))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "gamma_hat = s.E_gamma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5354596773233189"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mse(Gamma, gamma_hat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.010239626294968825, 0.010164929830776293)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "M=1000\n",
    "randomZ = gamma.rvs(shape, scale=scale, size=(N, K))\n",
    "randomGamma = gamma.rvs(shape, scale=scale, size=(M, K))\n",
    "mse(Z, randomZ), mse(Gamma, randomGamma)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 157,
   "metadata": {},
   "outputs": [],
   "source": [
    "def exp_normalize(x):\n",
    "    b = x.max()\n",
    "    y = np.exp(x - b)\n",
    "    return y / y.sum()\n",
    "def normalize(x):\n",
    "    x = np.exp(x)\n",
    "    return x / x.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 158,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([1.23353201e-04, 9.99541338e-01, 3.35308764e-04]),\n",
       " array([1.23353201e-04, 9.99541338e-01, 3.35308764e-04]))"
      ]
     },
     "execution_count": 158,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = np.array([1., 10., 2.])\n",
    "exp_normalize(x), normalize(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
