{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e14337a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy\n",
    "from allennlp.predictors.predictor import Predictor\n",
    "import allennlp_models.coref\n",
    "#print('ner')\n",
    "#import allennlp_models.ner\n",
    "#print('consti')\n",
    "#import allennlp_models.syntax.constituency_parser\n",
    "#print('srl')\n",
    "#import allennlp_models.syntax.srl\n",
    "import spacy\n",
    "from spacy import displacy\n",
    "from collections import Counter\n",
    "import en_core_web_sm\n",
    "import os\n",
    "import math\n",
    "import csv\n",
    "import itertools\n",
    "import pandas as pd\n",
    "import re\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "from joblib import Parallel, delayed\n",
    "total = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29e6e02c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import nltk\n",
    "from nltk.corpus import stopwords         \n",
    "from nltk.tokenize import word_tokenize   \n",
    "from nltk.stem import PorterStemmer       \n",
    "from nltk.stem import WordNetLemmatizer  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b5e5065",
   "metadata": {},
   "outputs": [],
   "source": [
    "stop_words = set(stopwords.words('english'))  \n",
    "text = \"The 2011201312 VCU Rams men's basketball team, led by third year head coach Shaka Smart, represented Virginia Commonwealth University which was founded in what year?\"  #文本\n",
    "word_tokens = word_tokenize(text)   \n",
    "\n",
    "print(word_tokens)\n",
    "\n",
    "if \"had\" in word_tokens:\n",
    "    print(\"yes\")\n",
    "else:\n",
    "    print(\"no\")\n",
    "#filtered_sentence = list(set(word_tokens)-stop_words) \n",
    "filtered_sentence = [w for w in word_tokens if w not in stop_words]\n",
    "print(\" \".join(word_tokens)) \n",
    "print(\" \".join(filtered_sentence)) \n",
    "\n",
    "stem_words = []\n",
    "ps =PorterStemmer()\n",
    "for w in filtered_sentence: \n",
    "\trootWord=ps.stem(w) \n",
    "\tstem_words.append(rootWord)\n",
    "print(filtered_sentence)\n",
    "print(stem_words)\n",
    "\n",
    "lemma_word = []\n",
    "wordnet_lemmatizer = WordNetLemmatizer()\n",
    "for w in filtered_sentence: \n",
    "\tword1 = wordnet_lemmatizer.lemmatize(w, pos = \"n\") \n",
    "\tword2 = wordnet_lemmatizer.lemmatize(word1, pos = \"v\") \n",
    "\tword3 = wordnet_lemmatizer.lemmatize(word2, pos = (\"a\")) \n",
    "\t#pos参数 是词性\n",
    "\tlemma_word.append(word3)\n",
    "print(lemma_word)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "441a00f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def nltk_wordlist(sent):\n",
    "    word_list = word_tokenize(sent) \n",
    "    return word_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a242a04",
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_dict(str):\n",
    "    dic={\n",
    "        }\n",
    "    key,v = \"\",\"\"\n",
    "    s = 0\n",
    "    for x in str:\n",
    "        if x == '[':\n",
    "            if s == 0:\n",
    "                s = 1\n",
    "        elif x == ':':\n",
    "            if s == 1:\n",
    "                s = 2\n",
    "                v = \"\"\n",
    "        elif x== ']':\n",
    "            if s == 2:\n",
    "                dic[key] = v.strip()\n",
    "                key,v = \"\",\"\"\n",
    "                s = 0\n",
    "        else:\n",
    "            if s == 1:\n",
    "                key += x\n",
    "            elif s == 2:\n",
    "                v += x\n",
    "            else:\n",
    "                s = 0\n",
    "    return dic\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91466f6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor_srl = Predictor.from_path('bert-base-srl-2020.03.24.tar.gz')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7181fb17",
   "metadata": {},
   "outputs": [],
   "source": [
    "def verb_num_be_dict(num,sent):\n",
    "    \n",
    "    #print(sent)\n",
    "    dict_be={}\n",
    "    special_word= ['is','are','were','have','has','was','Is','Are','Was','Were','Have','Has','Had']\n",
    "    if num == 0:\n",
    "        dict_sent = srl_sent_is_to_dict(sent)\n",
    "        if dict_sent == None:\n",
    "            return None\n",
    "        return dict_sent\n",
    "    if num == 1 and \"born\" in nltk_wordlist(sent) :\n",
    "        sent_trunk = pro_srl_sent(sent)\n",
    "        dict_sent = srl_sent_is_to_dict(sent_trunk)\n",
    "        return dict_sent\n",
    "    \n",
    "    if num > 0 :\n",
    "        for sp_word in special_word:\n",
    "              if sp_word in sent[0:60]:\n",
    "                    left_label = sent.find(sp_word)\n",
    "                    right_label = left_label + len(sp_word)\n",
    "                    \n",
    "                    first_verb_lefr_label = sent.find(first_verb(predictor_srl,sent))\n",
    "                    if first_verb_lefr_label > left_label and first_verb_lefr_label - left_label > 5 :\n",
    "                        dict_be['ARG0'] = sent[0:left_label]\n",
    "                        dict_be['V'] = sp_word\n",
    "                        dict_be['ARG1'] = sent[right_label:first_verb_lefr_label]\n",
    "                        return dict_be"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b861f94e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def verb_num(predictor_srl,sent):           \n",
    "    output = predictor_srl.predict(sent)\n",
    "    v_num = len(output['verbs'])\n",
    "    \n",
    "    return v_num\n",
    "    \n",
    "def first_verb(predictor_srl,sent):\n",
    "    output = predictor_srl.predict(sent)\n",
    "    first_verb = output['verbs'][0]['verb']\n",
    "        \n",
    "    return first_verb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b39ffd2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def srl_parse_sent(predictor_srl,sent):\n",
    "    sent = sent.replace('] ',' ').replace('[ ',' ').replace('; ',' ').replace(': ',' ')  \n",
    "    sent = pro_srl_sent(sent) \n",
    "\n",
    "    output = predictor_srl.predict(sent)\n",
    "    #print(output)\n",
    "    srl_v_dic = []\n",
    "    \n",
    "    speci = verb_num_be_dict(len(output['verbs']),sent)\n",
    "    if speci != {} and speci != None:\n",
    "    #if  speci != None:\n",
    "        srl_v_dic.append(speci)\n",
    "\n",
    "    for v_dic in output[\"verbs\"]:\n",
    "        srl_v_dic.append( to_dict(v_dic[\"description\"]) )\n",
    "    return srl_v_dic\n",
    "\n",
    "def refine_data( case ):\n",
    "    sup_title = set([x[0] for x in case[\"supporting_facts\"] ])\n",
    "    gold_context = []\n",
    "    \n",
    "    for para in case[\"context\"]:\n",
    "        if para[0] in sup_title:\n",
    "            sub_context = []\n",
    "            for sent_idx,sent in enumerate(para[1]):\n",
    "                sub_context.append([ sent , srl_parse_sent(predictor_srl,sent) , [para[0],sent_idx] in case[\"supporting_facts\"] ])\n",
    "            gold_context.append(sub_context)\n",
    "    #total += 1\n",
    "    #print(\"finish {}\".format(total))\n",
    "    return dict([\n",
    "                    (\"_id\",case[\"_id\"]),\n",
    "                    (\"answer\",case['answer']),\n",
    "                    (\"question\", case['question']),\n",
    "                    (\"srl_question\", srl_parse_sent(predictor_srl,case['question']) ),\n",
    "                    (\"supporting_facts\",case[\"supporting_facts\"]),\n",
    "                    (\"context\",gold_context),\n",
    "                    (\"type\",case[\"type\"]),\n",
    "                    (\"level\",case[\"level\"]),\n",
    "                    ])\n",
    "\n",
    "def dataset(file_name,predictor_srl,out_filename):\n",
    "\n",
    "    with open(file_name, \"r\", encoding='utf-8') as reader:\n",
    "            orig_data = json.load(reader)\n",
    "    #orig_data = orig_data[:3]\n",
    "    print(\"Load ok\")\n",
    "    srl_data = []\n",
    "    #srl_data = Parallel(n_jobs=8, verbose=20)(delayed(refine_data)(article) for article in orig_data)\n",
    "    for article in tqdm(orig_data):\n",
    "        srl_data.append( refine_data(article) )\n",
    "    #srl_data = [refine_data(article) for article in orig_data ]\n",
    "    print(\"write data\")\n",
    "    with open(out_filename,'w') as file_obj:\n",
    "        json.dump(srl_data,file_obj)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d26ddb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "dev_file = \"hotpot_train_v1.1.json\"\n",
    "out_file = \"srl_hotpot_train_v1.1.json\"\n",
    "dataset(dev_file,predictor_srl,out_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "765a18e3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e963a320",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54854254",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pt3_py3",
   "language": "python",
   "name": "pt3_py3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
