{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "flag = torch.cuda.is_available()\n",
    "print(flag)\n",
    "\n",
    "ngpu= 1  \n",
    "# Decide which device we want to run on \n",
    "device = torch.device(\"cuda:0\" if (torch.cuda.is_available() and ngpu > 0) else \"cpu\")\n",
    "torch.cuda.set_device(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "from torch.utils import data \n",
    "from torch.autograd import Variable \n",
    "import torchvision\n",
    "from torchvision.datasets import mnist \n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.autograd import Variable\n",
    "from torch import optim\n",
    "from torch import nn\n",
    "import json\n",
    "import numpy as np\n",
    "import string\n",
    "import os\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class qa_train_context_dataset(Dataset):\n",
    "    \n",
    "        \n",
    "    def graph_refine_data( self,case ):\n",
    "        graph_context = []\n",
    "    \n",
    "        for paragh in case[\"context\"]:\n",
    "            qu_graph = self.float_to_int(self.str_to_int(case['srl_question_to_graph']))\n",
    "            sub_context = []\n",
    "      \n",
    "            for sent_idx,sent_value in enumerate(paragh):\n",
    "                #print(\"s_v:\",sent_value[5])\n",
    "           \n",
    "                sub_context.append([self.float_to_int(self.str_to_int(sent_value[0])),   \n",
    "                                    self.float_to_int(self.str_ar_to_int(sent_value[1])),\n",
    "                                    self.float_to_int(self.str_ar_to_int(sent_value[2])),\n",
    "                                    self.float_to_int(self.str_to_int(sent_value[3])),   \n",
    "                                    self.float_to_int(self.str_to_int(sent_value[4])),\n",
    "                                    self.true_1_false_0(sent_value[5]),\n",
    "                                     ])\n",
    "            \n",
    "            graph_context.append(sub_context)\n",
    "    \n",
    "        return dict([\n",
    "                    (\"_id\",case[\"_id\"]),\n",
    "                    #(\"answer\",case['answer']),\n",
    "                    #(\"question\", case['question']),\n",
    "                    (\"qu_graph\", qu_graph),\n",
    "                    #(\"srl_question\", qu_sent_dict_to_graph(qu_sent_dict),srl_parse_sent(predictor_srl,case['question']) ),\n",
    "                    #(\"supporting_facts\",case[\"supporting_facts\"]),\n",
    "                    (\"context\",graph_context),\n",
    "                    (\"type\",case[\"type\"]),\n",
    "                    (\"level\",case[\"level\"]), \n",
    "                    ])\n",
    "    \n",
    "    \n",
    "    def example_to_sentlist(self,case,example_id):\n",
    "        sent_graph_list = []\n",
    "        \n",
    "        qu_graph = case['qu_graph']\n",
    "        context0 = case['context'][0]\n",
    "        context1 = case['context'][1]\n",
    "        art_id = case[\"_id\"]\n",
    "        #print(\"c0_n:\",len(context0))\n",
    "        #print(\"c1_n:\",len(context1))\n",
    "        context0.extend(context1)\n",
    "        #print(\"c0_u_n:\",len(context0))\n",
    "        #print(context0[0][5])\n",
    "\n",
    "            #print(qu_graph)\n",
    "        \n",
    "#flow_power_start=======================================================================        \n",
    "#         core_article_infor_list = []\n",
    "#         core_qu_infor_list = []\n",
    "        verb_num = 3\n",
    "        srl_label_num = 25\n",
    "        sent_size = srl_label_num*verb_num\n",
    "        sent_num = 5\n",
    "        qu_graph = qu_graph.float().reshape(1,1,sent_size)\n",
    "#         flow_once_qu_article_power_list = []\n",
    "        \n",
    "#         for con_value in context0 :\n",
    "#             #con_value[1] = con_value[1].float().reshape(1,1,75,5).unsqueeze(0).unsqueeze(0).cuda()\n",
    "            \n",
    "#             core_article_infor = con_value[1].float().reshape(1,sent_num,sent_size)\n",
    "#             core_qu_infor = con_value[4].float().reshape(1,1,sent_size)\n",
    "#             flow_once_qu_article_power_list.append(\n",
    "#                 self.article_value(core_article_infor,sent_num,sent_size).mul(\n",
    "#                 self.pad_core_qu_info(core_qu_infor,sent_num,sent_size)))\n",
    "#             #print(self.article_value(core_article_infor,sent_num,sent_size).reshape(10,100))\n",
    "#             #print(self.pad_core_qu_info(core_qu_infor,sent_num,sent_size).reshape(10,100))\n",
    "#         flow_once_qu_article_power_graph = self.sum_flow_article_qu_power(flow_once_qu_article_power_list)\n",
    "#         #print(flow_once_qu_article_power_graph.reshape(10,100))\n",
    "#         flow_once_core_list = self.divide_article_to_sent(flow_once_qu_article_power_graph,\n",
    "#                                                          sent_num,sent_size)\n",
    "        \n",
    "#         #flow_twice==========================================================================================================\n",
    "#         flow_twice_qu_article_power_list = []\n",
    "        \n",
    "#         sent_id = 0\n",
    "#         for con_value in context0 :\n",
    "#             #con_value[1] = con_value[1].float().reshape(1,1,75,5).unsqueeze(0).unsqueeze(0).cuda()\n",
    "            \n",
    "#             core_article_infor = con_value[1].float().reshape(1,sent_num,sent_size)\n",
    "            \n",
    "#             #print(flow_once_qu_article_power_list[sent_id].size())\n",
    "#             flow_twice_qu_article_power_list.append(\n",
    "#                 self.article_value(core_article_infor,sent_num,sent_size).mul(\n",
    "#                 self.pad_core_qu_info(flow_once_core_list[sent_id],sent_num,sent_size)))\n",
    "#             sent_id = sent_id + 1 \n",
    "            \n",
    "#         flow_twice_qu_article_power_graph = self.sum_flow_article_qu_power(flow_twice_qu_article_power_list)\n",
    "#         flow_twice_core_list = self.divide_article_to_sent(flow_twice_qu_article_power_graph,\n",
    "#                                                          sent_num,sent_size)\n",
    "        \n",
    "        #flow_once_core_list\n",
    "#flow_power_end=======================================================================               \n",
    "            \n",
    "           \n",
    "        \n",
    "        core_sent_id = -1\n",
    "        for con_value in context0 :\n",
    "            core_sent_id = core_sent_id + 1\n",
    "            #con_value[1] = con_value[1].float().reshape(1,1,75,5).unsqueeze(0).unsqueeze(0).cuda()\n",
    "            core_graph = con_value[0].float().reshape(1,1,sent_size)\n",
    "            core_article_infor = con_value[1].float().reshape(1,sent_num,sent_size)\n",
    "            core_article_power = con_value[2].float().reshape(1,sent_num,sent_size)\n",
    "            core_qu_infor = con_value[4].float().reshape(1,1,sent_size)\n",
    "            core_qu_power = con_value[3].float().reshape(1,1,sent_size)\n",
    "            #core_qu_infor = torch.cat((con_value[0],con_value[3]),0).reshape(1,2,75)\n",
    "            #core_qu_power = torch.cat((con_value[0],con_value[4]),0).reshape(1,2,75)\n",
    "            #qu_core_infor = torch.cat((qu_graph,con_value[4]),0).reshape(1,2,75)\n",
    "            #qu_core_power = torch.cat((qu_graph,con_value[3]),0).reshape(1,2,75)\n",
    "\n",
    "            con_value[5] = con_value[5].float()\n",
    "            label = con_value[5]\n",
    "            #print(\"example_id_N:\",example_id)\n",
    "#                                  core_sent_dict_to_graph(core_sent_dict) , \n",
    "#                                  core_graph_data_core_other(para,core_sent_dict) , \n",
    "#                                  core_graph_data_other(para,core_sent_dict),\n",
    "#                                  core_graph_data_qu(core_sent_dict,qu_sent_dict),\n",
    "#                                  qu_graph_data_core(core_sent_dict,qu_sent_dict),\n",
    "            \n",
    "            merge_sum_sent_qu = []\n",
    "            merge_sum_sent_other = []\n",
    "            \n",
    "            article_sum = []\n",
    "            qu_article_info_sum = []\n",
    "            qu_article_power_sum = []\n",
    "            other_sent_id = 0\n",
    "            for other_sent_value in context0 :\n",
    "                other_sent_value_self = other_sent_value[0].float().reshape(1,1,sent_size)\n",
    "                other_sent_value_cq = other_sent_value[3].float().reshape(1,1,sent_size)\n",
    "                other_sent_value_qc = other_sent_value[4].float().reshape(1,1,sent_size)\n",
    "                core_article_power_list = core_article_power.reshape(sent_num,sent_size)\n",
    "                c_o_sent_value = core_article_power_list[other_sent_id]\n",
    "                \n",
    "                #merge_sent_qu_cosent = torch.cat((qu_graph,other_sent_value_qc + core_qu_power),0).reshape(1,2,100)\n",
    "                #merge_sent_o_cqusent = torch.cat((other_sent_value_self,other_sent_value_cq+c_o_sent_value),0).reshape(1,2,100)\n",
    "               \n",
    "                #merge_sum_sent_qu.extend(merge_sent_qu_cosent)\n",
    "                #merge_sum_sent_other.extend(merge_sent_o_cqusent)\n",
    "                article_sum.extend(other_sent_value_self)\n",
    "                qu_article_info_sum.extend(other_sent_value_cq)\n",
    "                qu_article_power_sum.extend(other_sent_value_qc)\n",
    "                \n",
    "                other_sent_id = other_sent_id + 1\n",
    "                \n",
    "                \n",
    "            sent_graph_list.append([core_article_infor,\n",
    "                                   core_article_power,\n",
    "                                   core_qu_infor,\n",
    "                                   core_qu_power,\n",
    "                                   qu_graph,\n",
    "                                   core_graph,\n",
    "                                   label,\n",
    "                                   example_id,\n",
    "                                   self.merge_list_to_graph(self.padding_sum_graph(sent_num,article_sum,sent_size)),\n",
    "                                   self.merge_list_to_graph(self.padding_sum_graph(sent_num,qu_article_info_sum,sent_size)),\n",
    "                                   self.merge_list_to_graph(self.padding_sum_graph(sent_num,qu_article_power_sum,sent_size)),\n",
    "                                   art_id,\n",
    "#                                    flow_once_qu_article_power_graph.reshape(1,sent_num,sent_size),\n",
    "#                                    flow_once_core_list[core_sent_id].reshape(1,1,sent_size),\n",
    "#                                    flow_twice_qu_article_power_graph.reshape(1,sent_num,sent_size),\n",
    "#                                    flow_twice_core_list[core_sent_id].reshape(1,1,sent_size),\n",
    "                                   ])\n",
    "        return sent_graph_list\n",
    "        \n",
    "        \n",
    "    \n",
    "                \n",
    "        \n",
    "#flow_power_tools_start==============================================================================       \n",
    "\n",
    "    def article_value(self,core_article_info,sent_num,sent_size):\n",
    "        a = core_article_info.reshape(sent_num*sent_size).int()\n",
    "        contant = sent_size\n",
    "        i = 0\n",
    "        b = []\n",
    "        #graph_value = null\n",
    "\n",
    "        for i in range(sent_num):\n",
    "            b.append(sum(a[i*contant:i*contant+contant]))\n",
    "\n",
    "        token = 0\n",
    "        pad_0 = torch.zeros(sent_size)\n",
    "        pad_1 = torch.ones(sent_size)\n",
    "        for value in b:\n",
    "            if value == 0 and token == 0:\n",
    "                graph_value = pad_0\n",
    "                token = token+1\n",
    "                continue\n",
    "            if value >= 1 and token == 0:\n",
    "                graph_value = pad_1\n",
    "                token = token+1\n",
    "                continue\n",
    "            if value == 0:\n",
    "                graph_value = torch.cat((graph_value,pad_0),0)\n",
    "            if value >= 1:\n",
    "                graph_value = torch.cat((graph_value,pad_1),0)\n",
    "\n",
    "\n",
    "        return graph_value.int()\n",
    "\n",
    "   \n",
    "    def pad_core_qu_info(self,core_qu_info,sent_num,sent_size):\n",
    "        a = core_qu_info.reshape(sent_size)\n",
    "        pad_graph = a\n",
    "        for i in range(sent_num-1):\n",
    "            pad_graph = torch.cat((pad_graph,a),0)\n",
    "        return pad_graph.int()\n",
    " \n",
    "\n",
    " #    def core_article_qu_power_graph_list():\n",
    "#         core_article_qu_power_graph_list = []\n",
    "\n",
    "#         for case in data:\n",
    "#             core_article_qu_power_graph_list.append(article_value(case['']).mual(pad_core_qu_power(case[''])))\n",
    "\n",
    "\n",
    "\n",
    "#         return core_article_qu_info_graph_list\n",
    "#\n",
    "\n",
    "    def sum_flow_article_qu_power(self,core_article_qu_power_list):\n",
    "        token_id = 0\n",
    "\n",
    "        for case in core_article_qu_power_list:\n",
    "            if token_id == 0:\n",
    "                sum_article_qu_power_graph = case\n",
    "                token_id = token_id + 1\n",
    "                continue\n",
    "            sum_article_qu_power_graph = sum_article_qu_power_graph + case\n",
    "            #print(sum_article_qu_power_graph.reshape(10,100))\n",
    "        pos_id = 0    \n",
    "        for value in sum_article_qu_power_graph:\n",
    "            \n",
    "            if value >= 1:\n",
    "                #print(value)\n",
    "                sum_article_qu_power_graph[pos_id] = 1\n",
    "            pos_id = pos_id + 1\n",
    "        return sum_article_qu_power_graph\n",
    "\n",
    "#将flow_power_qu_article_power,\n",
    "    def divide_article_to_sent(self,sum_article_qu_power_graph,sent_num,sent_size):\n",
    "        a = sum_article_qu_power_graph.reshape(sent_num*sent_size).int()\n",
    "        contant = sent_size\n",
    "        for value in a:\n",
    "            if value >= 1:\n",
    "                a[value] = 1\n",
    "\n",
    "\n",
    "        flow_once_core_qu_power_list = []\n",
    "        contant = sent_size\n",
    "        for i in range(sent_num):\n",
    "            flow_once_core_qu_power_list.append(a[i*contant:i*contant+contant])\n",
    "\n",
    "\n",
    "\n",
    "        return flow_once_core_qu_power_list\n",
    "\n",
    "\n",
    "#flow_power_tools_end==============================================================================                    \n",
    "                \n",
    "    \n",
    "    \n",
    "    def __init__(self,file_name,transform = None):\n",
    "        self.file_name = file_name\n",
    "        self.transform = transform \n",
    "        with open(file_name, \"r\", encoding='utf-8') as reader:\n",
    "            orig_data = json.load(reader)\n",
    "            print(\"Load ok\")\n",
    "            #orig_data = orig_data[0:2]\n",
    "        self.orig_data = orig_data \n",
    "        self.srl_data = []\n",
    "        self.sum_sent_graph_list = []\n",
    "        for article in tqdm(orig_data):\n",
    "            #print(\"article:\",article)\n",
    "            self.srl_data.append( self.graph_refine_data(article) )\n",
    "         \n",
    "        example_id = 0\n",
    "        for case in tqdm(self.srl_data) :\n",
    "            #print(\"example_id:\",example_id)\n",
    "            self.sum_sent_graph_list.extend(self.example_to_sentlist(case,example_id)) \n",
    "            example_id = example_id + 1\n",
    "            \n",
    "            \n",
    "        #print(\"s:\",self.srl_data[0]['context'][0])\n",
    "        #print(\"sum:\",self.sum_sent_graph_list[0][6])\n",
    "        #print(\"sum:\",self.sum_sent_graph_list[0][7])\n",
    "        #print(\"sum:\",len(self.sum_sent_graph_list))\n",
    "    \n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.sum_sent_graph_list)  \n",
    "    \n",
    "    \n",
    "    def str_to_int(self,data_str):\n",
    "        data_str_list = data_str.split()\n",
    "        b_data = [ int(i) for i in data_str_list ]\n",
    "        \n",
    "        return b_data\n",
    "    \n",
    "    def str_ar_to_int(self,data_str):\n",
    "    \n",
    "    \n",
    "        data_str_list = data_str.replace(\"[\",' ').replace(\"]\",' ').replace(\",\",' ').split()\n",
    "    #print(data_str_list)\n",
    "        b_data = [ int(i) for i in data_str_list ]\n",
    "    #for s_idx,s_data in data_str_list:\n",
    "        #data_str_list[s_idx]=int(s_data)\n",
    "        \n",
    "        \n",
    "        return b_data\n",
    "    \n",
    "    def float_to_int(self,data):\n",
    "        return torch.Tensor(data).int()\n",
    "    \n",
    "    def true_1_false_0(self,data):\n",
    "        if data == True :\n",
    "            return torch.tensor(1)\n",
    "        if data == False :\n",
    "            return torch.tensor(0)\n",
    "    def merge_list_to_graph(self,merge_sum_sent_list):\n",
    "        merge_list_num = 0    \n",
    "        for sent_merge_data in merge_sum_sent_list:\n",
    "            if merge_list_num == 0 :\n",
    "                merge_data = sent_merge_data\n",
    "                merge_list_num = merge_list_num + 1\n",
    "                continue\n",
    "            merge_data = torch.cat((merge_data,sent_merge_data),0)\n",
    "            merge_list_num = merge_list_num + 1\n",
    "        return merge_data\n",
    "    \n",
    "    def padding_sum_graph(self,sent_num,sum_graph_list,sent_size):\n",
    "        pro_list = []\n",
    "        pro_list = sum_graph_list\n",
    "        add_pad_size = torch.zeros(sent_size,1).reshape(1,1,sent_size)\n",
    "\n",
    "        for add_padding in range(sent_num - len(sum_graph_list)):\n",
    "            pro_list.extend(add_pad_size)\n",
    "\n",
    "        return pro_list \n",
    "            \n",
    "        \n",
    "        \n",
    "    \n",
    "    def __getitem__(self,index):\n",
    "        \n",
    "        \n",
    "        sent_graph_index = self.sum_sent_graph_list[index]\n",
    "        #print(sent_graph_index)\n",
    "#         article_index = self.orig_data[index]\n",
    "        \n",
    "#         qu_graph = article_index['qu_graph']\n",
    "#         context0_value = article_index['context'][0]\n",
    "#         context1_value = article_index['context'][1]\n",
    "\n",
    "        core_article_infor = sent_graph_index[0]\n",
    "        core_article_power = sent_graph_index[1]\n",
    "        core_qu_infor = sent_graph_index[2]\n",
    "        core_qu_power = sent_graph_index[3]\n",
    "        qu_graph = sent_graph_index[4]\n",
    "        core_graph = sent_graph_index[5]\n",
    "        label = sent_graph_index[6]\n",
    "        example_id = sent_graph_index[7]\n",
    "        #merge_sum_sent_qu = sent_graph_index[8]\n",
    "        #merge_sum_sent_other = sent_graph_index[9]\n",
    "        article_sum = sent_graph_index[8]\n",
    "        qu_article_info_sum = sent_graph_index[9]\n",
    "        qu_article_power_sum = sent_graph_index[10]\n",
    "        art_id = sent_graph_index[11]\n",
    "#         flow_once_qu_article_power = sent_graph_index[12]\n",
    "#         flow_once_qu_core_power = sent_graph_index[13]\n",
    "#         flow_twice_qu_article_power = sent_graph_index[14]\n",
    "#         flow_twice_qu_core_power = sent_graph_index[15]\n",
    "       \n",
    "        sample = {'core_article_infor': core_article_infor,\n",
    "                  'core_article_power':core_article_power,\n",
    "                  'core_qu_infor': core_qu_infor,\n",
    "                  'core_qu_power': core_qu_power,\n",
    "                  'qu_graph': qu_graph,\n",
    "                  'core_graph': core_graph,\n",
    "                  #'merge_sum_sent_qu': merge_sum_sent_qu,\n",
    "                  #'merge_sum_sent_other':merge_sum_sent_other,\n",
    "                  'article_sum':article_sum,\n",
    "                  'qu_article_info_sum':qu_article_info_sum,\n",
    "                  'qu_article_power_sum':qu_article_power_sum,\n",
    "                  'label': label,\n",
    "                  'example_id': example_id,\n",
    "                  'art_id': art_id,\n",
    "#                   'flow_once_qu_article_power':flow_once_qu_article_power,\n",
    "#                   'flow_once_qu_core_power':flow_once_qu_core_power,\n",
    "#                   'flow_twice_qu_article_power':flow_twice_qu_article_power,\n",
    "#                   'flow_twice_qu_core_power':flow_twice_qu_core_power,\n",
    "                 }\n",
    "        \n",
    "        if self.transform:\n",
    "            sample = self.transform(sample)\n",
    "        return sample \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_file_name = \"train_equal_sent_5_verb_3_data_v1.json\"\n",
    "dev_file_name = \"dev_equal_sent_5_verb_3_data_v1.json\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|          | 56/5319 [00:00<00:09, 556.95it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Load ok\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5319/5319 [00:10<00:00, 508.57it/s]\n",
      "100%|██████████| 5319/5319 [00:12<00:00, 432.39it/s]\n",
      " 10%|█         | 49/471 [00:00<00:00, 488.46it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Load ok\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 471/471 [00:00<00:00, 559.84it/s]\n",
      "100%|██████████| 471/471 [00:00<00:00, 581.26it/s]\n"
     ]
    }
   ],
   "source": [
    "train_data = qa_train_context_dataset(train_file_name,transform=None)\n",
    "dev_data = qa_train_context_dataset(dev_file_name,transform=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "#sent 5,6,7,8\n",
    "train_loader = DataLoader(train_data,batch_size=64,drop_last=True,shuffle=False)\n",
    "dev_loader = DataLoader(dev_data,batch_size=64,drop_last=True,shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def conv_article_Batch(in_planes, out_planes, kernel_size,stride):                      \n",
    "    return torch.nn.Sequential(\n",
    "            torch.nn.Conv2d(in_channels = in_planes,\n",
    "                            out_channels = out_planes,\n",
    "                            kernel_size=kernel_size,\n",
    "                            stride=stride,\n",
    "                            #padding=1\n",
    "                           ),\n",
    "            #torch.nn.BatchNorm2d(num_features=out_planes),\n",
    "            #torch.nn.ReLU(),\n",
    "            torch.nn.Sigmoid()\n",
    "            #torch.nn.Dropout(0.5),\n",
    "        )\n",
    "def conv_sent_Batch(in_planes, out_planes, kernel_size,stride):                      \n",
    "    return torch.nn.Sequential(\n",
    "            torch.nn.Conv2d(in_channels = in_planes,\n",
    "                            out_channels = out_planes,\n",
    "                            kernel_size=kernel_size,\n",
    "                            stride=stride,\n",
    "                            #padding=1\n",
    "                           ),\n",
    "            #torch.nn.BatchNorm2d(num_features=out_planes),\n",
    "            torch.nn.Sigmoid()\n",
    "            #torch.nn.ReLU(),\n",
    "            #torch.nn.Dropout(0.5),\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [],
   "source": [
    "class conv_one_sent_layer(torch.nn.Module):\n",
    "    def __init__(self,batch_size,graph_size,in_channels,out_channels,verb_num):\n",
    "        super(conv_one_sent_layer,self).__init__()\n",
    "        self.batch_size = batch_size\n",
    "        self.in_channels = in_channels\n",
    "        self.out_channels = out_channels\n",
    "        self.sent_kernel_size = (1,graph_size)\n",
    "        self.pos_kernel_size = (verb_num,1)\n",
    "        self.sent_stride = 1\n",
    "        self.pos_stride = 1\n",
    "        self.graph_size = graph_size\n",
    "        self.verb_num = verb_num\n",
    "\n",
    "        self.sent_power_conv = conv_sent_Batch(self.in_channels,self.out_channels,self.sent_kernel_size,self.sent_stride)\n",
    "        self.pos_power_conv = conv_sent_Batch(self.in_channels,self.out_channels,self.pos_kernel_size,self.pos_stride)\n",
    "        \n",
    "    def forward(self,   \n",
    "                sent_graph,):\n",
    "        sent_power = self.sent_power_conv(sent_graph.reshape(self.batch_size,1,1,self.graph_size))\n",
    "        #print(sent_graph.reshape(self.batch_size,-1)[:,0::25].size())\n",
    "        pos_i_power = self.pos_power_conv(sent_graph.reshape(self.batch_size,1,self.verb_num,25))\n",
    "        \n",
    "        return sent_power,pos_i_power\n",
    "\n",
    "    \n",
    "class conv_article_layer(torch.nn.Module):\n",
    "    def __init__(self,batch_size,graph_size,sent_num,in_channels,out_channels,verb_num):\n",
    "        super(conv_article_layer,self).__init__()\n",
    "        self.batch_size = batch_size\n",
    "        self.in_channels = in_channels\n",
    "        self.out_channels = out_channels\n",
    "        self.verb_num = verb_num\n",
    "        self.sent_num = sent_num\n",
    "        self.srl_len = 25\n",
    "        self.sent_size = self.verb_num*self.srl_len\n",
    "        self.sent_kernel_size = (self.sent_num,self.sent_size)\n",
    "        self.pos_kernel_size = (self.sent_num*self.verb_num,1)\n",
    "        self.sent_stride = 1\n",
    "        self.pos_stride = 1\n",
    "        self.graph_size = graph_size\n",
    "        \n",
    "\n",
    "        self.article_power_conv = conv_article_Batch(self.in_channels,self.out_channels,self.sent_kernel_size,self.sent_stride)\n",
    "        self.pos_power_conv = conv_article_Batch(self.in_channels,self.out_channels,self.pos_kernel_size,self.pos_stride)\n",
    "        \n",
    "        \n",
    "    def forward(self,   \n",
    "                sent_graph,):\n",
    "        #print(\"self.sent_size\",self.sent_size)\n",
    "        article_power = self.article_power_conv(sent_graph.reshape(self.batch_size,1,self.sent_num,self.sent_size))\n",
    "        \n",
    "        a_pos_i_power = self.pos_power_conv(sent_graph.reshape(self.batch_size,1,self.sent_num*self.verb_num,25))\n",
    "        \n",
    "        \n",
    "        return article_power,a_pos_i_power\n",
    "            \n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "conv_pos_power_net(\n",
      "  (core_power_conv): conv_one_sent_layer(\n",
      "    (sent_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(1, 75), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "    (pos_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(3, 1), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "  )\n",
      "  (qu_power_conv): conv_one_sent_layer(\n",
      "    (sent_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(1, 75), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "    (pos_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(3, 1), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "  )\n",
      "  (article_sum_power_conv): conv_article_layer(\n",
      "    (article_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(5, 75), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "    (pos_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(15, 1), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "  )\n",
      "  (qu_article_power_conv): conv_article_layer(\n",
      "    (article_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(5, 75), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "    (pos_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(15, 1), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "  )\n",
      "  (core_article_power_conv): conv_article_layer(\n",
      "    (article_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(5, 75), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "    (pos_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(15, 1), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "  )\n",
      "  (cqi_info_conv): conv_one_sent_layer(\n",
      "    (sent_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(1, 75), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "    (pos_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(3, 1), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "  )\n",
      "  (cqp_power_conv): conv_one_sent_layer(\n",
      "    (sent_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(1, 75), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "    (pos_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(3, 1), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "  )\n",
      "  (core_article_info_conv): conv_article_layer(\n",
      "    (article_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(5, 75), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "    (pos_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(15, 1), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "  )\n",
      "  (qu_article_info_conv): conv_article_layer(\n",
      "    (article_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(5, 75), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "    (pos_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(15, 1), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "  )\n",
      "  (mlp1): Linear(in_features=234, out_features=200, bias=True)\n",
      "  (mlp2): Linear(in_features=200, out_features=100, bias=True)\n",
      "  (mlp3): Linear(in_features=100, out_features=1, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "class conv_pos_power_net(torch.nn.Module):\n",
    "    def __init__(self,batch_size,in_channels,out_channels):\n",
    "        super(conv_pos_power_net,self).__init__()\n",
    "        self.batch_size = batch_size\n",
    "        self.in_channels = in_channels\n",
    "        self.out_channels = out_channels\n",
    "        self.verb_num = 3\n",
    "        self.srl_len = 25\n",
    "        self.graph_size = self.verb_num*self.srl_len\n",
    "        self.sent_num = 5\n",
    "        \n",
    "        \n",
    "        self.core_power_conv = conv_one_sent_layer(self.batch_size,self.graph_size,self.in_channels,self.out_channels,self.verb_num)\n",
    "        self.qu_power_conv = conv_one_sent_layer(self.batch_size,self.graph_size,self.in_channels,self.out_channels,self.verb_num)\n",
    "        self.article_sum_power_conv = conv_article_layer(self.batch_size,self.graph_size,self.sent_num,self.in_channels,self.out_channels,self.verb_num)\n",
    "        self.qu_article_power_conv = conv_article_layer(self.batch_size,self.graph_size,self.sent_num,self.in_channels,self.out_channels,self.verb_num)\n",
    "        self.core_article_power_conv = conv_article_layer(self.batch_size,self.graph_size,self.sent_num,self.in_channels,self.out_channels,self.verb_num)\n",
    "        self.cqi_info_conv = conv_one_sent_layer(self.batch_size,self.graph_size,self.in_channels,self.out_channels,self.verb_num)\n",
    "        self.cqp_power_conv = conv_one_sent_layer(self.batch_size,self.graph_size,self.in_channels,self.out_channels,self.verb_num)\n",
    "        self.core_article_info_conv = conv_article_layer(self.batch_size,self.graph_size,self.sent_num,self.in_channels,self.out_channels,self.verb_num)\n",
    "        self.qu_article_info_conv = conv_article_layer(self.batch_size,self.graph_size,self.sent_num,self.in_channels,self.out_channels,self.verb_num)\n",
    "#         self.flow_once_qu_article_power_conv = conv_article_layer(self.batch_size,self.graph_size,self.sent_num,self.in_channels,self.out_channels)\n",
    "#         self.flow_once_qu_core_power_conv = conv_one_sent_layer(self.batch_size,self.graph_size,self.in_channels,self.out_channels)\n",
    "#         self.flow_twice_qu_article_power_conv = conv_article_layer(self.batch_size,self.graph_size,self.sent_num,self.in_channels,self.out_channels)\n",
    "#         self.flow_twice_qu_core_power_conv = conv_one_sent_layer(self.batch_size,self.graph_size,self.in_channels,self.out_channels)\n",
    "        \n",
    "        #self.mlp1 = torch.nn.Linear((9+25*9+4*9),200*self.out_channels)\n",
    "        \n",
    "        self.mlp1 = torch.nn.Linear((9+25*9),200*self.out_channels)\n",
    "        self.mlp2 = torch.nn.Linear(200*self.out_channels,100*self.out_channels)\n",
    "        self.mlp3 = torch.nn.Linear(100*self.out_channels,1)\n",
    "\n",
    "\n",
    "#         self.mlp1 = torch.nn.Linear((25*9),200*self.out_channels)\n",
    "#         self.mlp2 = torch.nn.Linear(200*self.out_channels,100*self.out_channels)\n",
    "#         self.mlp3 = torch.nn.Linear(100*self.out_channels,1)\n",
    "        \n",
    "#         self.mlp1 = torch.nn.Linear(9,20*self.out_channels)\n",
    "#         self.mlp2 = torch.nn.Linear(20*self.out_channels,10*self.out_channels)\n",
    "#         self.mlp3 = torch.nn.Linear(10*self.out_channels,1)\n",
    "        \n",
    "    def forward(self,   \n",
    "                core_graph,    \n",
    "                qu_graph,\n",
    "                core_article_power,\n",
    "                core_qu_info,\n",
    "                core_qu_power,\n",
    "                article_sum,\n",
    "                qu_article_power_sum,\n",
    "                core_article_info,\n",
    "                qu_article_info_sum,\n",
    "#                 flow_once_qu_article_power,\n",
    "#                 flow_once_qu_core_power,\n",
    "#                 flow_twice_qu_article_power,\n",
    "#                 flow_twice_qu_core_power,\n",
    "                \n",
    "               ):\n",
    "        \n",
    "        \n",
    "        core_power,core_pos_i_power = self.core_power_conv(core_graph)\n",
    "        #print(core_power.size())\n",
    "        #print(core_pos_i_power.size())\n",
    "        qu_power,qu_pos_i_power= self.qu_power_conv(qu_graph)\n",
    "        core_sent_article_power,core_sent_article_pos_i_power = self.core_article_power_conv(core_article_power)\n",
    "        context_power,context_pos_i_power = self.article_sum_power_conv(article_sum)\n",
    "        qu_article_power,qu_article_pos_i_power = self.qu_article_power_conv(qu_article_power_sum)\n",
    "        core_in_qu_power,core_in_qu_pos_i_power = self.cqp_power_conv(core_qu_power)\n",
    "        qu_in_core_power,qu_in_core_pos_i_power = self.cqi_info_conv(core_qu_info)\n",
    "        core_sent_article_info,core_sent_article_pos_i_info = self.core_article_info_conv(core_article_info)\n",
    "        qu_article_info,qu_article_pos_i_info = self.qu_article_info_conv(qu_article_info_sum)\n",
    "#         flow_once_article,flow_once_article_pos = self.flow_once_qu_article_power_conv(flow_once_qu_article_power)\n",
    "#         flow_once_qu,flow_once_qu_pos = self.flow_once_qu_core_power_conv(flow_once_qu_core_power)\n",
    "#         flow_twice_article,flow_twice_article_pos = self.flow_twice_qu_article_power_conv(flow_twice_qu_article_power)\n",
    "#         flow_twice_qu,flow_twice_qu_pos = self.flow_twice_qu_core_power_conv(flow_twice_qu_core_power)\n",
    "        \n",
    "        \n",
    "        \n",
    "        x = torch.cat((\n",
    "            core_power,core_pos_i_power,\n",
    "            qu_power,qu_pos_i_power,\n",
    "            core_sent_article_power,core_sent_article_pos_i_power,\n",
    "            context_power,context_pos_i_power,\n",
    "            qu_article_power,qu_article_pos_i_power,\n",
    "            core_in_qu_power,core_in_qu_pos_i_power,\n",
    "            qu_in_core_power,qu_in_core_pos_i_power,\n",
    "            core_sent_article_info,core_sent_article_pos_i_info,\n",
    "            qu_article_info,qu_article_pos_i_info,\n",
    "            \n",
    "        ),3)\n",
    "        \n",
    "#         x = torch.cat((\n",
    "#             core_power,\n",
    "#             qu_power,\n",
    "#             core_sent_article_power,\n",
    "#             context_power,\n",
    "#             qu_article_power,\n",
    "#             core_in_qu_power,\n",
    "#             qu_in_core_power,\n",
    "#             core_sent_article_info,\n",
    "#             qu_article_info,\n",
    "            \n",
    "#         ),3)\n",
    "\n",
    "\n",
    "#         x = torch.cat((\n",
    "#             core_pos_i_power,\n",
    "#             qu_pos_i_power,\n",
    "#             core_sent_article_pos_i_power,\n",
    "#             context_pos_i_power,\n",
    "#             qu_article_pos_i_power,\n",
    "#             core_in_qu_pos_i_power,\n",
    "#             qu_in_core_pos_i_power,\n",
    "#             core_sent_article_pos_i_info,\n",
    "#             qu_article_pos_i_info,\n",
    "            \n",
    "#         ),3)\n",
    "        \n",
    "        x = x.view(self.batch_size, -1)\n",
    "        x = self.mlp1(x)\n",
    "        #print(\"mlp1:\",x.size())\n",
    "        x = self.mlp2(x)\n",
    "        x = self.mlp3(x)\n",
    "        x = torch.sigmoid(x)\n",
    "        #print(\"x\"x.size())\n",
    "        \n",
    "        return x\n",
    "        \n",
    "model = conv_pos_power_net(64,1,1)\n",
    "\n",
    "#output = nn.parallel.data_parallel(new_net, input, device_ids=[0, 1])\n",
    "print(model)\n",
    "#out = model()\n",
    "if torch.cuda.is_available():\n",
    "    model = model.cuda()\n",
    "    #model = torch.nn.DataParallel(model,device_ids=[0,1,2,3])        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [],
   "source": [
    "#criterion = nn.CrossEntropyLoss()\n",
    "#criterion = nn.BCEWithLogitsLoss()\n",
    "learning_rate = 0.01\n",
    "criterion = nn.BCELoss()\n",
    "optimizer = optim.Adam(model.parameters(),lr=learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 训练模型\n",
    "def train_code_confing(train_loader,dev_loader,epoch_num):\n",
    "    dev_epoch_num = 0\n",
    "    for num in range(epoch_num):\n",
    "        \n",
    "        epoch = 0\n",
    "        \n",
    "                    \n",
    "        for data in tqdm(train_loader):\n",
    "            #print(data['flow_once_qu_article_power'].type())\n",
    "            if torch.cuda.is_available():\n",
    "\n",
    "                #qu_graph = data['qu_graph'].cuda() \n",
    "                qu_graph = data['qu_graph'].cuda()\n",
    "                core_graph = data['core_graph'].cuda()\n",
    "                core_article_info = data['core_article_infor'].cuda()\n",
    "                #print(core_article_infor.size())\n",
    "                core_article_power = data['core_article_power'].cuda()    \n",
    "                core_qu_infor = data['core_qu_infor'].cuda()    \n",
    "                core_qu_power = data['core_qu_power'].cuda()\n",
    "                article_sum = data['article_sum'].cuda()\n",
    "                qu_article_info_sum = data['qu_article_info_sum'].cuda()\n",
    "                #print(data['qu_article_info_sum'].type())\n",
    "                \n",
    "                qu_article_power_sum = data['qu_article_power_sum'].cuda()\n",
    "#                 flower_once_qu_article_power = data['flow_once_qu_article_power'].float().cuda(),\n",
    "#                 flower_once_qu_core_power = data['flow_once_qu_core_power'].float().cuda(),\n",
    "#                 flower_twice_qu_article_power = data['flow_twice_qu_article_power'].float().cuda(),\n",
    "#                 flower_twice_qu_core_power = data['flow_twice_qu_core_power'].float().cuda(),\n",
    "                \n",
    "                \n",
    "                #merge_sum_sent_qu = data['merge_sum_sent_qu'].cuda()\n",
    "                #merge_sum_sent_other = data['merge_sum_sent_other'].cuda()\n",
    "                \n",
    "                \n",
    "                    #print(\"size:\",con_value[1].size())\n",
    "#                     #print(\"cai_size:\",core_article_infor.size())\n",
    "#                 print(data['flow_once_qu_article_power'].type())\n",
    "#                 print(data['flow_once_qu_article_power'].cuda().type())\n",
    "    \n",
    "#                 print(flower_once_qu_article_power.type())\n",
    "                out = model(core_graph,\n",
    "                            qu_graph,\n",
    "                            #core_article_infor,\n",
    "                            core_article_power,\n",
    "                            core_qu_infor,\n",
    "                            core_qu_power,\n",
    "                            article_sum,\n",
    "                            qu_article_power_sum,\n",
    "                            core_article_info,\n",
    "                            qu_article_info_sum,\n",
    "                            #data['qu_article_info_sum'].cuda(),\n",
    "#                             data['flow_once_qu_article_power'].float().cuda(),\n",
    "#                             data['flow_once_qu_core_power'].float().cuda(),\n",
    "#                             data['flow_twice_qu_article_power'].float().cuda(),\n",
    "#                             data['flow_twice_qu_core_power'].float().cuda(),\n",
    "                            #flower_once_qu_article_power.float().cuda(),\n",
    "                            #flower_once_qu_core_power.float().cuda(),\n",
    "                            )\n",
    "\n",
    "#                 out = nn.parallel.data_parallel(model, (core_article_infor,core_article_power,\n",
    "#                             core_qu_infor,core_qu_power,\n",
    "#                             qu_core_infor,qu_core_power), device_ids=[0,1,2,3])\n",
    "                    #print(\"out:\",out)\n",
    "                label = data['label'].cuda()\n",
    "                #example_id = data['example_id']\n",
    "#                 print(\"p_out_t:\",out.size())\n",
    "#                 print(\"p_out_t:\",out)\n",
    "\n",
    "                out = out.squeeze(1)\n",
    "                #label = label.unsqueeze(0)\n",
    "#                 print(\"out_t:\",out.size())\n",
    "#                 print(\"out_t:\",out)\n",
    "#                 print(\"label:\",label.size())\n",
    "#                 print(\"label:\",label)\n",
    "                loss = criterion(out,label)\n",
    "\n",
    "                print_loss = loss.data.item()\n",
    "\n",
    "                optimizer.zero_grad()\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "                epoch+=1\n",
    "                if epoch%500 == 0:\n",
    "                    print('epoch: {}, loss: {:.4}'.format(epoch, loss.data.item()))\n",
    "                    eval_code_confing(dev_loader)\n",
    "\n",
    "            else:\n",
    "                print(\"train_error\")\n",
    "            \n",
    "        dev_epoch_num = dev_epoch_num + 1\n",
    "        if dev_epoch_num%1 == 0:\n",
    "            #print(\"epoch_num:\",epoch_num)\n",
    "            \n",
    "            eval_code_confing(dev_loader)\n",
    "        torch.save(model, './flow_srl_feature_v2.0_5.pkl')\n",
    "\n",
    "\n",
    "        \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [],
   "source": [
    "def exact_example_num(predict_label_list,dev_label_list,example_id_list):\n",
    "    \n",
    "    #predict_label_list = predict_label_list.detach().cpu()\n",
    "    example_id_list = [ i.detach().cpu() for i in example_id_list ]\n",
    "    predict_label_list = [ i.detach().cpu() for i in predict_label_list ]\n",
    "    dev_label_list = [ i.detach().cpu() for i in dev_label_list ]\n",
    "\n",
    "    assert len(predict_label_list) == len(dev_label_list)\n",
    "    assert len(predict_label_list) == len(example_id_list)\n",
    "\n",
    "    num = len(dev_label_list)\n",
    "    n = np.max(example_id_list) + 1\n",
    "    print(\"num:\",num)\n",
    "    print(\"n:\",n)\n",
    "    gold_sp = [ [] for i in range(n) ]\n",
    "    pred_sp = [ [] for i in range(n) ]\n",
    "\n",
    "    last_id,sd = -1,0\n",
    "    for i in range(num):\n",
    "        if example_id_list[i] != last_id:\n",
    "            last_id, sd = example_id_list[i], 0\n",
    "        if dev_label_list[i].long() == 1:\n",
    "            gold_sp[last_id].append(sd)\n",
    "        if predict_label_list[i] > 0.5:\n",
    "            pred_sp[last_id].append(sd)\n",
    "        sd += 1\n",
    "    \n",
    "    em,f1 = 0,0\n",
    "    for i in range(n):\n",
    "        cur_sp_pred = pred_sp[i]\n",
    "        gold_sp_pred = gold_sp[i]\n",
    "        tp, fp, fn = 0, 0, 0\n",
    "        for e in cur_sp_pred:\n",
    "            if e in gold_sp_pred:\n",
    "                tp += 1\n",
    "            else:\n",
    "                fp += 1\n",
    "        for e in gold_sp_pred:\n",
    "            if e not in cur_sp_pred:\n",
    "                fn += 1\n",
    "        prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0\n",
    "        recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0\n",
    "        f1 += 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0\n",
    "        em += 1.0 if fp + fn == 0 else 0.0\n",
    "    em /= n\n",
    "    f1 /= n\n",
    "    print(\"em:\",em)\n",
    "    print(\"f1:\",f1)\n",
    "    return em,f1\n",
    "                "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_code_confing(dev_loader):\n",
    "    # 模型评估\n",
    "    print(\"load_function\")\n",
    "    model.eval()\n",
    "    eval_loss = 0\n",
    "    eval_acc = 0\n",
    "    \n",
    "    predict_label_list = []\n",
    "    dev_label_list = []\n",
    "    example_id_list = []\n",
    "    for data in tqdm(dev_loader):\n",
    "        \n",
    "        if torch.cuda.is_available():\n",
    "            \n",
    "            #qu_graph = data['qu_graph'].cuda()              \n",
    "            #qu_graph = data['qu_graph'].cuda() \n",
    "            qu_graph = data['qu_graph'].cuda()\n",
    "            core_graph = data['core_graph'].cuda()\n",
    "            core_article_info = data['core_article_infor'].cuda()\n",
    "            #print(core_article_infor.size())\n",
    "            core_article_power = data['core_article_power'].cuda()    \n",
    "            core_qu_infor = data['core_qu_infor'].cuda()    \n",
    "            core_qu_power = data['core_qu_power'].cuda()\n",
    "            article_sum = data['article_sum'].cuda()\n",
    "            qu_article_info_sum = data['qu_article_info_sum'].cuda()\n",
    "            qu_article_power_sum = data['qu_article_power_sum'].cuda()\n",
    "#             flower_once_qu_article_power = data['flow_once_qu_article_power'].cuda(),\n",
    "#             flower_once_qu_core_power = data['flow_once_qu_core_power'].cuda(),\n",
    "#             flower_twice_qu_article_power = data['flow_twice_qu_article_power'].cuda(),\n",
    "#             flower_twice_qu_core_power = data['flow_twice_qu_core_power'].cuda(),\n",
    "            example_id = data['example_id']\n",
    "            #merge_sum_sent_qu = data['merge_sum_sent_qu'].cuda()\n",
    "            #merge_sum_sent_other = data['merge_sum_sent_other'].cuda()\n",
    "\n",
    "\n",
    "                #print(\"size:\",con_value[1].size())\n",
    "                #print(\"cai_size:\",core_article_infor.size())\n",
    "            out = model(qu_graph,\n",
    "                        core_graph,\n",
    "                        #core_article_infor,\n",
    "                        core_article_power,\n",
    "                        core_qu_infor,\n",
    "                        core_qu_power,\n",
    "                        article_sum,\n",
    "                        qu_article_power_sum,\n",
    "                        core_article_info,\n",
    "                        qu_article_info_sum,\n",
    "#                         data['flow_once_qu_article_power'].float().cuda(),\n",
    "#                         data['flow_once_qu_core_power'].float().cuda(),\n",
    "#                         data['flow_twice_qu_article_power'].float().cuda(),\n",
    "#                         data['flow_twice_qu_core_power'].float().cuda(),\n",
    "                        #flower_once_qu_article_power,\n",
    "                        #flower_once_qu_core_power,\n",
    "                        )\n",
    "            #print(\"out:\",out)\n",
    "            label = data['label'].cuda()\n",
    "           \n",
    "            #print(\"p_out_t:\",out.size())\n",
    "            #print(\"p_out_t:\",out)\n",
    "\n",
    "            out = out.squeeze(1)\n",
    "            loss = criterion(out,label)\n",
    "            \n",
    "            #print(\"out:\",out)\n",
    "            predict_label_list.extend(out)\n",
    "            #print(len(predict_label_list))\n",
    "            #print(\"predict_label_list:\",predict_label_list)\n",
    "            \n",
    "            #return 0\n",
    "            #print(predict_label_list[1])\n",
    "            dev_label_list.extend(label)\n",
    "            example_id_list.extend(example_id)\n",
    "            \n",
    "            \n",
    "            \n",
    "            \n",
    "            eval_loss += loss.data.item()\n",
    "            pred = out > 0.5\n",
    "            num_correct = (pred.long().reshape(1,-1) == label.long()).sum()\n",
    "            #print(\"pred:\",pred.long().reshape(1,-1))\n",
    "            #print(\"label:\",label.long())\n",
    "            #print(\"num_correct:\",num_correct)\n",
    "            eval_acc += num_correct.item()\n",
    "            #print(\"eval_acc:\",eval_acc)\n",
    "            \n",
    "\n",
    "\n",
    "    #sent_num = 2355\n",
    "    #sent_num = 6911\n",
    "    \n",
    "    print(\"predict_label_list:\",len(predict_label_list),\"dev_label_list:\",len(dev_label_list),\"example_id_list:\",len(example_id_list))\n",
    "    sent_num = len(dev_loader)*64\n",
    "    context_exact = exact_example_num(predict_label_list,dev_label_list,example_id_list)\n",
    "    print('Test Loss: {:.6f}, Acc: {:.6f}'.format(\n",
    "        eval_loss / sent_num,\n",
    "        eval_acc / sent_num,\n",
    "        #context_exact\n",
    "    ))\n",
    "    #return dev_merge_list_to_graph(predict_label_list),dev_merge_list_to_graph(dev_label_list),dev_merge_list_to_graph(example_id_list)\n",
    "\n",
    "\n",
    "#     print('Test Loss: {:.6f}, Acc: {:.6f}'.format(\n",
    "#         eval_loss / (len(dev_s5_v3_loader)),\n",
    "#         eval_acc / (len(dev_s5_v3_loader))\n",
    "#     ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:06<00:00, 59.29it/s]\n",
      " 36%|███▌      | 13/36 [00:00<00:00, 122.25it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 134.58it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.19305856832971802\n",
      "f1: 0.7388957752298333\n",
      "Test Loss: 0.009469, Acc: 0.710938\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:06<00:00, 60.92it/s]\n",
      " 44%|████▍     | 16/36 [00:00<00:00, 154.02it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 155.86it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.24078091106290672\n",
      "f1: 0.7312743173914555\n",
      "Test Loss: 0.008861, Acc: 0.731771\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:07<00:00, 54.49it/s]\n",
      " 47%|████▋     | 17/36 [00:00<00:00, 168.38it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 165.14it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.26898047722342733\n",
      "f1: 0.7275195399924267\n",
      "Test Loss: 0.008847, Acc: 0.738281\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:06<00:00, 59.90it/s]\n",
      " 47%|████▋     | 17/36 [00:00<00:00, 168.99it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 168.66it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.2646420824295011\n",
      "f1: 0.726052749371623\n",
      "Test Loss: 0.008839, Acc: 0.740885\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:07<00:00, 58.75it/s]\n",
      " 36%|███▌      | 13/36 [00:00<00:00, 128.33it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 128.37it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.25162689804772237\n",
      "f1: 0.7109096856385372\n",
      "Test Loss: 0.008899, Acc: 0.736111\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:07<00:00, 56.30it/s]\n",
      " 42%|████▏     | 15/36 [00:00<00:00, 148.86it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 151.11it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.24078091106290672\n",
      "f1: 0.7014151430637342\n",
      "Test Loss: 0.008836, Acc: 0.730469\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:07<00:00, 57.47it/s]\n",
      " 44%|████▍     | 16/36 [00:00<00:00, 157.05it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 158.60it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.2386117136659436\n",
      "f1: 0.7085184037461704\n",
      "Test Loss: 0.009262, Acc: 0.733073\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:07<00:00, 55.36it/s]\n",
      " 39%|███▉      | 14/36 [00:00<00:00, 131.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 131.30it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.25162689804772237\n",
      "f1: 0.7142047997796381\n",
      "Test Loss: 0.009012, Acc: 0.736545\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:07<00:00, 55.73it/s]\n",
      " 47%|████▋     | 17/36 [00:00<00:00, 164.91it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 162.07it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.24511930585683298\n",
      "f1: 0.7120269944564965\n",
      "Test Loss: 0.008886, Acc: 0.734375\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:07<00:00, 58.83it/s]\n",
      " 47%|████▋     | 17/36 [00:00<00:00, 165.16it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 163.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.2472885032537961\n",
      "f1: 0.714500912440176\n",
      "Test Loss: 0.008915, Acc: 0.732639\n"
     ]
    }
   ],
   "source": [
    "#sent5verb3\n",
    "epoch_num = 10\n",
    "train_code_confing(train_loader,dev_loader,epoch_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:06<00:00, 66.54it/s]\n",
      " 36%|███▌      | 13/36 [00:00<00:00, 128.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 128.00it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 8/415 [00:00<00:05, 76.87it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "em: 0.25813449023861174\n",
      "f1: 0.7069724202045252\n",
      "Test Loss: 0.008556, Acc: 0.732639\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:05<00:00, 69.72it/s]\n",
      " 44%|████▍     | 16/36 [00:00<00:00, 152.56it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 153.09it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.2472885032537961\n",
      "f1: 0.703775436421858\n",
      "Test Loss: 0.008571, Acc: 0.733507\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:05<00:00, 76.97it/s]\n",
      " 44%|████▍     | 16/36 [00:00<00:00, 156.32it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 156.51it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.24078091106290672\n",
      "f1: 0.7007798781117661\n",
      "Test Loss: 0.008588, Acc: 0.731771\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:05<00:00, 80.91it/s]\n",
      " 44%|████▍     | 16/36 [00:00<00:00, 153.41it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 153.27it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.23427331887201736\n",
      "f1: 0.7021399304479574\n",
      "Test Loss: 0.008612, Acc: 0.733941\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:05<00:00, 78.76it/s]\n",
      " 47%|████▋     | 17/36 [00:00<00:00, 162.46it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 161.60it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.23210412147505424\n",
      "f1: 0.6958991839685991\n",
      "Test Loss: 0.008637, Acc: 0.730469\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:05<00:00, 80.10it/s]\n",
      " 44%|████▍     | 16/36 [00:00<00:00, 154.86it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 158.63it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.21691973969631237\n",
      "f1: 0.6891419619185357\n",
      "Test Loss: 0.008659, Acc: 0.725694\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:05<00:00, 79.36it/s]\n",
      " 44%|████▍     | 16/36 [00:00<00:00, 153.76it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 153.14it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.22125813449023862\n",
      "f1: 0.6905140653513768\n",
      "Test Loss: 0.008678, Acc: 0.725260\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:05<00:00, 81.04it/s]\n",
      " 47%|████▋     | 17/36 [00:00<00:00, 167.38it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 169.05it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.21258134490238612\n",
      "f1: 0.6900027545363782\n",
      "Test Loss: 0.008678, Acc: 0.724826\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:05<00:00, 81.52it/s]\n",
      " 47%|████▋     | 17/36 [00:00<00:00, 166.47it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 163.08it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.20607375271149675\n",
      "f1: 0.6890627689976941\n",
      "Test Loss: 0.008675, Acc: 0.723958\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:05<00:00, 81.86it/s]\n",
      " 47%|████▋     | 17/36 [00:00<00:00, 165.22it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 165.00it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.20824295010845986\n",
      "f1: 0.6906328547326391\n",
      "Test Loss: 0.008675, Acc: 0.724826\n"
     ]
    }
   ],
   "source": [
    "#sent5verb4_global\n",
    "epoch_num = 10\n",
    "train_code_confing(train_loader,dev_loader,epoch_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 84.58it/s]\n",
      " 50%|█████     | 18/36 [00:00<00:00, 173.10it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 169.71it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.19522776572668113\n",
      "f1: 0.6809265571738466\n",
      "Test Loss: 0.008960, Acc: 0.707899\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 84.33it/s]\n",
      " 50%|█████     | 18/36 [00:00<00:00, 170.95it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 173.08it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.14533622559652928\n",
      "f1: 0.7312812037323988\n",
      "Test Loss: 0.009795, Acc: 0.688368\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:05<00:00, 80.26it/s]\n",
      " 42%|████▏     | 15/36 [00:00<00:00, 148.67it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 151.01it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.19305856832971802\n",
      "f1: 0.7359449781358691\n",
      "Test Loss: 0.009260, Acc: 0.709635\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:05<00:00, 80.38it/s]\n",
      " 44%|████▍     | 16/36 [00:00<00:00, 157.49it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 160.61it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.23427331887201736\n",
      "f1: 0.7363220053024842\n",
      "Test Loss: 0.008900, Acc: 0.726562\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:05<00:00, 78.50it/s]\n",
      " 39%|███▉      | 14/36 [00:00<00:00, 138.18it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 145.47it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.2386117136659436\n",
      "f1: 0.7328753916606427\n",
      "Test Loss: 0.008865, Acc: 0.726562\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:05<00:00, 80.94it/s]\n",
      " 47%|████▋     | 17/36 [00:00<00:00, 169.73it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 168.96it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.24945770065075923\n",
      "f1: 0.7239661880659728\n",
      "Test Loss: 0.008622, Acc: 0.735243\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 83.23it/s]\n",
      " 47%|████▋     | 17/36 [00:00<00:00, 162.00it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 157.41it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.2559652928416486\n",
      "f1: 0.7245635781427557\n",
      "Test Loss: 0.008691, Acc: 0.733073\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:05<00:00, 82.49it/s]\n",
      " 50%|█████     | 18/36 [00:00<00:00, 177.46it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 179.79it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.24945770065075923\n",
      "f1: 0.7207468236752418\n",
      "Test Loss: 0.008643, Acc: 0.735677\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 86.33it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 181.37it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 8/415 [00:00<00:05, 79.09it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.2559652928416486\n",
      "f1: 0.719460799504185\n",
      "Test Loss: 0.008639, Acc: 0.736111\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:05<00:00, 79.53it/s]\n",
      " 42%|████▏     | 15/36 [00:00<00:00, 143.91it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 145.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.25162689804772237\n",
      "f1: 0.7201787005474655\n",
      "Test Loss: 0.008583, Acc: 0.733941\n"
     ]
    }
   ],
   "source": [
    "#sent5verb4_local\n",
    "epoch_num = 10\n",
    "train_code_confing(train_loader,dev_loader,epoch_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "learn_py3",
   "language": "python",
   "name": "learn_py3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
