{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.stats import poisson, gamma\n",
    "import spf_vi as spf\n",
    "import fixed_pref_influence_model as im\n",
    "import adjusted_model as am\n",
    "import confounder_model as cm\n",
    "import poisson_influence_factorization as pif\n",
    "import utils as ut\n",
    "from importlib import reload"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_set_overlap(Beta_p, Beta, k=50):\n",
    "    top = np.argsort(Beta)[-k:]\n",
    "    top_p = np.argsort(Beta_p)[-k:]\n",
    "    return np.intersect1d(top, top_p).shape[0]/np.union1d(top, top_p).shape[0]\n",
    "\n",
    "def report_sparsity(A, Y, Z, Gamma, Beta, Y_past):\n",
    "    print(\"Sparsity in A:\", A.mean())\n",
    "    print(\"Sparsity in Obs:\", Y.mean())\n",
    "    print(\"Sparsity in Past Obs:\", Y_past.mean())\n",
    "    print(\"Sparsity in Z:\", Z.mean())\n",
    "    print(\"Sparsity in Gamma:\", Gamma.mean())\n",
    "    print(\"Sparsity in Beta:\", Beta.mean())\n",
    "    print(\"Max in Beta:\", Beta.max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sparsity in A: 0.0242\n",
      "Sparsity in Obs: 0.21713\n",
      "Sparsity in Past Obs: 0.05003\n",
      "Sparsity in Z: 0.05036480383572492\n",
      "Sparsity in Gamma: 0.04898457145150887\n",
      "Sparsity in Beta: 1.350858147186832\n",
      "Max in Beta: 25.21054241541104\n"
     ]
    }
   ],
   "source": [
    "N = 100\n",
    "M = 1000\n",
    "K = 20\n",
    "\n",
    "mean = 0.05\n",
    "scale = 1./10.\n",
    "shape = mean/scale\n",
    "\n",
    "Z = gamma.rvs(shape, scale=scale, size=(N, K))\n",
    "A = poisson.rvs(np.dot(Z, Z.transpose()))\n",
    "A[A>1] = 1\n",
    "A = np.triu(A)\n",
    "non_id = 1 - np.identity(N)\n",
    "A = A*non_id\n",
    "\n",
    "# eta = gamma.rvs(shape, scale=scale, size=(N, K))\n",
    "Gamma = gamma.rvs(shape, scale=scale, size=(M, K))\n",
    "\n",
    "Y_past = poisson.rvs(Z.dot(Gamma.T))\n",
    "\n",
    "infshape = 0.1\n",
    "infscale = 10.\n",
    "Beta = gamma.rvs(infshape, scale=infscale, size=N)\n",
    "\n",
    "M = Beta*A\n",
    "I = np.dot(M,Y_past)\n",
    "P = Z.dot(Gamma.T)\n",
    "rate = I + P\n",
    "Y = poisson.rvs(rate)\n",
    "\n",
    "report_sparsity(A, Y, Z, Gamma, Beta, Y_past)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tAfter ITERATION: 15\tObjective: -35847.69\tOld objective: -35862.61\tImprovement: 0.00042\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/dhanyasridhar/Documents/social-fake-news/src/confounder_model.py:176: RuntimeWarning: invalid value encountered in double_scalars\n",
      "  improvement = (bound - old_bd) / abs(old_bd)\n"
     ]
    }
   ],
   "source": [
    "confounder_model = cm.JointPoissonMF(n_components=K, verbose=True)\n",
    "confounder_model.fit(Y_past, A)\n",
    "Z_hat = confounder_model.Et\n",
    "Gamma_hat = confounder_model.Eb.T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num. nonzero obs: 7504\n",
      "Computing element-wise product of A, Y past...\n",
      "Computing nonzero elements of AY...\n",
      "Bound: inf\n",
      "Bound: -70089.56250955052\n",
      "Bound: -39503.98981841449\n",
      "Bound: -33538.45262583048\n",
      "Bound: -26898.15431444227\n",
      "Bound: -26710.966251629976\n",
      "Bound: -26700.949263415725\n",
      "Bound: -26697.652846222187\n",
      "Bound: -26696.700979462705\n",
      "Bound: -26696.483743693552\n",
      "Bound: -26696.494606118747\n",
      "Bound: -26696.585775499134\n",
      "Bound: -26696.716726170573\n",
      "Bound: -26696.845958597558\n",
      "Bound: -26696.946206665758\n",
      "Bound: -26697.05856535593\n",
      "Bound: -26697.190682093777\n",
      "Bound: -26697.317702156848\n",
      "Bound: -26697.381821306473\n",
      "Bound: -26697.416066544585\n",
      "Bound: -26697.45302211293\n",
      "Bound: -26697.495145498113\n",
      "Bound: -26697.544530641324\n",
      "Bound: -26697.604045206932\n",
      "Bound: -26697.67745088044\n",
      "Bound: -26697.76855137118\n",
      "Bound: -26697.875677173295\n",
      "Bound: -26697.974054519433\n",
      "Bound: -26698.03379626608\n",
      "Bound: -26698.09397612748\n",
      "Bound: -26698.17210944273\n",
      "Bound: -26698.261744109233\n",
      "Bound: -26698.32028992983\n",
      "Bound: -26698.327297877302\n"
     ]
    }
   ],
   "source": [
    "reload(im)\n",
    "model = im.PoissonInfluenceModel(n_components=K, verbose=True)\n",
    "model.fit(Y,A,Z_hat,Gamma_hat,Y_past)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6949152542372882"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bhat = model.E_beta\n",
    "get_set_overlap(bhat, Beta)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num. nonzero obs: 7504\n",
      "Computing element-wise product of A, Y past...\n",
      "Computing nonzero elements of AY...\n",
      "Bound: -97979.58361404922\n",
      "Bound: -58770.66295552756\n",
      "Bound: -40401.77439273102\n",
      "Bound: -28930.458336360876\n",
      "Bound: -25410.84762622217\n",
      "Bound: -24947.87119307924\n",
      "Bound: -24925.98176779184\n",
      "Bound: -24922.21206655182\n",
      "Bound: -24921.11876727831\n",
      "Bound: -24920.71153248907\n",
      "Bound: -24920.55823811791\n",
      "Bound: -24920.50554911856\n",
      "Bound: -24920.46893603228\n",
      "Bound: -24920.44077606642\n",
      "Bound: -24920.428211686925\n",
      "Bound: -24920.422759107834\n",
      "Bound: -24920.405525801845\n",
      "Bound: -24920.385720056718\n",
      "Bound: -24920.366608853823\n",
      "Bound: -24920.352109185475\n",
      "Bound: -24920.350145625962\n",
      "Bound: -24920.369274951518\n",
      "Bound: -24920.394413718033\n",
      "Bound: -24920.40844332546\n",
      "Bound: -24920.41445607328\n",
      "Bound: -24920.417893778857\n",
      "Bound: -24920.42010009695\n",
      "Bound: -24920.42162337006\n",
      "Bound: -24920.422723660853\n",
      "Bound: -24920.42354023192\n",
      "Bound: -24920.424156510435\n",
      "Bound: -24920.424626858596\n",
      "Bound: -24920.424988749073\n",
      "Bound: -24920.425268962088\n",
      "Bound: -24920.425487088163\n",
      "Bound: -24920.425657683736\n",
      "Bound: -24920.42579168378\n",
      "Bound: -24920.425897369623\n",
      "Bound: -24920.42598105273\n",
      "Bound: -24920.426047568377\n",
      "Bound: -24920.426100637633\n",
      "Bound: -24920.4261431356\n",
      "Bound: -24920.42617729217\n",
      "Bound: -24920.4262048429\n",
      "Bound: -24920.426227143413\n",
      "Bound: -24920.426245256123\n",
      "Bound: -24920.42626001654\n",
      "Bound: -24920.426272083954\n",
      "Bound: -24920.426281980428\n",
      "Bound: -24920.42629012077\n",
      "Bound: -24920.426296835703\n",
      "Bound: -24920.426302389853\n",
      "Bound: -24920.42630699566\n",
      "Bound: -24920.4263108243\n",
      "Bound: -24920.426314014116\n",
      "Bound: -24920.42631667734\n",
      "Bound: -24920.42631890529\n",
      "Bound: -24920.42632077253\n",
      "Bound: -24920.426322340118\n",
      "Bound: -24920.426323658186\n",
      "Bound: -24920.426324768065\n",
      "Bound: -24920.426325703866\n",
      "Bound: -24920.426326493845\n",
      "Bound: -24920.426327161455\n",
      "Bound: -24920.42632772623\n",
      "Bound: -24920.42632820444\n",
      "Bound: -24920.426328609712\n",
      "Bound: -24920.4263289534\n",
      "Bound: -24920.42632924509\n",
      "Bound: -24920.426329492777\n",
      "Bound: -24920.426329703238\n",
      "Bound: -24920.426329882157\n",
      "Bound: -24920.42633003432\n",
      "Bound: -24920.42633016378\n",
      "Bound: -24920.426330273982\n",
      "Bound: -24920.42633036781\n",
      "Bound: -24920.42633044772\n",
      "Bound: -24920.42633051581\n",
      "Bound: -24920.426330573824\n",
      "Bound: -24920.42633062327\n",
      "Bound: -24920.426330665432\n",
      "Bound: -24920.426330701375\n",
      "Bound: -24920.426330732025\n",
      "Bound: -24920.42633075817\n",
      "Bound: -24920.426330780476\n",
      "Bound: -24920.426330799502\n",
      "Bound: -24920.426330815735\n",
      "Bound: -24920.426330829592\n",
      "Bound: -24920.426330841405\n",
      "Bound: -24920.4263308515\n",
      "Bound: -24920.42633086011\n",
      "Bound: -24920.42633086746\n",
      "Bound: -24920.426330873728\n",
      "Bound: -24920.42633087909\n",
      "Bound: -24920.426330883656\n",
      "Bound: -24920.426330887563\n",
      "Bound: -24920.426330890896\n",
      "Bound: -24920.426330893744\n",
      "Bound: -24920.426330896167\n",
      "Bound: -24920.426330898248\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.6949152542372882"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = pif.PoissonInfluenceModel(n_components=K, verbose=True)\n",
    "model.fit(Y,A,Z_hat,Y_past)\n",
    "bhat = model.E_beta\n",
    "get_set_overlap(bhat, Beta)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
