{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Welcome!\n",
    "\n",
    "This is also a notebook to let you play around with the extraction method described in our ICML 2018 paper, [Extracting Automata from Recurrent Neural Networks Using Queries and Counterexamples](https://arxiv.org/abs/1711.09576), only without documentation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# from LSTM import LSTMNetwork\n",
    "# from GRU import GRUNetwork\n",
    "# from RNNClassifier import RNNClassifier\n",
    "from Training_Functions import mixed_curriculum_train\n",
    "from Tomita_Grammars import tomita_1, tomita_2, tomita_3, tomita_4, tomita_5, tomita_6, tomita_7\n",
    "from Training_Functions import make_train_set_for_target\n",
    "from Extraction import extract"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target = tomita_3\n",
    "alphabet = \"01\"\n",
    "\n",
    "# alternative option (example):\n",
    "# def target(w):\n",
    "#     if len(w)==0:\n",
    "#         return True\n",
    "#     return w[0]==w[-1]\n",
    "# alphabet = \"abc\"\n",
    "\n",
    "train_set = make_train_set_for_target(target,alphabet)\n",
    "#rnn = RNNClassifier(alphabet,num_layers=1,hidden_dim=10,RNNClass = LSTMNetwork)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(train_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "mixed_curriculum_train(rnn,train_set,stop_threshold = 0.0005)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_words = sorted(list(train_set.keys()),key=lambda x:len(x))\n",
    "pos = next((w for w in all_words if train_set[w]==True),None)\n",
    "neg = next((w for w in all_words if train_set[w]==False),None)\n",
    "starting_examples = [w for w in [pos,neg] if not None == w]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['', '10']\n"
     ]
    }
   ],
   "source": [
    "print(starting_examples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "all_words = sorted(list(train_set.keys()),key=lambda x:len(x))\n",
    "pos = next((w for w in all_words if rnn.classify_word(w)==True),None)\n",
    "neg = next((w for w in all_words if rnn.classify_word(w)==False),None)\n",
    "starting_examples = [w for w in [pos,neg] if not None == w]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "rnn.renew()  \n",
    "# you only really need this if you start messing about and doing weird stuff. \n",
    "# It cleans the computation graph, but doesn't reset the weights so don't worry"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "dfa = extract(rnn,time_limit = 50,initial_split_depth = 10,starting_examples=starting_examples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import pow\n",
    "def percent(num,digits=2):\n",
    "    tens = pow(10,digits)\n",
    "    return str(int(100*num*tens)/tens)+\"%\"\n",
    "\n",
    "dfa.draw_nicely(maximum=30) #max size willing to draw\n",
    "\n",
    "test_set = train_set \n",
    "print(\"testing on train set, i.e. test set is train set\")\n",
    "# we're printing stats on the train set for now, but you can define other test sets by using\n",
    "# make_train_set_for_target again\n",
    "\n",
    "n = len(test_set)\n",
    "print(\"test set size:\", n)\n",
    "pos = len([w for w in test_set if target(w)])\n",
    "print(\"of which positive:\",pos,\"(\"+percent(pos/n)+\")\")\n",
    "rnn_target = len([w for w in test_set if rnn.classify_word(w)==target(w)])\n",
    "print(\"rnn score against target on test set:\",rnn_target,\"(\"+percent(rnn_target/n)+\")\")\n",
    "dfa_rnn = len([w for w in test_set if rnn.classify_word(w)==dfa.classify_word(w)])\n",
    "print(\"extracted dfa score against rnn on test set:\",dfa_rnn,\"(\"+percent(dfa_rnn/n)+\")\")\n",
    "dfa_target = len([w for w in test_set if dfa.classify_word(w)==target(w)])\n",
    "print(\"extracted dfa score against target on rnn's test set:\",dfa_target,\"(\"+percent(dfa_target/n)+\")\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.11.2 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.2"
  },
  "vscode": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
