{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import jax.random as jr\n",
    "import optax\n",
    "\n",
    "# import matplotlib\n",
    "# matplotlib.use('TkAgg') # need this on my machine for some reason\n",
    "import matplotlib.pyplot as plt\n",
    "import wandb\n",
    "\n",
    "from nmrnn.data_generation import sample_all\n",
    "from nmrnn.util import random_nmrnn_params, log_wandb_model\n",
    "from nmrnn.fitting import fit_mwg_nm_rnn, fit_mwg_nm_only\n",
    "from nmrnn.rnn_code import batched_nm_rnn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "wandb.init(entity='twss', project='twarhmm_gp', id=\"8auwcth1\", resume=\"allow\")\n",
    "\n",
    "artifact = wandb.use_artifact('nm-rnn/nm-rnn-mwg/nmrnn_r3_n100_m5:v1', type=\"model\")\n",
    "artifact_dir = artifact.download()\n",
    "# model = TWARHMM.load_wnb('twss/twarhmm_gp/twarhmm_K20_T31:v12')"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}