{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from runner import Config, create_weather_data\n",
    "import runner"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Example of running models on Weather data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load input data\n",
    "# 'rand_in': case 2\n",
    "# 'rand_out': case 3\n",
    "base_config = Config({'experiment_case':'rand_in',\n",
    "                     'debug':False})\n",
    "data, meta = create_weather_data(base_config, './data/jena_climate_2009_2016.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_rep = 5\n",
    "batch_size = 1024\n",
    "epochs = 10\n",
    "n_filters = 32 \n",
    "filter_width = 2\n",
    "dilation_rates = [2**i for i in range(8)] \n",
    "pred_steps = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Choose one of the models defined in runner class\n",
    "model = runner.attention_deep_time(meta['embedding_dim'], 1)\n",
    "history = model.fit(data[0],\n",
    "                    epochs=epochs,\n",
    "                    validation_data=data[1],\n",
    "                    verbose=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example of running models on Wikipedia data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load input data\n",
    "# 'rand_in': case 2\n",
    "# 'rand_out': case 3\n",
    "base_config = Config({'experiment_case':'rand_in',\n",
    "                     'debug':False})\n",
    "train, val = runner.create_wiki_data(base_config, './data/wiki_traffic.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_rep = 5\n",
    "batch_size = 1024\n",
    "epochs = 10\n",
    "n_filters = 32 \n",
    "filter_width = 2\n",
    "dilation_rates = [2**i for i in range(8)] \n",
    "pred_steps = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = runner.attention_deep_time(train[0].shape[-1], 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "history = model.fit(train[0], train[1],\n",
    "                    batch_size=batch_size,\n",
    "                    epochs=epochs,\n",
    "                    validation_split=0.2,\n",
    "                    verbose=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load input data\n",
    "# 'rand_in': case 2\n",
    "# 'rand_out': case 3\n",
    "base_config = Config({'experiment_case':'rand_out',\n",
    "                     'debug':True})\n",
    "train, val = runner.create_wiki_data(base_config, './data/wiki_traffic.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = runner.attention_deep_time(train[0].shape[-1], 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "history = model.fit(train[0], train[1],\n",
    "                    batch_size=batch_size,\n",
    "                    epochs=epochs,\n",
    "                    validation_split=0.2,\n",
    "                    verbose=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
