{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import nltk\n",
    "import json\n",
    "import numpy as np\n",
    "from nltk.corpus import stopwords         #停用词\n",
    "from nltk.tokenize import word_tokenize   #分词\n",
    "from nltk.stem import PorterStemmer       #词干化\n",
    "from nltk.stem import WordNetLemmatizer   #词形还原"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import math\n",
    "import csv\n",
    "import itertools\n",
    "import pandas as pd\n",
    "import re\n",
    "import json\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "#out_file = \"srl_hotpot_dev_distractor_v1_old3.json\"\n",
    "#dev_file = \"hotpot_train_v1.1.json\"\n",
    "#out_file = \"srl_hotpot_train_v1.1.json\"\n",
    "#dataset(dev_file,predictor_srl,out_file)\n",
    "# print(\"??????\")\n",
    "# with open(out_file, \"r\", encoding='utf-8') as reader:\n",
    "#     orig_data = json.load(reader)\n",
    "# print(\"!!!!!\")\n",
    "# print(len(orig_data))\n",
    "# print(orig_data[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "stop_words = set(stopwords.words('english'))  #英文停用分词集合\n",
    "\n",
    "#sent_dict_value1 = orig_data[0]['context'][0][0][1][0]['ARG1']\n",
    "#print(sent_dict_value1)\n",
    "def no_stop_word(sent_dict_value):\n",
    "    sent_dict_value = sent_dict_value.replace(',',' ').replace('.', ' ').replace('(',' ').replace(')',' ').replace('\"',' ')\n",
    "    word_tokens =word_tokenize(sent_dict_value)\n",
    "    #word_tokens.remove('.').remove(',')\n",
    "    filtered_sentence = [w for w in word_tokens if w not in stop_words]\n",
    "    \n",
    "    lemma_word_value = []\n",
    "    wordnet_lemmatizer = WordNetLemmatizer()\n",
    "    for w in filtered_sentence: \n",
    "        word1 = wordnet_lemmatizer.lemmatize(w, pos = \"n\") \n",
    "        word2 = wordnet_lemmatizer.lemmatize(word1, pos = \"v\") \n",
    "        word3 = wordnet_lemmatizer.lemmatize(word2, pos = (\"a\")) \n",
    "        #pos参数 是词性\n",
    "        lemma_word_value.append(word3)\n",
    "\n",
    "    \n",
    "    return lemma_word_value\n",
    "\n",
    "# print(no_stop_word(sent_dict_value1))\n",
    "# print(no_stop_word(''))    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "#第一种core_graph_data_core_other\n",
    "#=============================================================================================================================\n",
    "sent_max_num = 5\n",
    "verb_max_num = 4\n",
    "\n",
    "\n",
    "def core_graph_data_core_other(para,core_sent_dict):\n",
    "    \n",
    "    \n",
    "    \n",
    "    label_list=[\n",
    "    'ARG0',     \n",
    "    'ARG1',     \n",
    "    'V',        \n",
    "    'ARG2',     \n",
    "    'ARG3',     \n",
    "    'ARG4',     \n",
    "    'ARG5',     \n",
    "    'ARGM-COM', \n",
    "    'ARGM-LOC', \n",
    "    'ARGM-DIR', \n",
    "    'ARGM-MNR', \n",
    "    'ARGM-TMP', \n",
    "    'ARGM-EXT', \n",
    "    'ARGM-REC', \n",
    "    'ARGM-PRD', \n",
    "    'ARGM-PRP', \n",
    "    'ARGM-CAU', \n",
    "    'ARGM-DIS', \n",
    "    'ARGM-ADV', \n",
    "    'ARGM-ADJ', \n",
    "    'ARGM-MOD', \n",
    "    'ARGM-NEG', \n",
    "    'ARGM-DSP', \n",
    "    'ARGM-LVB', \n",
    "    'ARGM-CNX'   \n",
    "    ]           \n",
    "    pos_len = len(label_list)\n",
    "    \n",
    "    #print(\"arg0:\",label_list.index('ARG0'))\n",
    "    #print(\"arg1:\",label_list.index('ARG1'))\n",
    "    #print(\"ARGM-CNX:\",label_list.index('ARGM-CNX'))\n",
    "    #core_sent_dict = sent_pos_dict = sent_srl_data[1]\n",
    "    #sent_max_num = 15\n",
    "    #verb_max_num = 5\n",
    "    train_data_example = torch.zeros(sent_max_num,pos_len * verb_max_num,dtype = torch.int)\n",
    "    \n",
    "    \n",
    "#    first_para = para[0]                                                                                       \n",
    "#     print(\"1first_len:\",len(first_para))  \n",
    "#     for s in first_para:\n",
    "#         print(s)\n",
    "#         print(\"***********************************************************************************\")\n",
    "    \n",
    "    \n",
    "   # second_para = para[1]                                                                                      \n",
    "    #first_para.extend(second_para)                                                                                          \n",
    "    #print(\"1total-len:\",len(first_para))  \n",
    "    para_sum = []\n",
    "    para_sum.extend(para[0])\n",
    "    para_sum.extend(para[1])\n",
    "    sent_id = -1                                                                                                            \n",
    "    #for sent_srl_data in first_para:                                        #单个实例的所有句子                             \n",
    "        #sent_pos_dict = sent_srl_data[1]                      #0：原句字符，1：解析后的label字典合集，2：句子label ：is_sup \n",
    "        #print(\"sent_pos_dict:\",sent_pos_dict)      \n",
    "    for sent_pos_dict in para_sum:    \n",
    "        sent_id = sent_id+1                                                                                                 \n",
    "        for min_sent_pos_dict in sent_pos_dict:                            #单个句子的所有以verb计数的字典                  \n",
    "            if min_sent_pos_dict == None :                                                                                  \n",
    "                continue                                                                                                    \n",
    "            for pos_value in min_sent_pos_dict:                            # 单个字典的所有key                              \n",
    "                #print(pos_value)                                                                                            \n",
    "                #print(min_sent_pos_dict[pos_value])                                                                         \n",
    "                sent_no_stop_list=no_stop_word(min_sent_pos_dict[pos_value])                                                \n",
    "                for sent_value in sent_no_stop_list:                       #去除停用词后的所有value                         \n",
    "                    v_turn_to_num = -1                                                                                       \n",
    "                    for min_core_sent_dict in core_sent_dict:             #核心句                                           \n",
    "                        if min_core_sent_dict == None :                                                                     \n",
    "                            continue                                                                                        \n",
    "                        v_turn_to_num = v_turn_to_num+1                                                                     \n",
    "                        for core_pos_value in min_core_sent_dict:                                                           \n",
    "                            core_sent_no_stop_list=no_stop_word(min_core_sent_dict[core_pos_value])                         \n",
    "                            if sent_value in core_sent_no_stop_list:\n",
    "                                if core_pos_value not in label_list:\n",
    "                                    continue\n",
    "                                #print(\"yes\")                                                                                \n",
    "                                #train_data_example[sent_id][v_turn_to_num*pos_len+label_list.index(pos_value)]=1            \n",
    "                                train_data_example[sent_id][v_turn_to_num*pos_len+label_list.index(core_pos_value)]=1      \n",
    "                                #train_data_example[sent_id-1][v_turn_to_num*pos_len+label_list.index(core_pos_value)]  += 1\n",
    "\n",
    "                           \n",
    "                        \n",
    "    example_data = \" \".join('%s' %id for id in train_data_example.numpy().tolist())\n",
    "    return example_data\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "#第二种core_graph_data_other\n",
    "#=============================================================================================================================\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def core_graph_data_other(para,core_sent_dict):\n",
    "    \n",
    "    label_list=[\n",
    "    'ARG0',     \n",
    "    'ARG1',     \n",
    "    'V',        \n",
    "    'ARG2',     \n",
    "    'ARG3',     \n",
    "    'ARG4',     \n",
    "    'ARG5',     \n",
    "    'ARGM-COM', \n",
    "    'ARGM-LOC', \n",
    "    'ARGM-DIR', \n",
    "    'ARGM-MNR', \n",
    "    'ARGM-TMP', \n",
    "    'ARGM-EXT', \n",
    "    'ARGM-REC', \n",
    "    'ARGM-PRD', \n",
    "    'ARGM-PRP', \n",
    "    'ARGM-CAU', \n",
    "    'ARGM-DIS', \n",
    "    'ARGM-ADV', \n",
    "    'ARGM-ADJ', \n",
    "    'ARGM-MOD', \n",
    "    'ARGM-NEG', \n",
    "    'ARGM-DSP', \n",
    "    'ARGM-LVB', \n",
    "    'ARGM-CNX'   \n",
    "    ]           \n",
    "    pos_len = len(label_list)\n",
    "    \n",
    "    #print(\"arg0:\",label_list.index('ARG0'))\n",
    "    #print(\"arg1:\",label_list.index('ARG1'))\n",
    "    #print(\"ARGM-CNX:\",label_list.index('ARGM-CNX'))\n",
    "    #core_sent_dict = sent_pos_dict = sent_srl_data[1]\n",
    "    #sent_max_num = 15\n",
    "    #verb_max_num = 5\n",
    "    train_data_example = torch.zeros(sent_max_num,pos_len * verb_max_num,dtype = torch.int)\n",
    "    \n",
    "    para_sum = []\n",
    "    para_sum.extend(para[0])\n",
    "    para_sum.extend(para[1])\n",
    "    #first_para = para[0]                                                                                       \n",
    "    #print(len(first_para))                                                                                                  \n",
    "    #second_para = para[1]                                                                                      \n",
    "    #first_para.extend(second_para)                                                                                          \n",
    "    #print(\"2total-len:\",len(first_para))                                                                                                                                                                                     \n",
    "    sent_id = -1                                                                                                            \n",
    "    #for sent_srl_data in first_para:                                        #单个实例的所有句子                             \n",
    "        #sent_pos_dict = sent_srl_data[1]                      #0：原句字符，1：解析后的label字典合集，2：句子label ：is_sup \n",
    "        #print(\"sent_pos_dict:\",sent_pos_dict)    \n",
    "    for sent_pos_dict in para_sum:           \n",
    "        sent_id = sent_id+1                                                                                                 \n",
    "        for min_sent_pos_dict in sent_pos_dict:                            #单个句子的所有以verb计数的字典                  \n",
    "            if min_sent_pos_dict == None :                                                                                  \n",
    "                continue                                                                                                    \n",
    "            for pos_value in min_sent_pos_dict:                            # 单个字典的所有key                              \n",
    "                #print(pos_value)                                                                                            \n",
    "                #print(min_sent_pos_dict[pos_value])                                                                         \n",
    "                sent_no_stop_list=no_stop_word(min_sent_pos_dict[pos_value])                                                \n",
    "                for sent_value in sent_no_stop_list:                       #去除停用词后的所有value                         \n",
    "                    v_turn_to_num = -1                                                                                       \n",
    "                    for min_core_sent_dict in core_sent_dict:             #核心句                                           \n",
    "                        if min_core_sent_dict == None :                                                                     \n",
    "                            continue                                                                                        \n",
    "                        v_turn_to_num = v_turn_to_num+1                                                                     \n",
    "                        for core_pos_value in min_core_sent_dict:                                                           \n",
    "                            core_sent_no_stop_list=no_stop_word(min_core_sent_dict[core_pos_value])                         \n",
    "                            if sent_value in core_sent_no_stop_list:\n",
    "                                if pos_value not in label_list:\n",
    "                                    continue\n",
    "                                #print(\"yes\")                                                                                \n",
    "                                train_data_example[sent_id][v_turn_to_num*pos_len+label_list.index(pos_value)]=1            \n",
    "                                #train_data_example[sent_id][v_turn_to_num*pos_len+label_list.index(core_pos_value)]=1      \n",
    "                                #train_data_example[sent_id-1][v_turn_to_num*pos_len+label_list.index(core_pos_value)]  += 1\n",
    "\n",
    "                           \n",
    "                        \n",
    "    example_data = \" \".join('%s' %id for id in train_data_example.numpy().tolist())\n",
    "    return example_data\n",
    "\n",
    "\n",
    "\n",
    "#第三种core_graph_data_qu\n",
    "#=============================================================================================================================\n",
    "\n",
    "\n",
    "def core_graph_data_qu(core_sent_dict,qu_sent_dict):\n",
    "    \n",
    "    sent_pos_dict = qu_sent_dict\n",
    "    label_list=[\n",
    "    'ARG0',     \n",
    "    'ARG1',     \n",
    "    'V',        \n",
    "    'ARG2',     \n",
    "    'ARG3',     \n",
    "    'ARG4',     \n",
    "    'ARG5',     \n",
    "    'ARGM-COM', \n",
    "    'ARGM-LOC', \n",
    "    'ARGM-DIR', \n",
    "    'ARGM-MNR', \n",
    "    'ARGM-TMP', \n",
    "    'ARGM-EXT', \n",
    "    'ARGM-REC', \n",
    "    'ARGM-PRD', \n",
    "    'ARGM-PRP', \n",
    "    'ARGM-CAU', \n",
    "    'ARGM-DIS', \n",
    "    'ARGM-ADV', \n",
    "    'ARGM-ADJ', \n",
    "    'ARGM-MOD', \n",
    "    'ARGM-NEG', \n",
    "    'ARGM-DSP', \n",
    "    'ARGM-LVB', \n",
    "    'ARGM-CNX'   \n",
    "    ]           \n",
    "    pos_len = len(label_list)\n",
    "    \n",
    "    #sent_max_num = 15\n",
    "    #verb_max_num = 5\n",
    "    \n",
    "    #print(\"arg0:\",label_list.index('ARG0'))\n",
    "    #print(\"arg1:\",label_list.index('ARG1'))\n",
    "    #print(\"ARGM-CNX:\",label_list.index('ARGM-CNX'))\n",
    "    #core_sent_dict = sent_pos_dict = sent_srl_data[1]\n",
    "    \n",
    "    train_data_example = torch.zeros(pos_len * verb_max_num,dtype = torch.int)\n",
    "    \n",
    "    \n",
    "                                                                    \n",
    "    for min_sent_pos_dict in sent_pos_dict:                            #单个句子的所有以verb计数的字典                  \n",
    "        if min_sent_pos_dict == None :                                                                                  \n",
    "            continue                                                                                                    \n",
    "        for pos_value in min_sent_pos_dict:                            # 单个字典的所有key                              \n",
    "            #print(pos_value)                                                                                            \n",
    "            #print(min_sent_pos_dict[pos_value])                                                                         \n",
    "            sent_no_stop_list=no_stop_word(min_sent_pos_dict[pos_value])                                                \n",
    "            for sent_value in sent_no_stop_list:                       #去除停用词后的所有value                         \n",
    "                v_turn_to_num = -1                                                                                       \n",
    "                for min_core_sent_dict in core_sent_dict:             #核心句                                           \n",
    "                    if min_core_sent_dict == None :                                                                     \n",
    "                        continue                                                                                        \n",
    "                    v_turn_to_num = v_turn_to_num+1                                                                     \n",
    "                    for core_pos_value in min_core_sent_dict:                                                           \n",
    "                        core_sent_no_stop_list=no_stop_word(min_core_sent_dict[core_pos_value])                         \n",
    "                        if sent_value in core_sent_no_stop_list:                                                        \n",
    "                            #print(\"yes\")                                                                                \n",
    "                            #train_data_example[0][v_turn_to_num*pos_len+label_list.index(pos_value)]=1\n",
    "                            if core_pos_value not in label_list:\n",
    "                                continue\n",
    "                            train_data_example[v_turn_to_num*pos_len+label_list.index(core_pos_value)]=1      \n",
    "                            #train_data_example[sent_id-1][v_turn_to_num*pos_len+label_list.index(core_pos_value)]  += 1\n",
    "\n",
    "                           \n",
    "                        \n",
    "    example_data = \" \".join('%s' %id for id in train_data_example.numpy().tolist())\n",
    "    return example_data\n",
    "\n",
    "\n",
    "\n",
    "#第四种qu_graph_data_core\n",
    "\n",
    "#=============================================================================================================================\n",
    "\n",
    "\n",
    "\n",
    "def qu_graph_data_core(core_sent_dict,qu_sent_dict):\n",
    "    \n",
    "    sent_pos_dict = qu_sent_dict\n",
    "    label_list=[\n",
    "    'ARG0',     \n",
    "    'ARG1',     \n",
    "    'V',        \n",
    "    'ARG2',     \n",
    "    'ARG3',     \n",
    "    'ARG4',     \n",
    "    'ARG5',     \n",
    "    'ARGM-COM', \n",
    "    'ARGM-LOC', \n",
    "    'ARGM-DIR', \n",
    "    'ARGM-MNR', \n",
    "    'ARGM-TMP', \n",
    "    'ARGM-EXT', \n",
    "    'ARGM-REC', \n",
    "    'ARGM-PRD', \n",
    "    'ARGM-PRP', \n",
    "    'ARGM-CAU', \n",
    "    'ARGM-DIS', \n",
    "    'ARGM-ADV', \n",
    "    'ARGM-ADJ', \n",
    "    'ARGM-MOD', \n",
    "    'ARGM-NEG', \n",
    "    'ARGM-DSP', \n",
    "    'ARGM-LVB', \n",
    "    'ARGM-CNX'   \n",
    "    ]           \n",
    "    pos_len = len(label_list)\n",
    "    \n",
    "    #sent_max_num = 15\n",
    "    #verb_max_num = 5\n",
    "    \n",
    "    #print(\"arg0:\",label_list.index('ARG0'))\n",
    "    #print(\"arg1:\",label_list.index('ARG1'))\n",
    "    #print(\"ARGM-CNX:\",label_list.index('ARGM-CNX'))\n",
    "    #core_sent_dict = sent_pos_dict = sent_srl_data[1]\n",
    "    \n",
    "    train_data_example = torch.zeros(pos_len * verb_max_num,dtype = torch.int)\n",
    "    \n",
    "    \n",
    "                                                                    \n",
    "    for min_sent_pos_dict in sent_pos_dict:                            #单个句子的所有以verb计数的字典                  \n",
    "        if min_sent_pos_dict == None :                                                                                  \n",
    "            continue                                                                                                    \n",
    "        for pos_value in min_sent_pos_dict:                            # 单个字典的所有key                              \n",
    "            #print(pos_value)                                                                                            \n",
    "            #print(min_sent_pos_dict[pos_value])                                                                         \n",
    "            sent_no_stop_list=no_stop_word(min_sent_pos_dict[pos_value])                                                \n",
    "            for sent_value in sent_no_stop_list:                       #去除停用词后的所有value                         \n",
    "                v_turn_to_num = -1                                                                                       \n",
    "                for min_core_sent_dict in core_sent_dict:             #核心句                                           \n",
    "                    if min_core_sent_dict == None :                                                                     \n",
    "                        continue                                                                                        \n",
    "                    v_turn_to_num = v_turn_to_num+1                                                                     \n",
    "                    for core_pos_value in min_core_sent_dict:                                                           \n",
    "                        core_sent_no_stop_list=no_stop_word(min_core_sent_dict[core_pos_value])                         \n",
    "                        if sent_value in core_sent_no_stop_list:\n",
    "                            if pos_value not in label_list:\n",
    "                                continue\n",
    "                            #print(\"yes\")                                                                                \n",
    "                            train_data_example[v_turn_to_num*pos_len+label_list.index(pos_value)]=1            \n",
    "                            #train_data_example[sent_id][v_turn_to_num*pos_len+label_list.index(core_pos_value)]=1      \n",
    "                            #train_data_example[sent_id-1][v_turn_to_num*pos_len+label_list.index(core_pos_value)]  += 1\n",
    "\n",
    "                           \n",
    "                        \n",
    "    example_data = \" \".join('%s' %id for id in train_data_example.numpy().tolist()) \n",
    "    return example_data\n",
    "    \n",
    "\n",
    "#第五种core_sent_dict_to_graph\n",
    "\n",
    "#=============================================================================================================================\n",
    "\n",
    "def core_sent_dict_to_graph(core_sent_dict):\n",
    "    \n",
    "    sent_pos_dict = core_sent_dict\n",
    "    \n",
    "    label_list=[\n",
    "    'ARG0',     \n",
    "    'ARG1',     \n",
    "    'V',        \n",
    "    'ARG2',     \n",
    "    'ARG3',     \n",
    "    'ARG4',     \n",
    "    'ARG5',     \n",
    "    'ARGM-COM', \n",
    "    'ARGM-LOC', \n",
    "    'ARGM-DIR', \n",
    "    'ARGM-MNR', \n",
    "    'ARGM-TMP', \n",
    "    'ARGM-EXT', \n",
    "    'ARGM-REC', \n",
    "    'ARGM-PRD', \n",
    "    'ARGM-PRP', \n",
    "    'ARGM-CAU', \n",
    "    'ARGM-DIS', \n",
    "    'ARGM-ADV', \n",
    "    'ARGM-ADJ', \n",
    "    'ARGM-MOD', \n",
    "    'ARGM-NEG', \n",
    "    'ARGM-DSP', \n",
    "    'ARGM-LVB', \n",
    "    'ARGM-CNX'   \n",
    "    ]           \n",
    "    \n",
    "    pos_len = len(label_list)\n",
    "    #verb_max_num = 5\n",
    "\n",
    "    #print(\"arg0:\",label_list.index('ARG0'))\n",
    "    #print(\"arg1:\",label_list.index('ARG1'))\n",
    "    #print(\"ARGM-CNX:\",label_list.index('ARGM-CNX'))\n",
    "    \n",
    "    train_data_example = torch.zeros(pos_len * verb_max_num,dtype = torch.int)\n",
    "    \n",
    "    v_turn_to_num = -1\n",
    "    for min_sent_pos_dict in sent_pos_dict:                            #单个句子的所有以verb计数的字典                  \n",
    "            if min_sent_pos_dict == None :                                                                                  \n",
    "                continue\n",
    "            v_turn_to_num = v_turn_to_num+1\n",
    "            #print(\"v_t\",)\n",
    "            for pos_value in min_sent_pos_dict:                            # 单个字典的所有key                              \n",
    "                #print(pos_value)                                                                                            \n",
    "                #print(min_sent_pos_dict[pos_value])                                                                         \n",
    "                #sent_no_stop_list=no_stop_word(min_sent_pos_dict[pos_value])\n",
    "                if pos_value not in label_list:\n",
    "                    continue\n",
    "                train_data_example[v_turn_to_num*pos_len+label_list.index(pos_value)]=1\n",
    "        \n",
    "        \n",
    "    example_data = \" \".join('%s' %id for id in train_data_example.numpy().tolist())      \n",
    "    return example_data  \n",
    "    \n",
    "\n",
    "    \n",
    "#core_dict1 = orig_data[0]['context'][0][2][1]\n",
    "#core_dict2 = [None, {'ARG0': 'He', 'V': 'lives', 'ARGM-LOC': 'in Los Angeles , California'}]\n",
    "#print(core_sent_dict_to_graph(core_dict2))\n",
    "#test_str1 = \"\".join()\n",
    "#print(test_str1)\n",
    "#print(\" \".join('%s' %id for id in core_sent_dict_to_graph(core_dict2)))\n",
    "    \n",
    "    \n",
    "#第六种qu_sent_dict_to_graph\n",
    "\n",
    "\n",
    "def qu_sent_dict_to_graph(qu_sent_dict):\n",
    "    \n",
    "    sent_pos_dict = qu_sent_dict\n",
    "    \n",
    "    label_list=[\n",
    "    'ARG0',     \n",
    "    'ARG1',     \n",
    "    'V',        \n",
    "    'ARG2',     \n",
    "    'ARG3',     \n",
    "    'ARG4',     \n",
    "    'ARG5',     \n",
    "    'ARGM-COM', \n",
    "    'ARGM-LOC', \n",
    "    'ARGM-DIR', \n",
    "    'ARGM-MNR', \n",
    "    'ARGM-TMP', \n",
    "    'ARGM-EXT', \n",
    "    'ARGM-REC', \n",
    "    'ARGM-PRD', \n",
    "    'ARGM-PRP', \n",
    "    'ARGM-CAU', \n",
    "    'ARGM-DIS', \n",
    "    'ARGM-ADV', \n",
    "    'ARGM-ADJ', \n",
    "    'ARGM-MOD', \n",
    "    'ARGM-NEG', \n",
    "    'ARGM-DSP', \n",
    "    'ARGM-LVB', \n",
    "    'ARGM-CNX'   \n",
    "    ]           \n",
    "    \n",
    "    pos_len = len(label_list)\n",
    "    #verb_max_num = 5\n",
    "\n",
    "    #print(\"arg0:\",label_list.index('ARG0'))\n",
    "    #print(\"arg1:\",label_list.index('ARG1'))\n",
    "    #print(\"ARGM-CNX:\",label_list.index('ARGM-CNX'))\n",
    "    \n",
    "    train_data_example = torch.zeros(pos_len * verb_max_num,dtype = torch.int)\n",
    "    \n",
    "    v_turn_to_num = -1\n",
    "    for min_sent_pos_dict in sent_pos_dict:                            #单个句子的所有以verb计数的字典                  \n",
    "            if min_sent_pos_dict == None :                                                                                  \n",
    "                continue\n",
    "            v_turn_to_num = v_turn_to_num+1\n",
    "            for pos_value in min_sent_pos_dict:                            # 单个字典的所有key                              \n",
    "                #print(pos_value)                                                                                            \n",
    "                #print(min_sent_pos_dict[pos_value])                                                                         \n",
    "                #sent_no_stop_list=no_stop_word(min_sent_pos_dict[pos_value]) \n",
    "                if pos_value not in label_list:\n",
    "                    continue\n",
    "                train_data_example[v_turn_to_num*pos_len+label_list.index(pos_value)]=1\n",
    "        \n",
    "        \n",
    "    example_data = \" \".join('%s' %id for id in train_data_example.numpy().tolist())     \n",
    "    return example_data  \n",
    "    \n",
    "    \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def srl_parse_sent(predictor_srl,sent):\n",
    "#     sent = sent.replace('] ',' ').replace('[ ',' ').replace('; ',' ').replace(': ',' ')  \n",
    "#     sent = pro_srl_sent(sent) #去除括号里的内容\n",
    "\n",
    "#     output = predictor_srl.predict(sent)\n",
    "#     #print(output)\n",
    "#     srl_v_dic = []\n",
    "    \n",
    "#     speci = verb_num_be_dict(len(output['verbs']),sent)\n",
    "#     if speci != {}:\n",
    "#         srl_v_dic.append(speci)\n",
    "\n",
    "#     for v_dic in output[\"verbs\"]:\n",
    "#         srl_v_dic.append( to_dict(v_dic[\"description\"]) )\n",
    "#     return srl_v_dic\n",
    "\n",
    "\n",
    "#qu_graph_data_core(core_sent_dict,qu_sent_dict)\n",
    "#core_graph_data_qu(core_sent_dict,qu_sent_dict)\n",
    "#qu_sent_dict_to_graph(qu_sent_dict)\n",
    "#core_sent_dict_to_graph(core_sent_dict)\n",
    "#core_graph_data_core_other(para,core_sent_dict)\n",
    "#core_graph_data_other(para,core_sent_dict)\n",
    "\n",
    "def graph_refine_data( case ):\n",
    "    graph_context = []\n",
    "    all_sent = []\n",
    "    tmp1 = []\n",
    "    #print(len(case[\"context\"][0]))\n",
    "    #print(len(case[\"context\"][1]))\n",
    "    \n",
    "    #print('=======================')\n",
    "    for idx,sent1 in enumerate(case[\"context\"][0]):\n",
    "        tmp1.append(sent1[1])\n",
    "        #print(sent1)\n",
    "        #print('=======================')\n",
    "    all_sent.append(tmp1)\n",
    "    \n",
    "    #print(all_sent)\n",
    "    \n",
    "    #print(\"len : \",len(all_sent[0]))\n",
    "    #print('=======================')\n",
    "    tmp2 = []\n",
    "    for sent1 in case[\"context\"][1]:\n",
    "        tmp2.append(sent1[1])\n",
    "    all_sent.append(tmp2)\n",
    "    \n",
    "    #print(all_sent[0])\n",
    "    #print(all_sent[1])\n",
    "    para = all_sent\n",
    "    #print(len(para))\n",
    "    #print(\"len_p\",len(para[0]))\n",
    "    #print(para[0])\n",
    "    #print(para[0][0])\n",
    "    #print(para[0][1])\n",
    "    #print(\"len_p1\",len(para[1]))\n",
    "    for paragraph in case[\"context\"]:\n",
    "        qu_sent_dict=case['srl_question']\n",
    "        #print(\"qu_sent_dict\",qu_sent_dict)\n",
    "        sub_context = []\n",
    "        #print(paragraph)\n",
    "        #print(paragraph[0])\n",
    "        #print(paragraph[1])\n",
    "        for sent_idx,core_sent_dict_d in enumerate(paragraph):\n",
    "            #print(paragraph[1])\n",
    "            #print(core_sent_dict_d[1])\n",
    "            #print(core_sent_dict_d[2])\n",
    "            core_sent_dict = core_sent_dict_d[1]\n",
    "            #print(\"len_p:1\",len(para[0]))\n",
    "            #print(\"len_p:2\",len(para[1]))\n",
    "            sub_context.append([core_sent_dict_to_graph(core_sent_dict) , \n",
    "                                 core_graph_data_core_other(para,core_sent_dict) , \n",
    "                                 core_graph_data_other(para,core_sent_dict),\n",
    "                                 core_graph_data_qu(core_sent_dict,qu_sent_dict),\n",
    "                                 qu_graph_data_core(core_sent_dict,qu_sent_dict),\n",
    "                                 core_sent_dict_d[2] ])\n",
    "            \n",
    "        graph_context.append(sub_context)\n",
    "        #print(\"ap_num:\",11111)\n",
    "            #break\n",
    "    #total += 1\n",
    "    #print(\"finish {}\".format(total))\n",
    "    #print(\"tyoe:\",type(graph_context))\n",
    "    return dict([\n",
    "                    (\"_id\",case[\"_id\"]),\n",
    "                    (\"answer\",case['answer']),\n",
    "                    (\"question\", case['question']),\n",
    "                    (\"srl_question_to_graph\", qu_sent_dict_to_graph(qu_sent_dict)),\n",
    "                    #(\"srl_question\", qu_sent_dict_to_graph(qu_sent_dict),srl_parse_sent(predictor_srl,case['question']) ),\n",
    "                    #(\"supporting_facts\",case[\"supporting_facts\"]),\n",
    "                    (\"context\",graph_context),\n",
    "                    (\"type\",case[\"type\"]),\n",
    "                    (\"level\",case[\"level\"]), \n",
    "                    ])\n",
    "\n",
    "def dataset(file_name,out_filename,sent_num,verb_num):\n",
    "\n",
    "    with open(file_name, \"r\", encoding='utf-8') as reader:\n",
    "            orig_data = json.load(reader)\n",
    "    #orig_data = orig_data[:1]\n",
    "    #print(\"con:\",len(orig_data[0]['context']))\n",
    "    print(\"Load ok\")\n",
    "    srl_data = []\n",
    "    #srl_data = Parallel(n_jobs=8, verbose=20)(delayed(refine_data)(article) for article in orig_data)\n",
    "    \n",
    "    #for article in orig_data:\n",
    "       # print(\"1\")\n",
    "    \n",
    "    \n",
    "    for article in tqdm(orig_data):\n",
    "        sub_verb_num = max(verb_c_num(article['context'][0]),verb_c_num(article['context'][1]),len(article['srl_question']))\n",
    "        sub_sent = len(article['context'][0])+len(article['context'][1])\n",
    "        \n",
    "        if sub_sent <= sent_num and sub_verb_num <= verb_num:\n",
    "            #print(\"v_n:\",sub_verb_num)\n",
    "            srl_data.append( graph_refine_data(article) )\n",
    "    \n",
    "    \n",
    "    \n",
    "    #for article in tqdm(orig_data):\n",
    "        #srl_data.append( refine_data(article) )\n",
    "    #srl_data = [refine_data(article) for article in orig_data ]\n",
    "    print(\"write data\")\n",
    "    with open(out_filename,'w') as file_obj:\n",
    "        json.dump(srl_data,file_obj)\n",
    "#  [{'V': 'born', '': 'ARGM-TMP 18 November 1963 )'}, {'ARG0': 'a Danish former professional footballer', '': 'ARGM-MNR as a goalkeeper'}, {'V': 'voted', '': 'ARGM-TMP in 1992 and 1993'}]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def graph_refine_data( case ):\n",
    "    graph_context = []\n",
    "    all_sent = []\n",
    "    tmp1 = []\n",
    "    #print(len(case[\"context\"][0]))\n",
    "    #print(len(case[\"context\"][1]))\n",
    "    \n",
    "    #print('=======================')\n",
    "    for idx,sent1 in enumerate(case[\"context\"][0]):\n",
    "        tmp1.append(sent1[1])\n",
    "        #print(sent1)\n",
    "        #print('=======================')\n",
    "    all_sent.append(tmp1)\n",
    "    \n",
    "    #print(all_sent)\n",
    "    \n",
    "    #print(\"len : \",len(all_sent[0]))\n",
    "    #print('=======================')\n",
    "    tmp2 = []\n",
    "    for sent1 in case[\"context\"][1]:\n",
    "        tmp2.append(sent1[1])\n",
    "    all_sent.append(tmp2)\n",
    "    \n",
    "    #print(all_sent[0])\n",
    "    #print(all_sent[1])\n",
    "    para = all_sent\n",
    "    #print(len(para))\n",
    "    #print(\"len_p\",len(para[0]))\n",
    "    #print(para[0])\n",
    "    #print(para[0][0])\n",
    "    #print(para[0][1])\n",
    "    #print(\"len_p1\",len(para[1]))\n",
    "    for paragraph in case[\"context\"]:\n",
    "        qu_sent_dict=case['srl_question']\n",
    "        #print(\"qu_sent_dict\",qu_sent_dict)\n",
    "        sub_context = []\n",
    "        #print(paragraph)\n",
    "        #print(paragraph[0])\n",
    "        #print(paragraph[1])\n",
    "        for sent_idx,core_sent_dict_d in enumerate(paragraph):\n",
    "            #print(paragraph[1])\n",
    "            #print(core_sent_dict_d[1])\n",
    "            #print(core_sent_dict_d[2])\n",
    "            core_sent_dict = core_sent_dict_d[1]\n",
    "            #print(\"len_p:1\",len(para[0]))\n",
    "            #print(\"len_p:2\",len(para[1]))\n",
    "            sub_context.append([core_sent_dict_to_graph(core_sent_dict) , \n",
    "                                 core_graph_data_core_other(para,core_sent_dict) , \n",
    "                                 core_graph_data_other(para,core_sent_dict),\n",
    "                                 core_graph_data_qu(core_sent_dict,qu_sent_dict),\n",
    "                                 qu_graph_data_core(core_sent_dict,qu_sent_dict),\n",
    "                                 core_sent_dict_d[2] ])\n",
    "            \n",
    "        graph_context.append(sub_context)\n",
    "        #print(\"ap_num:\",11111)\n",
    "            #break\n",
    "    #total += 1\n",
    "    #print(\"finish {}\".format(total))\n",
    "    #print(\"tyoe:\",type(graph_context))\n",
    "    return dict([\n",
    "                    (\"_id\",case[\"_id\"]),\n",
    "                    (\"answer\",case['answer']),\n",
    "                    (\"question\", case['question']),\n",
    "                    (\"srl_question_to_graph\", qu_sent_dict_to_graph(qu_sent_dict)),\n",
    "                    #(\"srl_question\", qu_sent_dict_to_graph(qu_sent_dict),srl_parse_sent(predictor_srl,case['question']) ),\n",
    "                    #(\"supporting_facts\",case[\"supporting_facts\"]),\n",
    "                    (\"context\",graph_context),\n",
    "                    (\"type\",case[\"type\"]),\n",
    "                    (\"level\",case[\"level\"]), \n",
    "                    ])\n",
    "\n",
    "def dataset(file_name,out_filename,sent_num,verb_num):\n",
    "\n",
    "    with open(file_name, \"r\", encoding='utf-8') as reader:\n",
    "            orig_data = json.load(reader)\n",
    "    #orig_data = orig_data[:1]\n",
    "    #print(\"con:\",len(orig_data[0]['context']))\n",
    "    print(\"Load ok\")\n",
    "    srl_data = []\n",
    "    #srl_data = Parallel(n_jobs=8, verbose=20)(delayed(refine_data)(article) for article in orig_data)\n",
    "    \n",
    "    #for article in orig_data:\n",
    "       # print(\"1\")\n",
    "    \n",
    "    \n",
    "    for article in tqdm(orig_data):\n",
    "        sub_verb_num = max(verb_c_num(article['context'][0]),verb_c_num(article['context'][1]),len(article['srl_question']))\n",
    "        sub_sent = len(article['context'][0])+len(article['context'][1])\n",
    "        \n",
    "        if sub_sent <= sent_num and sub_verb_num <= verb_num:\n",
    "            #print(\"v_n:\",sub_verb_num)\n",
    "            srl_data.append( graph_refine_data(article) )\n",
    "    \n",
    "    \n",
    "    \n",
    "    #for article in tqdm(orig_data):\n",
    "        #srl_data.append( refine_data(article) )\n",
    "    #srl_data = [refine_data(article) for article in orig_data ]\n",
    "    print(\"write data\")\n",
    "    with open(out_filename,'w') as file_obj:\n",
    "        json.dump(srl_data,file_obj)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def verb_c_num(context):\n",
    "    verb_min_list=[]\n",
    "    for sent in context:\n",
    "        verb_min_list.append(len(sent[1]))\n",
    "               \n",
    "    return max(verb_min_list)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 0/90447 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Load ok\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 90447/90447 [5:00:37<00:00,  5.01it/s]   \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "write data\n"
     ]
    }
   ],
   "source": [
    "#file_name = \"srl_hotpot_dev_distractor_v1.json\"\n",
    "#out_filename = \"dev_sent_5_verb_4_data_v1.json\"\n",
    "file_name = \"srl_hotpot_train_v1.1.json\"\n",
    "out_filename = \"train_sent_5_verb_4_data_v1.json\"\n",
    "\n",
    "dataset(file_name,out_filename,5,4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pt3_py3",
   "language": "python",
   "name": "pt3_py3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
