{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "import re\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEP_TOKEN = \" <S>\"\n",
    "\n",
    "def camel_case_split(term):\n",
    "    matches = re.finditer('.+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)', term)\n",
    "    return [m.group(0) for m in matches]\n",
    "\n",
    "def attached_number_split(term):    \n",
    "    matches = re.finditer(r\"([a-zA-Z]+)|([0-9]+)\", term)\n",
    "    return [m.group(0) for m in matches]\n",
    "\n",
    "def cyc_term_split(term):\n",
    "    # splits the string in camelCase, snake-case and underlines\n",
    "    ret = term.split(\"-\")\n",
    "    ret = [y for x in ret for y in x.split(\"_\")]\n",
    "    ret = [y for x in ret for y in camel_case_split(x)]\n",
    "    ret = [y for x in ret for y in attached_number_split(x)]\n",
    "    return ret\n",
    "\n",
    "def fact_to_datapoint(fact):    \n",
    "    text = \"\"\n",
    "    for term in fact:  \n",
    "        if text != \"\":\n",
    "            text += SEP_TOKEN\n",
    "        for word in cyc_term_split(term):\n",
    "            if text != \"\":\n",
    "                text += \" \"\n",
    "            text += word.lower()\n",
    "    text += \".\"\n",
    "    return text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "def process_fact_file(filename = 'usable_facts.txt'):\n",
    "    with open(filename) as file:\n",
    "        usable_facts = [line.rstrip() for line in file]    \n",
    "        simple_facts = [fact[1:-1].split(\" \") for fact in usable_facts \n",
    "                        if fact[0] == \"(\" and fact[-1] == \")\" and not \"(\" in fact[1:-1]]\n",
    "        unique_terms = list(set([el for fact in simple_facts for el in fact if len(el) > 0]))\n",
    "        unique_words = list(set([word.lower() for term in unique_terms for word in cyc_term_split(term)]))\n",
    "        print(\"number of usable facts: \", len(usable_facts))\n",
    "        print(usable_facts[:10])\n",
    "        print(\"number of simple facts: \", len(simple_facts))\n",
    "        print(simple_facts[:10])\n",
    "        print(\"number of unique elements: \", len(unique_terms))\n",
    "        print(unique_terms[:10])\n",
    "        print(\"number of unique words: \", len(unique_words))\n",
    "        print(unique_words[:10])\n",
    "        print(\"\")\n",
    "        print(\"datapoints: \", [fact_to_datapoint(fact) for fact in simple_facts[:20]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# save datapoints to files\n",
    "DATASET_SIZE = len(simple_facts)\n",
    "SPLIT_LEN = int(0.9*DATASET_SIZE)\n",
    "idxs = list(range(DATASET_SIZE))\n",
    "random.shuffle(idxs)\n",
    "train_idxs = idxs[:SPLIT_LEN]\n",
    "test_idxs = idxs[SPLIT_LEN:]\n",
    "\n",
    "with open('fact_dataset_train.txt', 'w') as file:\n",
    "    for i, fact in enumerate(simple_facts):\n",
    "        if i in train_idxs:\n",
    "            datapoint = fact_to_datapoint(fact)\n",
    "            file.write(datapoint)\n",
    "            file.write(\"\\n\")\n",
    "            \n",
    "with open('fact_dataset_test.txt', 'w') as file:\n",
    "    for i, fact in enumerate(simple_facts):\n",
    "        if i in test_idxs:\n",
    "            datapoint = fact_to_datapoint(fact)\n",
    "            file.write(datapoint)\n",
    "            file.write(\"\\n\")            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup GPT-2 model\n",
    "import torch\n",
    "\n",
    "from transformers import LineByLineTextDataset, TextDataset, DataCollatorForLanguageModeling\n",
    "from transformers import GPT2Tokenizer, GPT2Model, GPT2LMHeadModel\n",
    "from transformers import Trainer, TrainingArguments\n",
    "\n",
    "tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token = \"<PAD>\")\n",
    "model = GPT2LMHeadModel.from_pretrained('gpt2').cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare dataset\n",
    "train_dataset = TextDataset(\n",
    "    tokenizer=tokenizer,\n",
    "    file_path=\"./fact_dataset_train.txt\",\n",
    "    # TODO: maybe this can be much less?\n",
    "    block_size=64,\n",
    ")\n",
    "\n",
    "eval_dataset = TextDataset(\n",
    "    tokenizer=tokenizer,\n",
    "    file_path=\"./fact_dataset_test.txt\",\n",
    "    # TODO: maybe this can be much less?\n",
    "    block_size=64,\n",
    ")\n",
    "\n",
    "data_collator = DataCollatorForLanguageModeling(\n",
    "    tokenizer=tokenizer, mlm=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(train_dataset))\n",
    "print(train_dataset[0].size())\n",
    "print(train_dataset[1].size())\n",
    "print(tokenizer.pad_token_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup trainer\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=\"./fact_lm\",\n",
    "    overwrite_output_dir=True,\n",
    "    num_train_epochs=1,\n",
    "    per_gpu_train_batch_size=16,\n",
    "    save_steps=5000,\n",
    "    save_total_limit=2,\n",
    "    no_cuda=True\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    data_collator=data_collator,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=eval_dataset,\n",
    "    prediction_loss_only=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train on fact dataset\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.save_model(\"./fact_lm\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "next_word = pipeline(\n",
    "    \"text-generation\",\n",
    "    model=\"./fact_lm\",\n",
    "    tokenizer=tokenizer\n",
    ")\n",
    "\n",
    "next_word(\"contrary feelings <S> love <S> \")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
