{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "504f642b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "flag = torch.cuda.is_available()\n",
    "print(flag)\n",
    "\n",
    "ngpu= 1  \n",
    "# Decide which device we want to run on \n",
    "device = torch.device(\"cuda:0\" if (torch.cuda.is_available() and ngpu > 0) else \"cpu\")\n",
    "torch.cuda.set_device(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "98a8925c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "from torch.utils import data \n",
    "from torch.autograd import Variable \n",
    "import torchvision\n",
    "from torchvision.datasets import mnist \n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.autograd import Variable\n",
    "from torch import optim\n",
    "from torch import nn\n",
    "import json\n",
    "import numpy as np\n",
    "import string\n",
    "import os\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c88ea01d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class qa_train_context_dataset(Dataset):\n",
    "    \n",
    "        \n",
    "    def graph_refine_data( self,case ):\n",
    "        graph_context = []\n",
    "    \n",
    "        for paragh in case[\"context\"]:\n",
    "            qu_graph = self.float_to_int(self.str_to_int(case['srl_question_to_graph']))\n",
    "            sub_context = []\n",
    "      \n",
    "            for sent_idx,sent_value in enumerate(paragh):\n",
    "                #print(\"s_v:\",sent_value[5])\n",
    "           \n",
    "                sub_context.append([self.float_to_int(self.str_to_int(sent_value[0])),   \n",
    "                                    self.float_to_int(self.str_ar_to_int(sent_value[1])),\n",
    "                                    self.float_to_int(self.str_ar_to_int(sent_value[2])),\n",
    "                                    self.float_to_int(self.str_to_int(sent_value[3])),   \n",
    "                                    self.float_to_int(self.str_to_int(sent_value[4])),\n",
    "                                    self.true_1_false_0(sent_value[5]),\n",
    "                                     ])\n",
    "            \n",
    "            graph_context.append(sub_context)\n",
    "    \n",
    "        return dict([\n",
    "                    (\"_id\",case[\"_id\"]),\n",
    "                    #(\"answer\",case['answer']),\n",
    "                    #(\"question\", case['question']),\n",
    "                    (\"qu_graph\", qu_graph),\n",
    "                    #(\"srl_question\", qu_sent_dict_to_graph(qu_sent_dict),srl_parse_sent(predictor_srl,case['question']) ),\n",
    "                    #(\"supporting_facts\",case[\"supporting_facts\"]),\n",
    "                    (\"context\",graph_context),\n",
    "                    (\"type\",case[\"type\"]),\n",
    "                    (\"level\",case[\"level\"]), \n",
    "                    ])\n",
    "    \n",
    "    \n",
    "    def example_to_sentlist(self,case,example_id):\n",
    "        sent_graph_list = []\n",
    "        \n",
    "        qu_graph = case['qu_graph']\n",
    "        context0 = case['context'][0]\n",
    "        context1 = case['context'][1]\n",
    "        art_id = case[\"_id\"]\n",
    "        #print(\"c0_n:\",len(context0))\n",
    "        #print(\"c1_n:\",len(context1))\n",
    "        context0.extend(context1)\n",
    "        #print(\"c0_u_n:\",len(context0))\n",
    "        #print(context0[0][5])\n",
    "\n",
    "            #print(qu_graph)\n",
    "        \n",
    "#flow_power_start=======================================================================        \n",
    "#         core_article_infor_list = []\n",
    "#         core_qu_infor_list = []\n",
    "        verb_num = 3\n",
    "        srl_label_num = 25\n",
    "        sent_size = srl_label_num*verb_num\n",
    "        sent_num = 5\n",
    "        qu_graph = qu_graph.float().reshape(1,1,sent_size)\n",
    "#         flow_once_qu_article_power_list = []\n",
    "        \n",
    "#         for con_value in context0 :\n",
    "#             #con_value[1] = con_value[1].float().reshape(1,1,75,5).unsqueeze(0).unsqueeze(0).cuda()\n",
    "            \n",
    "#             core_article_infor = con_value[1].float().reshape(1,sent_num,sent_size)\n",
    "#             core_qu_infor = con_value[4].float().reshape(1,1,sent_size)\n",
    "#             flow_once_qu_article_power_list.append(\n",
    "#                 self.article_value(core_article_infor,sent_num,sent_size).mul(\n",
    "#                 self.pad_core_qu_info(core_qu_infor,sent_num,sent_size)))\n",
    "#             #print(self.article_value(core_article_infor,sent_num,sent_size).reshape(10,100))\n",
    "#             #print(self.pad_core_qu_info(core_qu_infor,sent_num,sent_size).reshape(10,100))\n",
    "#         flow_once_qu_article_power_graph = self.sum_flow_article_qu_power(flow_once_qu_article_power_list)\n",
    "#         #print(flow_once_qu_article_power_graph.reshape(10,100))\n",
    "#         flow_once_core_list = self.divide_article_to_sent(flow_once_qu_article_power_graph,\n",
    "#                                                          sent_num,sent_size)\n",
    "        \n",
    "#         #flow_twice==========================================================================================================\n",
    "#         flow_twice_qu_article_power_list = []\n",
    "        \n",
    "#         sent_id = 0\n",
    "#         for con_value in context0 :\n",
    "#             #con_value[1] = con_value[1].float().reshape(1,1,75,5).unsqueeze(0).unsqueeze(0).cuda()\n",
    "            \n",
    "#             core_article_infor = con_value[1].float().reshape(1,sent_num,sent_size)\n",
    "            \n",
    "#             #print(flow_once_qu_article_power_list[sent_id].size())\n",
    "#             flow_twice_qu_article_power_list.append(\n",
    "#                 self.article_value(core_article_infor,sent_num,sent_size).mul(\n",
    "#                 self.pad_core_qu_info(flow_once_core_list[sent_id],sent_num,sent_size)))\n",
    "#             sent_id = sent_id + 1 \n",
    "            \n",
    "#         flow_twice_qu_article_power_graph = self.sum_flow_article_qu_power(flow_twice_qu_article_power_list)\n",
    "#         flow_twice_core_list = self.divide_article_to_sent(flow_twice_qu_article_power_graph,\n",
    "#                                                          sent_num,sent_size)\n",
    "        \n",
    "        #flow_once_core_list\n",
    "#flow_power_end=======================================================================               \n",
    "            \n",
    "           \n",
    "        \n",
    "        core_sent_id = -1\n",
    "        for con_value in context0 :\n",
    "            core_sent_id = core_sent_id + 1\n",
    "            #con_value[1] = con_value[1].float().reshape(1,1,75,5).unsqueeze(0).unsqueeze(0).cuda()\n",
    "            core_graph = con_value[0].float().reshape(1,1,sent_size)\n",
    "            core_article_infor = con_value[1].float().reshape(1,sent_num,sent_size)\n",
    "            core_article_power = con_value[2].float().reshape(1,sent_num,sent_size)\n",
    "            core_qu_infor = con_value[4].float().reshape(1,1,sent_size)\n",
    "            core_qu_power = con_value[3].float().reshape(1,1,sent_size)\n",
    "            #core_qu_infor = torch.cat((con_value[0],con_value[3]),0).reshape(1,2,75)\n",
    "            #core_qu_power = torch.cat((con_value[0],con_value[4]),0).reshape(1,2,75)\n",
    "            #qu_core_infor = torch.cat((qu_graph,con_value[4]),0).reshape(1,2,75)\n",
    "            #qu_core_power = torch.cat((qu_graph,con_value[3]),0).reshape(1,2,75)\n",
    "\n",
    "            con_value[5] = con_value[5].float()\n",
    "            label = con_value[5]\n",
    "            #print(\"example_id_N:\",example_id)\n",
    "#                                  core_sent_dict_to_graph(core_sent_dict) , \n",
    "#                                  core_graph_data_core_other(para,core_sent_dict) , \n",
    "#                                  core_graph_data_other(para,core_sent_dict),\n",
    "#                                  core_graph_data_qu(core_sent_dict,qu_sent_dict),\n",
    "#                                  qu_graph_data_core(core_sent_dict,qu_sent_dict),\n",
    "            \n",
    "            merge_sum_sent_qu = []\n",
    "            merge_sum_sent_other = []\n",
    "            \n",
    "            article_sum = []\n",
    "            qu_article_info_sum = []\n",
    "            qu_article_power_sum = []\n",
    "            other_sent_id = 0\n",
    "            for other_sent_value in context0 :\n",
    "                other_sent_value_self = other_sent_value[0].float().reshape(1,1,sent_size)\n",
    "                other_sent_value_cq = other_sent_value[3].float().reshape(1,1,sent_size)\n",
    "                other_sent_value_qc = other_sent_value[4].float().reshape(1,1,sent_size)\n",
    "                core_article_power_list = core_article_power.reshape(sent_num,sent_size)\n",
    "                c_o_sent_value = core_article_power_list[other_sent_id]\n",
    "                \n",
    "                #merge_sent_qu_cosent = torch.cat((qu_graph,other_sent_value_qc + core_qu_power),0).reshape(1,2,100)\n",
    "                #merge_sent_o_cqusent = torch.cat((other_sent_value_self,other_sent_value_cq+c_o_sent_value),0).reshape(1,2,100)\n",
    "               \n",
    "                #merge_sum_sent_qu.extend(merge_sent_qu_cosent)\n",
    "                #merge_sum_sent_other.extend(merge_sent_o_cqusent)\n",
    "                article_sum.extend(other_sent_value_self)\n",
    "                qu_article_info_sum.extend(other_sent_value_cq)\n",
    "                qu_article_power_sum.extend(other_sent_value_qc)\n",
    "                \n",
    "                other_sent_id = other_sent_id + 1\n",
    "                \n",
    "                \n",
    "            sent_graph_list.append([core_article_infor,\n",
    "                                   core_article_power,\n",
    "                                   core_qu_infor,\n",
    "                                   core_qu_power,\n",
    "                                   qu_graph,\n",
    "                                   core_graph,\n",
    "                                   label,\n",
    "                                   example_id,\n",
    "                                   self.merge_list_to_graph(self.padding_sum_graph(sent_num,article_sum,sent_size)),\n",
    "                                   self.merge_list_to_graph(self.padding_sum_graph(sent_num,qu_article_info_sum,sent_size)),\n",
    "                                   self.merge_list_to_graph(self.padding_sum_graph(sent_num,qu_article_power_sum,sent_size)),\n",
    "                                   art_id,\n",
    "#                                    flow_once_qu_article_power_graph.reshape(1,sent_num,sent_size),\n",
    "#                                    flow_once_core_list[core_sent_id].reshape(1,1,sent_size),\n",
    "#                                    flow_twice_qu_article_power_graph.reshape(1,sent_num,sent_size),\n",
    "#                                    flow_twice_core_list[core_sent_id].reshape(1,1,sent_size),\n",
    "                                   ])\n",
    "        return sent_graph_list\n",
    "        \n",
    "        \n",
    "    \n",
    "                \n",
    "        \n",
    "#flow_power_tools_start==============================================================================       \n",
    "\n",
    "    def article_value(self,core_article_info,sent_num,sent_size):\n",
    "        a = core_article_info.reshape(sent_num*sent_size).int()\n",
    "        contant = sent_size\n",
    "        i = 0\n",
    "        b = []\n",
    "        #graph_value = null\n",
    "\n",
    "        for i in range(sent_num):\n",
    "            b.append(sum(a[i*contant:i*contant+contant]))\n",
    "\n",
    "        token = 0\n",
    "        pad_0 = torch.zeros(sent_size)\n",
    "        pad_1 = torch.ones(sent_size)\n",
    "        for value in b:\n",
    "            if value == 0 and token == 0:\n",
    "                graph_value = pad_0\n",
    "                token = token+1\n",
    "                continue\n",
    "            if value >= 1 and token == 0:\n",
    "                graph_value = pad_1\n",
    "                token = token+1\n",
    "                continue\n",
    "            if value == 0:\n",
    "                graph_value = torch.cat((graph_value,pad_0),0)\n",
    "            if value >= 1:\n",
    "                graph_value = torch.cat((graph_value,pad_1),0)\n",
    "\n",
    "\n",
    "        return graph_value.int()\n",
    "\n",
    "    \n",
    "    def pad_core_qu_info(self,core_qu_info,sent_num,sent_size):\n",
    "        a = core_qu_info.reshape(sent_size)\n",
    "        pad_graph = a\n",
    "        for i in range(sent_num-1):\n",
    "            pad_graph = torch.cat((pad_graph,a),0)\n",
    "        return pad_graph.int()\n",
    " \n",
    "\n",
    " #    def core_article_qu_power_graph_list():\n",
    "#         core_article_qu_power_graph_list = []\n",
    "\n",
    "#         for case in data:\n",
    "#             core_article_qu_power_graph_list.append(article_value(case['']).mual(pad_core_qu_power(case[''])))\n",
    "\n",
    "\n",
    "\n",
    "#         return core_article_qu_info_graph_list\n",
    "#\n",
    "\n",
    "    def sum_flow_article_qu_power(self,core_article_qu_power_list):\n",
    "        token_id = 0\n",
    "\n",
    "        for case in core_article_qu_power_list:\n",
    "            if token_id == 0:\n",
    "                sum_article_qu_power_graph = case\n",
    "                token_id = token_id + 1\n",
    "                continue\n",
    "            sum_article_qu_power_graph = sum_article_qu_power_graph + case\n",
    "            #print(sum_article_qu_power_graph.reshape(10,100))\n",
    "        pos_id = 0    \n",
    "        for value in sum_article_qu_power_graph:\n",
    "            \n",
    "            if value >= 1:\n",
    "                #print(value)\n",
    "                sum_article_qu_power_graph[pos_id] = 1\n",
    "            pos_id = pos_id + 1\n",
    "        return sum_article_qu_power_graph\n",
    "\n",
    "#将flow_power_qu_article_power,\n",
    "    def divide_article_to_sent(self,sum_article_qu_power_graph,sent_num,sent_size):\n",
    "        a = sum_article_qu_power_graph.reshape(sent_num*sent_size).int()\n",
    "        contant = sent_size\n",
    "        for value in a:\n",
    "            if value >= 1:\n",
    "                a[value] = 1\n",
    "\n",
    "\n",
    "        flow_once_core_qu_power_list = []\n",
    "        contant = sent_size\n",
    "        for i in range(sent_num):\n",
    "            flow_once_core_qu_power_list.append(a[i*contant:i*contant+contant])\n",
    "\n",
    "\n",
    "\n",
    "        return flow_once_core_qu_power_list\n",
    "\n",
    "\n",
    "#flow_power_tools_end==============================================================================                    \n",
    "                \n",
    "    \n",
    "    \n",
    "    def __init__(self,file_name,transform = None):\n",
    "        self.file_name = file_name\n",
    "        self.transform = transform \n",
    "        with open(file_name, \"r\", encoding='utf-8') as reader:\n",
    "            orig_data = json.load(reader)\n",
    "            print(\"Load ok\")\n",
    "            #orig_data = orig_data[0:2]\n",
    "        self.orig_data = orig_data \n",
    "        self.srl_data = []\n",
    "        self.sum_sent_graph_list = []\n",
    "        for article in tqdm(orig_data):\n",
    "            #print(\"article:\",article)\n",
    "            self.srl_data.append( self.graph_refine_data(article) )\n",
    "         \n",
    "        example_id = 0\n",
    "        for case in tqdm(self.srl_data) :\n",
    "            #print(\"example_id:\",example_id)\n",
    "            self.sum_sent_graph_list.extend(self.example_to_sentlist(case,example_id)) \n",
    "            example_id = example_id + 1\n",
    "            \n",
    "            \n",
    "        #print(\"s:\",self.srl_data[0]['context'][0])\n",
    "        #print(\"sum:\",self.sum_sent_graph_list[0][6])\n",
    "        #print(\"sum:\",self.sum_sent_graph_list[0][7])\n",
    "        #print(\"sum:\",len(self.sum_sent_graph_list))\n",
    "    \n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.sum_sent_graph_list)    \n",
    "    \n",
    "    \n",
    "    def str_to_int(self,data_str):\n",
    "        data_str_list = data_str.split()\n",
    "        b_data = [ int(i) for i in data_str_list ]\n",
    "        \n",
    "        return b_data\n",
    "    \n",
    "    def str_ar_to_int(self,data_str):\n",
    "    \n",
    "    \n",
    "        data_str_list = data_str.replace(\"[\",' ').replace(\"]\",' ').replace(\",\",' ').split()\n",
    "    #print(data_str_list)\n",
    "        b_data = [ int(i) for i in data_str_list ]\n",
    "    #for s_idx,s_data in data_str_list:\n",
    "        #data_str_list[s_idx]=int(s_data)\n",
    "        \n",
    "        \n",
    "        return b_data\n",
    "    \n",
    "    def float_to_int(self,data):\n",
    "        return torch.Tensor(data).int()\n",
    "    \n",
    "    def true_1_false_0(self,data):\n",
    "        if data == True :\n",
    "            return torch.tensor(1)\n",
    "        if data == False :\n",
    "            return torch.tensor(0)\n",
    "    def merge_list_to_graph(self,merge_sum_sent_list):\n",
    "        merge_list_num = 0    \n",
    "        for sent_merge_data in merge_sum_sent_list:\n",
    "            if merge_list_num == 0 :\n",
    "                merge_data = sent_merge_data\n",
    "                merge_list_num = merge_list_num + 1\n",
    "                continue\n",
    "            merge_data = torch.cat((merge_data,sent_merge_data),0)\n",
    "            merge_list_num = merge_list_num + 1\n",
    "        return merge_data\n",
    "    \n",
    "    def padding_sum_graph(self,sent_num,sum_graph_list,sent_size):\n",
    "        pro_list = []\n",
    "        pro_list = sum_graph_list\n",
    "        add_pad_size = torch.zeros(sent_size,1).reshape(1,1,sent_size)\n",
    "\n",
    "        for add_padding in range(sent_num - len(sum_graph_list)):\n",
    "            pro_list.extend(add_pad_size)\n",
    "\n",
    "        return pro_list \n",
    "            \n",
    "        \n",
    "        \n",
    "    \n",
    "    def __getitem__(self,index):\n",
    "        \n",
    "        \n",
    "        sent_graph_index = self.sum_sent_graph_list[index]\n",
    "        #print(sent_graph_index)\n",
    "#         article_index = self.orig_data[index]\n",
    "        \n",
    "#         qu_graph = article_index['qu_graph']\n",
    "#         context0_value = article_index['context'][0]\n",
    "#         context1_value = article_index['context'][1]\n",
    "\n",
    "        core_article_infor = sent_graph_index[0]\n",
    "        core_article_power = sent_graph_index[1]\n",
    "        core_qu_infor = sent_graph_index[2]\n",
    "        core_qu_power = sent_graph_index[3]\n",
    "        qu_graph = sent_graph_index[4]\n",
    "        core_graph = sent_graph_index[5]\n",
    "        label = sent_graph_index[6]\n",
    "        example_id = sent_graph_index[7]\n",
    "        #merge_sum_sent_qu = sent_graph_index[8]\n",
    "        #merge_sum_sent_other = sent_graph_index[9]\n",
    "        article_sum = sent_graph_index[8]\n",
    "        qu_article_info_sum = sent_graph_index[9]\n",
    "        qu_article_power_sum = sent_graph_index[10]\n",
    "        art_id = sent_graph_index[11]\n",
    "#         flow_once_qu_article_power = sent_graph_index[12]\n",
    "#         flow_once_qu_core_power = sent_graph_index[13]\n",
    "#         flow_twice_qu_article_power = sent_graph_index[14]\n",
    "#         flow_twice_qu_core_power = sent_graph_index[15]\n",
    "       \n",
    "        sample = {'core_article_infor': core_article_infor,\n",
    "                  'core_article_power':core_article_power,\n",
    "                  'core_qu_infor': core_qu_infor,\n",
    "                  'core_qu_power': core_qu_power,\n",
    "                  'qu_graph': qu_graph,\n",
    "                  'core_graph': core_graph,\n",
    "                  #'merge_sum_sent_qu': merge_sum_sent_qu,\n",
    "                  #'merge_sum_sent_other':merge_sum_sent_other,\n",
    "                  'article_sum':article_sum,\n",
    "                  'qu_article_info_sum':qu_article_info_sum,\n",
    "                  'qu_article_power_sum':qu_article_power_sum,\n",
    "                  'label': label,\n",
    "                  'example_id': example_id,\n",
    "                  'art_id': art_id,\n",
    "#                   'flow_once_qu_article_power':flow_once_qu_article_power,\n",
    "#                   'flow_once_qu_core_power':flow_once_qu_core_power,\n",
    "#                   'flow_twice_qu_article_power':flow_twice_qu_article_power,\n",
    "#                   'flow_twice_qu_core_power':flow_twice_qu_core_power,\n",
    "                 }\n",
    "        \n",
    "        if self.transform:\n",
    "            sample = self.transform(sample)\n",
    "        return sample \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "230dbfa2",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_file_name = \"train_equal_sent_5_verb_3_data_v1.json\"\n",
    "dev_file_name = \"dev_equal_sent_5_verb_3_data_v1.json\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9c8a2570",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 1/5319 [00:00<09:26,  9.39it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Load ok\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5319/5319 [00:08<00:00, 597.19it/s]\n",
      "100%|██████████| 5319/5319 [00:13<00:00, 386.25it/s]\n",
      " 13%|█▎        | 61/471 [00:00<00:00, 609.44it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Load ok\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 471/471 [00:00<00:00, 606.24it/s]\n",
      "100%|██████████| 471/471 [00:01<00:00, 321.99it/s]\n"
     ]
    }
   ],
   "source": [
    "train_data = qa_train_context_dataset(train_file_name,transform=None)\n",
    "dev_data = qa_train_context_dataset(dev_file_name,transform=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "40e1db17",
   "metadata": {},
   "outputs": [],
   "source": [
    "#sent 5,6,7,8\n",
    "train_loader = DataLoader(train_data,batch_size=64,drop_last=True,shuffle=False)\n",
    "dev_loader = DataLoader(dev_data,batch_size=64,drop_last=True,shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5fe187dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def conv_article_Batch(in_planes, out_planes, kernel_size,stride):                      \n",
    "    return torch.nn.Sequential(\n",
    "            torch.nn.Conv2d(in_channels = in_planes,\n",
    "                            out_channels = out_planes,\n",
    "                            kernel_size=kernel_size,\n",
    "                            stride=stride,\n",
    "                            #padding=1\n",
    "                           ),\n",
    "            #torch.nn.BatchNorm2d(num_features=out_planes),\n",
    "            #torch.nn.ReLU(),\n",
    "            torch.nn.Sigmoid()\n",
    "            #torch.nn.Dropout(0.5),\n",
    "        )\n",
    "def conv_sent_Batch(in_planes, out_planes, kernel_size,stride):                      \n",
    "    return torch.nn.Sequential(\n",
    "            torch.nn.Conv2d(in_channels = in_planes,\n",
    "                            out_channels = out_planes,\n",
    "                            kernel_size=kernel_size,\n",
    "                            stride=stride,\n",
    "                            #padding=1\n",
    "                           ),\n",
    "            #torch.nn.BatchNorm2d(num_features=out_planes),\n",
    "            torch.nn.Sigmoid()\n",
    "            #torch.nn.ReLU(),\n",
    "            #torch.nn.Dropout(0.5),\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f07de668",
   "metadata": {},
   "outputs": [],
   "source": [
    "class conv_one_sent_layer(torch.nn.Module):\n",
    "    def __init__(self,batch_size,graph_size,in_channels,out_channels,verb_num):\n",
    "        super(conv_one_sent_layer,self).__init__()\n",
    "        self.batch_size = batch_size\n",
    "        self.in_channels = in_channels\n",
    "        self.out_channels = out_channels\n",
    "        self.sent_kernel_size = (1,graph_size)\n",
    "        self.pos_kernel_size = (verb_num,1)\n",
    "        self.sent_stride = 1\n",
    "        self.pos_stride = 1\n",
    "        self.graph_size = graph_size\n",
    "        self.verb_num = verb_num\n",
    "\n",
    "        self.sent_power_conv = conv_sent_Batch(self.in_channels,self.out_channels,self.sent_kernel_size,self.sent_stride)\n",
    "        self.pos_power_conv = conv_sent_Batch(self.in_channels,self.out_channels,self.pos_kernel_size,self.pos_stride)\n",
    "        \n",
    "    def forward(self,   \n",
    "                sent_graph,):\n",
    "        sent_power = self.sent_power_conv(sent_graph.reshape(self.batch_size,1,1,self.graph_size))\n",
    "        #print(sent_graph.reshape(self.batch_size,-1)[:,0::25].size())\n",
    "        pos_i_power = self.pos_power_conv(sent_graph.reshape(self.batch_size,1,self.verb_num,25))\n",
    "        \n",
    "        return sent_power,pos_i_power\n",
    "\n",
    "    \n",
    "class conv_article_layer(torch.nn.Module):\n",
    "    def __init__(self,batch_size,graph_size,sent_num,in_channels,out_channels,verb_num):\n",
    "        super(conv_article_layer,self).__init__()\n",
    "        self.batch_size = batch_size\n",
    "        self.in_channels = in_channels\n",
    "        self.out_channels = out_channels\n",
    "        self.verb_num = verb_num\n",
    "        self.sent_num = sent_num\n",
    "        self.srl_len = 25\n",
    "        self.sent_size = self.verb_num*self.srl_len\n",
    "        self.sent_kernel_size = (self.sent_num,self.sent_size)\n",
    "        self.pos_kernel_size = (self.sent_num*self.verb_num,1)\n",
    "        self.sent_stride = 1\n",
    "        self.pos_stride = 1\n",
    "        self.graph_size = graph_size\n",
    "        \n",
    "\n",
    "        self.article_power_conv = conv_article_Batch(self.in_channels,self.out_channels,self.sent_kernel_size,self.sent_stride)\n",
    "        self.pos_power_conv = conv_article_Batch(self.in_channels,self.out_channels,self.pos_kernel_size,self.pos_stride)\n",
    "        \n",
    "        \n",
    "    def forward(self,   \n",
    "                sent_graph,):\n",
    "        #print(\"self.sent_size\",self.sent_size)\n",
    "        article_power = self.article_power_conv(sent_graph.reshape(self.batch_size,1,self.sent_num,self.sent_size))\n",
    "        \n",
    "        a_pos_i_power = self.pos_power_conv(sent_graph.reshape(self.batch_size,1,self.sent_num*self.verb_num,25))\n",
    "        \n",
    "        \n",
    "        return article_power,a_pos_i_power\n",
    "            \n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c366b096",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "conv_pos_power_net(\n",
      "  (mlp1): Linear(in_features=2175, out_features=200, bias=True)\n",
      "  (mlp2): Linear(in_features=200, out_features=100, bias=True)\n",
      "  (mlp3): Linear(in_features=100, out_features=1, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "class conv_pos_power_net(torch.nn.Module):\n",
    "    def __init__(self,batch_size,in_channels,out_channels):\n",
    "        super(conv_pos_power_net,self).__init__()\n",
    "        self.batch_size = batch_size\n",
    "        self.in_channels = in_channels\n",
    "        self.out_channels = out_channels\n",
    "        self.verb_num = 3\n",
    "        self.srl_len = 25\n",
    "        self.graph_size = self.verb_num*self.srl_len\n",
    "        self.sent_num = 5\n",
    "        \n",
    "        \n",
    "#         self.core_power_conv = conv_one_sent_layer(self.batch_size,self.graph_size,self.in_channels,self.out_channels,self.verb_num)\n",
    "#         self.qu_power_conv = conv_one_sent_layer(self.batch_size,self.graph_size,self.in_channels,self.out_channels,self.verb_num)\n",
    "#         self.article_sum_power_conv = conv_article_layer(self.batch_size,self.graph_size,self.sent_num,self.in_channels,self.out_channels,self.verb_num)\n",
    "#         self.qu_article_power_conv = conv_article_layer(self.batch_size,self.graph_size,self.sent_num,self.in_channels,self.out_channels,self.verb_num)\n",
    "#         self.core_article_power_conv = conv_article_layer(self.batch_size,self.graph_size,self.sent_num,self.in_channels,self.out_channels,self.verb_num)\n",
    "#         self.cqi_info_conv = conv_one_sent_layer(self.batch_size,self.graph_size,self.in_channels,self.out_channels,self.verb_num)\n",
    "#         self.cqp_power_conv = conv_one_sent_layer(self.batch_size,self.graph_size,self.in_channels,self.out_channels,self.verb_num)\n",
    "#         self.core_article_info_conv = conv_article_layer(self.batch_size,self.graph_size,self.sent_num,self.in_channels,self.out_channels,self.verb_num)\n",
    "#         self.qu_article_info_conv = conv_article_layer(self.batch_size,self.graph_size,self.sent_num,self.in_channels,self.out_channels,self.verb_num)\n",
    "\n",
    "        \n",
    "        \n",
    "        \n",
    "        #         self.flow_once_qu_article_power_conv = conv_article_layer(self.batch_size,self.graph_size,self.sent_num,self.in_channels,self.out_channels)\n",
    "#         self.flow_once_qu_core_power_conv = conv_one_sent_layer(self.batch_size,self.graph_size,self.in_channels,self.out_channels)\n",
    "#         self.flow_twice_qu_article_power_conv = conv_article_layer(self.batch_size,self.graph_size,self.sent_num,self.in_channels,self.out_channels)\n",
    "#         self.flow_twice_qu_core_power_conv = conv_one_sent_layer(self.batch_size,self.graph_size,self.in_channels,self.out_channels)\n",
    "        \n",
    "        #self.mlp1 = torch.nn.Linear((9+25*9+4*9),200*self.out_channels)\n",
    "        \n",
    "        self.mlp1 = torch.nn.Linear((25*3*(4+5*5)),200*self.out_channels)\n",
    "        self.mlp2 = torch.nn.Linear(200*self.out_channels,100*self.out_channels)\n",
    "        self.mlp3 = torch.nn.Linear(100*self.out_channels,1)\n",
    "\n",
    "\n",
    "#         self.mlp1 = torch.nn.Linear((25*9),200*self.out_channels)\n",
    "#         self.mlp2 = torch.nn.Linear(200*self.out_channels,100*self.out_channels)\n",
    "#         self.mlp3 = torch.nn.Linear(100*self.out_channels,1)\n",
    "        \n",
    "#         self.mlp1 = torch.nn.Linear(9,20*self.out_channels)\n",
    "#         self.mlp2 = torch.nn.Linear(20*self.out_channels,10*self.out_channels)\n",
    "#         self.mlp3 = torch.nn.Linear(10*self.out_channels,1)\n",
    "        \n",
    "    def forward(self,   \n",
    "                core_graph,    \n",
    "                qu_graph,\n",
    "                core_article_power,\n",
    "                core_qu_info,\n",
    "                core_qu_power,\n",
    "                article_sum,\n",
    "                qu_article_power_sum,\n",
    "                core_article_info,\n",
    "                qu_article_info_sum,\n",
    "#                 flow_once_qu_article_power,\n",
    "#                 flow_once_qu_core_power,\n",
    "#                 flow_twice_qu_article_power,\n",
    "#                 flow_twice_qu_core_power,\n",
    "                \n",
    "               ):\n",
    "        \n",
    "        \n",
    "#         core_power,core_pos_i_power = self.core_power_conv(core_graph)\n",
    "#         #print(core_power.size())\n",
    "#         #print(core_pos_i_power.size())\n",
    "#         qu_power,qu_pos_i_power= self.qu_power_conv(qu_graph)\n",
    "#         core_sent_article_power,core_sent_article_pos_i_power = self.core_article_power_conv(core_article_power)\n",
    "#         context_power,context_pos_i_power = self.article_sum_power_conv(article_sum)\n",
    "#         qu_article_power,qu_article_pos_i_power = self.qu_article_power_conv(qu_article_power_sum)\n",
    "#         core_in_qu_power,core_in_qu_pos_i_power = self.cqp_power_conv(core_qu_power)\n",
    "#         qu_in_core_power,qu_in_core_pos_i_power = self.cqi_info_conv(core_qu_info)\n",
    "#         core_sent_article_info,core_sent_article_pos_i_info = self.core_article_info_conv(core_article_info)\n",
    "#         qu_article_info,qu_article_pos_i_info = self.qu_article_info_conv(qu_article_info_sum)\n",
    "        \n",
    "        \n",
    "        \n",
    "#         flow_once_article,flow_once_article_pos = self.flow_once_qu_article_power_conv(flow_once_qu_article_power)\n",
    "#         flow_once_qu,flow_once_qu_pos = self.flow_once_qu_core_power_conv(flow_once_qu_core_power)\n",
    "#         flow_twice_article,flow_twice_article_pos = self.flow_twice_qu_article_power_conv(flow_twice_qu_article_power)\n",
    "#         flow_twice_qu,flow_twice_qu_pos = self.flow_twice_qu_core_power_conv(flow_twice_qu_core_power)\n",
    "        \n",
    "        \n",
    "        \n",
    "        x = torch.cat((\n",
    "               core_graph.reshape(self.batch_size,1,1,self.graph_size),    \n",
    "                qu_graph.reshape(self.batch_size,1,1,self.graph_size),\n",
    "                core_article_power.reshape(self.batch_size,1,1,self.graph_size*5),\n",
    "                core_qu_info.reshape(self.batch_size,1,1,self.graph_size),\n",
    "                core_qu_power.reshape(self.batch_size,1,1,self.graph_size),\n",
    "                article_sum.reshape(self.batch_size,1,1,self.graph_size*5),\n",
    "                qu_article_power_sum.reshape(self.batch_size,1,1,self.graph_size*5),\n",
    "                core_article_info.reshape(self.batch_size,1,1,self.graph_size*5),\n",
    "                qu_article_info_sum.reshape(self.batch_size,1,1,self.graph_size*5),\n",
    "            \n",
    "        ),3)\n",
    "        \n",
    "#         x = torch.cat((\n",
    "#             core_power,\n",
    "#             qu_power,\n",
    "#             core_sent_article_power,\n",
    "#             context_power,\n",
    "#             qu_article_power,\n",
    "#             core_in_qu_power,\n",
    "#             qu_in_core_power,\n",
    "#             core_sent_article_info,\n",
    "#             qu_article_info,\n",
    "            \n",
    "#         ),3)\n",
    "\n",
    "\n",
    "#         x = torch.cat((\n",
    "#             core_pos_i_power,\n",
    "#             qu_pos_i_power,\n",
    "#             core_sent_article_pos_i_power,\n",
    "#             context_pos_i_power,\n",
    "#             qu_article_pos_i_power,\n",
    "#             core_in_qu_pos_i_power,\n",
    "#             qu_in_core_pos_i_power,\n",
    "#             core_sent_article_pos_i_info,\n",
    "#             qu_article_pos_i_info,\n",
    "            \n",
    "#         ),3)\n",
    "        \n",
    "        x = x.view(self.batch_size, -1)\n",
    "        x = self.mlp1(x)\n",
    "        #print(\"mlp1:\",x.size())\n",
    "        x = self.mlp2(x)\n",
    "        x = self.mlp3(x)\n",
    "        x = torch.sigmoid(x)\n",
    "        #print(\"x\"x.size())\n",
    "        \n",
    "        return x\n",
    "        \n",
    "model = conv_pos_power_net(64,1,1)\n",
    "\n",
    "#output = nn.parallel.data_parallel(new_net, input, device_ids=[0, 1])\n",
    "print(model)\n",
    "#out = model()\n",
    "if torch.cuda.is_available():\n",
    "    model = model.cuda()\n",
    "    #model = torch.nn.DataParallel(model,device_ids=[0,1,2,3])        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f9f0d73",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "23dd05e1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "conv_pos_power_net(\n",
      "  (core_power_conv): conv_one_sent_layer(\n",
      "    (sent_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(1, 75), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "    (pos_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(3, 1), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "  )\n",
      "  (qu_power_conv): conv_one_sent_layer(\n",
      "    (sent_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(1, 75), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "    (pos_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(3, 1), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "  )\n",
      "  (article_sum_power_conv): conv_article_layer(\n",
      "    (article_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(5, 75), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "    (pos_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(15, 1), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "  )\n",
      "  (qu_article_power_conv): conv_article_layer(\n",
      "    (article_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(5, 75), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "    (pos_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(15, 1), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "  )\n",
      "  (core_article_power_conv): conv_article_layer(\n",
      "    (article_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(5, 75), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "    (pos_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(15, 1), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "  )\n",
      "  (cqi_info_conv): conv_one_sent_layer(\n",
      "    (sent_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(1, 75), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "    (pos_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(3, 1), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "  )\n",
      "  (cqp_power_conv): conv_one_sent_layer(\n",
      "    (sent_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(1, 75), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "    (pos_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(3, 1), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "  )\n",
      "  (core_article_info_conv): conv_article_layer(\n",
      "    (article_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(5, 75), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "    (pos_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(15, 1), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "  )\n",
      "  (qu_article_info_conv): conv_article_layer(\n",
      "    (article_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(5, 75), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "    (pos_power_conv): Sequential(\n",
      "      (0): Conv2d(1, 1, kernel_size=(15, 1), stride=(1, 1))\n",
      "      (1): Sigmoid()\n",
      "    )\n",
      "  )\n",
      "  (mlp1): Linear(in_features=225, out_features=200, bias=True)\n",
      "  (mlp2): Linear(in_features=200, out_features=100, bias=True)\n",
      "  (mlp3): Linear(in_features=100, out_features=1, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "class conv_pos_power_net(torch.nn.Module):\n",
    "    def __init__(self,batch_size,in_channels,out_channels):\n",
    "        super(conv_pos_power_net,self).__init__()\n",
    "        self.batch_size = batch_size\n",
    "        self.in_channels = in_channels\n",
    "        self.out_channels = out_channels\n",
    "        self.verb_num = 3\n",
    "        self.srl_len = 25\n",
    "        self.graph_size = self.verb_num*self.srl_len\n",
    "        self.sent_num = 5\n",
    "        \n",
    "        \n",
    "        self.core_power_conv = conv_one_sent_layer(self.batch_size,self.graph_size,self.in_channels,self.out_channels,self.verb_num)\n",
    "        self.qu_power_conv = conv_one_sent_layer(self.batch_size,self.graph_size,self.in_channels,self.out_channels,self.verb_num)\n",
    "        self.article_sum_power_conv = conv_article_layer(self.batch_size,self.graph_size,self.sent_num,self.in_channels,self.out_channels,self.verb_num)\n",
    "        self.qu_article_power_conv = conv_article_layer(self.batch_size,self.graph_size,self.sent_num,self.in_channels,self.out_channels,self.verb_num)\n",
    "        self.core_article_power_conv = conv_article_layer(self.batch_size,self.graph_size,self.sent_num,self.in_channels,self.out_channels,self.verb_num)\n",
    "        self.cqi_info_conv = conv_one_sent_layer(self.batch_size,self.graph_size,self.in_channels,self.out_channels,self.verb_num)\n",
    "        self.cqp_power_conv = conv_one_sent_layer(self.batch_size,self.graph_size,self.in_channels,self.out_channels,self.verb_num)\n",
    "        self.core_article_info_conv = conv_article_layer(self.batch_size,self.graph_size,self.sent_num,self.in_channels,self.out_channels,self.verb_num)\n",
    "        self.qu_article_info_conv = conv_article_layer(self.batch_size,self.graph_size,self.sent_num,self.in_channels,self.out_channels,self.verb_num)\n",
    "#         self.flow_once_qu_article_power_conv = conv_article_layer(self.batch_size,self.graph_size,self.sent_num,self.in_channels,self.out_channels)\n",
    "#         self.flow_once_qu_core_power_conv = conv_one_sent_layer(self.batch_size,self.graph_size,self.in_channels,self.out_channels)\n",
    "#         self.flow_twice_qu_article_power_conv = conv_article_layer(self.batch_size,self.graph_size,self.sent_num,self.in_channels,self.out_channels)\n",
    "#         self.flow_twice_qu_core_power_conv = conv_one_sent_layer(self.batch_size,self.graph_size,self.in_channels,self.out_channels)\n",
    "        \n",
    "        self.mlp1 = torch.nn.Linear((9*25),200*self.out_channels)\n",
    "        \n",
    "        #self.mlp1 = torch.nn.Linear((9+25*9)+(25*3*(4+5*5)),200*self.out_channels)\n",
    "        #self.mlp1 = torch.nn.Linear(25*9+(25*3*(4+5*5)),200*self.out_channels)\n",
    "        #self.mlp1 = torch.nn.Linear(9+(25*3*(4+5*5)),200*self.out_channels)\n",
    "        self.mlp2 = torch.nn.Linear(200*self.out_channels,100*self.out_channels)\n",
    "        self.mlp3 = torch.nn.Linear(100*self.out_channels,1)\n",
    "\n",
    "\n",
    "#         self.mlp1 = torch.nn.Linear((25*9),200*self.out_channels)\n",
    "#         self.mlp2 = torch.nn.Linear(200*self.out_channels,100*self.out_channels)\n",
    "#         self.mlp3 = torch.nn.Linear(100*self.out_channels,1)\n",
    "        \n",
    "#         self.mlp1 = torch.nn.Linear(9,20*self.out_channels)\n",
    "#         self.mlp2 = torch.nn.Linear(20*self.out_channels,10*self.out_channels)\n",
    "#         self.mlp3 = torch.nn.Linear(10*self.out_channels,1)\n",
    "        \n",
    "    def forward(self,   \n",
    "                core_graph,    \n",
    "                qu_graph,\n",
    "                core_article_power,\n",
    "                core_qu_info,\n",
    "                core_qu_power,\n",
    "                article_sum,\n",
    "                qu_article_power_sum,\n",
    "                core_article_info,\n",
    "                qu_article_info_sum,\n",
    "#                 flow_once_qu_article_power,\n",
    "#                 flow_once_qu_core_power,\n",
    "#                 flow_twice_qu_article_power,\n",
    "#                 flow_twice_qu_core_power,\n",
    "                \n",
    "               ):\n",
    "        \n",
    "        \n",
    "        core_power,core_pos_i_power = self.core_power_conv(core_graph)\n",
    "        #print(core_power.size())\n",
    "        #print(core_pos_i_power.size())\n",
    "        qu_power,qu_pos_i_power= self.qu_power_conv(qu_graph)\n",
    "        core_sent_article_power,core_sent_article_pos_i_power = self.core_article_power_conv(core_article_power)\n",
    "        context_power,context_pos_i_power = self.article_sum_power_conv(article_sum)\n",
    "        qu_article_power,qu_article_pos_i_power = self.qu_article_power_conv(qu_article_power_sum)\n",
    "        core_in_qu_power,core_in_qu_pos_i_power = self.cqp_power_conv(core_qu_power)\n",
    "        qu_in_core_power,qu_in_core_pos_i_power = self.cqi_info_conv(core_qu_info)\n",
    "        core_sent_article_info,core_sent_article_pos_i_info = self.core_article_info_conv(core_article_info)\n",
    "        qu_article_info,qu_article_pos_i_info = self.qu_article_info_conv(qu_article_info_sum)\n",
    "#         flow_once_article,flow_once_article_pos = self.flow_once_qu_article_power_conv(flow_once_qu_article_power)\n",
    "#         flow_once_qu,flow_once_qu_pos = self.flow_once_qu_core_power_conv(flow_once_qu_core_power)\n",
    "#         flow_twice_article,flow_twice_article_pos = self.flow_twice_qu_article_power_conv(flow_twice_qu_article_power)\n",
    "#         flow_twice_qu,flow_twice_qu_pos = self.flow_twice_qu_core_power_conv(flow_twice_qu_core_power)\n",
    "        \n",
    "        \n",
    "        \n",
    "        x = torch.cat((\n",
    "            #core_power,\n",
    "            core_pos_i_power,\n",
    "            #qu_power,\n",
    "            qu_pos_i_power,\n",
    "            #core_sent_article_power,\n",
    "            core_sent_article_pos_i_power,\n",
    "            #context_power,\n",
    "            context_pos_i_power,\n",
    "            #qu_article_power,\n",
    "            qu_article_pos_i_power,\n",
    "            #core_in_qu_power,\n",
    "            core_in_qu_pos_i_power,\n",
    "            #qu_in_core_power,\n",
    "            qu_in_core_pos_i_power,\n",
    "            #core_sent_article_info,\n",
    "            core_sent_article_pos_i_info,\n",
    "            #qu_article_info,\n",
    "            qu_article_pos_i_info,\n",
    "#              core_graph.reshape(self.batch_size,1,1,self.graph_size),    \n",
    "#                 qu_graph.reshape(self.batch_size,1,1,self.graph_size),\n",
    "#                 core_article_power.reshape(self.batch_size,1,1,self.graph_size*5),\n",
    "#                 core_qu_info.reshape(self.batch_size,1,1,self.graph_size),\n",
    "#                 core_qu_power.reshape(self.batch_size,1,1,self.graph_size),\n",
    "#                 article_sum.reshape(self.batch_size,1,1,self.graph_size*5),\n",
    "#                 qu_article_power_sum.reshape(self.batch_size,1,1,self.graph_size*5),\n",
    "#                 core_article_info.reshape(self.batch_size,1,1,self.graph_size*5),\n",
    "#                 qu_article_info_sum.reshape(self.batch_size,1,1,self.graph_size*5),\n",
    "            \n",
    "        ),3)\n",
    "\n",
    "#         x = torch.cat((\n",
    "#             core_power,\n",
    "#             #core_pos_i_power,\n",
    "#             qu_power,\n",
    "#             #qu_pos_i_power,\n",
    "#             core_sent_article_power,\n",
    "#             #core_sent_article_pos_i_power,\n",
    "#             context_power,\n",
    "#             #context_pos_i_power,\n",
    "#             qu_article_power,\n",
    "#             #qu_article_pos_i_power,\n",
    "#             core_in_qu_power,\n",
    "#             #core_in_qu_pos_i_power,\n",
    "#             qu_in_core_power,\n",
    "#             #qu_in_core_pos_i_power,\n",
    "#             core_sent_article_info,\n",
    "#             #core_sent_article_pos_i_info,\n",
    "#             qu_article_info,\n",
    "#             #qu_article_pos_i_info,\n",
    "# #              core_graph.reshape(self.batch_size,1,1,self.graph_size),    \n",
    "# #                 qu_graph.reshape(self.batch_size,1,1,self.graph_size),\n",
    "# #                 core_article_power.reshape(self.batch_size,1,1,self.graph_size*5),\n",
    "# #                 core_qu_info.reshape(self.batch_size,1,1,self.graph_size),\n",
    "# #                 core_qu_power.reshape(self.batch_size,1,1,self.graph_size),\n",
    "# #                 article_sum.reshape(self.batch_size,1,1,self.graph_size*5),\n",
    "# #                 qu_article_power_sum.reshape(self.batch_size,1,1,self.graph_size*5),\n",
    "# #                 core_article_info.reshape(self.batch_size,1,1,self.graph_size*5),\n",
    "# #                 qu_article_info_sum.reshape(self.batch_size,1,1,self.graph_size*5),\n",
    "            \n",
    "#         ),3)\n",
    "     \n",
    "        \n",
    "#         x = torch.cat((\n",
    "#             core_power,\n",
    "#             qu_power,\n",
    "#             core_sent_article_power,\n",
    "#             context_power,\n",
    "#             qu_article_power,\n",
    "#             core_in_qu_power,\n",
    "#             qu_in_core_power,\n",
    "#             core_sent_article_info,\n",
    "#             qu_article_info,\n",
    "            \n",
    "#         ),3)\n",
    "\n",
    "\n",
    "#         x = torch.cat((\n",
    "#             core_pos_i_power,\n",
    "#             qu_pos_i_power,\n",
    "#             core_sent_article_pos_i_power,\n",
    "#             context_pos_i_power,\n",
    "#             qu_article_pos_i_power,\n",
    "#             core_in_qu_pos_i_power,\n",
    "#             qu_in_core_pos_i_power,\n",
    "#             core_sent_article_pos_i_info,\n",
    "#             qu_article_pos_i_info,\n",
    "            \n",
    "#         ),3)\n",
    "        \n",
    "        x = x.view(self.batch_size, -1)\n",
    "        x = self.mlp1(x)\n",
    "        #print(\"mlp1:\",x.size())\n",
    "        x = self.mlp2(x)\n",
    "        x = self.mlp3(x)\n",
    "        x = torch.sigmoid(x)\n",
    "        #print(\"x\"x.size())\n",
    "        \n",
    "        return x\n",
    "        \n",
    "model = conv_pos_power_net(64,1,1)\n",
    "\n",
    "#output = nn.parallel.data_parallel(new_net, input, device_ids=[0, 1])\n",
    "print(model)\n",
    "#out = model()\n",
    "if torch.cuda.is_available():\n",
    "    model = model.cuda()\n",
    "    #model = torch.nn.DataParallel(model,device_ids=[0,1,2,3])        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0b838a7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "13fe57e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "learning_rate = 0.01\n",
    "criterion = nn.BCELoss()\n",
    "optimizer = optim.Adam(model.parameters(),lr=learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "c188cc8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 训练模型\n",
    "def train_code_confing(train_loader,dev_loader,epoch_num):\n",
    "    dev_epoch_num = 0\n",
    "    for num in range(epoch_num):\n",
    "        \n",
    "        epoch = 0\n",
    "        \n",
    "                    \n",
    "        for data in tqdm(train_loader):\n",
    "            #print(data['flow_once_qu_article_power'].type())\n",
    "            if torch.cuda.is_available():\n",
    "\n",
    "                #qu_graph = data['qu_graph'].cuda() \n",
    "                qu_graph = data['qu_graph'].cuda()\n",
    "                core_graph = data['core_graph'].cuda()\n",
    "                core_article_info = data['core_article_infor'].cuda()\n",
    "                #print(core_article_infor.size())\n",
    "                core_article_power = data['core_article_power'].cuda()    \n",
    "                core_qu_infor = data['core_qu_infor'].cuda()    \n",
    "                core_qu_power = data['core_qu_power'].cuda()\n",
    "                article_sum = data['article_sum'].cuda()\n",
    "                qu_article_info_sum = data['qu_article_info_sum'].cuda()\n",
    "                #print(data['qu_article_info_sum'].type())\n",
    "                \n",
    "                qu_article_power_sum = data['qu_article_power_sum'].cuda()\n",
    "#                 flower_once_qu_article_power = data['flow_once_qu_article_power'].float().cuda(),\n",
    "#                 flower_once_qu_core_power = data['flow_once_qu_core_power'].float().cuda(),\n",
    "#                 flower_twice_qu_article_power = data['flow_twice_qu_article_power'].float().cuda(),\n",
    "#                 flower_twice_qu_core_power = data['flow_twice_qu_core_power'].float().cuda(),\n",
    "                \n",
    "                \n",
    "                #merge_sum_sent_qu = data['merge_sum_sent_qu'].cuda()\n",
    "                #merge_sum_sent_other = data['merge_sum_sent_other'].cuda()\n",
    "                \n",
    "                \n",
    "                    #print(\"size:\",con_value[1].size())\n",
    "#                     #print(\"cai_size:\",core_article_infor.size())\n",
    "#                 print(data['flow_once_qu_article_power'].type())\n",
    "#                 print(data['flow_once_qu_article_power'].cuda().type())\n",
    "    \n",
    "#                 print(flower_once_qu_article_power.type())\n",
    "                out = model(core_graph,\n",
    "                            qu_graph,\n",
    "                            #core_article_infor,\n",
    "                            core_article_power,\n",
    "                            core_qu_infor,\n",
    "                            core_qu_power,\n",
    "                            article_sum,\n",
    "                            qu_article_power_sum,\n",
    "                            core_article_info,\n",
    "                            qu_article_info_sum,\n",
    "                            #data['qu_article_info_sum'].cuda(),\n",
    "#                             data['flow_once_qu_article_power'].float().cuda(),\n",
    "#                             data['flow_once_qu_core_power'].float().cuda(),\n",
    "#                             data['flow_twice_qu_article_power'].float().cuda(),\n",
    "#                             data['flow_twice_qu_core_power'].float().cuda(),\n",
    "                            #flower_once_qu_article_power.float().cuda(),\n",
    "                            #flower_once_qu_core_power.float().cuda(),\n",
    "                            )\n",
    "\n",
    "#                 out = nn.parallel.data_parallel(model, (core_article_infor,core_article_power,\n",
    "#                             core_qu_infor,core_qu_power,\n",
    "#                             qu_core_infor,qu_core_power), device_ids=[0,1,2,3])\n",
    "                    #print(\"out:\",out)\n",
    "                label = data['label'].cuda()\n",
    "                #example_id = data['example_id']\n",
    "#                 print(\"p_out_t:\",out.size())\n",
    "#                 print(\"p_out_t:\",out)\n",
    "\n",
    "                out = out.squeeze(1)\n",
    "                #label = label.unsqueeze(0)\n",
    "#                 print(\"out_t:\",out.size())\n",
    "#                 print(\"out_t:\",out)\n",
    "#                 print(\"label:\",label.size())\n",
    "#                 print(\"label:\",label)\n",
    "                loss = criterion(out,label)\n",
    "\n",
    "                print_loss = loss.data.item()\n",
    "\n",
    "                optimizer.zero_grad()\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "                epoch+=1\n",
    "                if epoch%500 == 0:\n",
    "                    print('epoch: {}, loss: {:.4}'.format(epoch, loss.data.item()))\n",
    "                    eval_code_confing(dev_loader)\n",
    "\n",
    "            else:\n",
    "                print(\"train_error\")\n",
    "            \n",
    "        dev_epoch_num = dev_epoch_num + 1\n",
    "        if dev_epoch_num%1 == 0:\n",
    "            #print(\"epoch_num:\",epoch_num)\n",
    "            \n",
    "            eval_code_confing(dev_loader)\n",
    "        torch.save(model, './full_connection.pkl')\n",
    "\n",
    "\n",
    "        \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "d5a5eba7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def exact_example_num(predict_label_list,dev_label_list,example_id_list):\n",
    "    \n",
    "    #predict_label_list = predict_label_list.detach().cpu()\n",
    "    example_id_list = [ i.detach().cpu() for i in example_id_list ]\n",
    "    predict_label_list = [ i.detach().cpu() for i in predict_label_list ]\n",
    "    dev_label_list = [ i.detach().cpu() for i in dev_label_list ]\n",
    "\n",
    "    assert len(predict_label_list) == len(dev_label_list)\n",
    "    assert len(predict_label_list) == len(example_id_list)\n",
    "\n",
    "    num = len(dev_label_list)\n",
    "    n = np.max(example_id_list) + 1\n",
    "    print(\"num:\",num)\n",
    "    print(\"n:\",n)\n",
    "    gold_sp = [ [] for i in range(n) ]\n",
    "    pred_sp = [ [] for i in range(n) ]\n",
    "\n",
    "    last_id,sd = -1,0\n",
    "    for i in range(num):\n",
    "        if example_id_list[i] != last_id:\n",
    "            last_id, sd = example_id_list[i], 0\n",
    "        if dev_label_list[i].long() == 1:\n",
    "            gold_sp[last_id].append(sd)\n",
    "        if predict_label_list[i] > 0.5:\n",
    "            pred_sp[last_id].append(sd)\n",
    "        sd += 1\n",
    "    \n",
    "    em,f1 = 0,0\n",
    "    for i in range(n):\n",
    "        cur_sp_pred = pred_sp[i]\n",
    "        gold_sp_pred = gold_sp[i]\n",
    "        tp, fp, fn = 0, 0, 0\n",
    "        for e in cur_sp_pred:\n",
    "            if e in gold_sp_pred:\n",
    "                tp += 1\n",
    "            else:\n",
    "                fp += 1\n",
    "        for e in gold_sp_pred:\n",
    "            if e not in cur_sp_pred:\n",
    "                fn += 1\n",
    "        prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0\n",
    "        recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0\n",
    "        f1 += 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0\n",
    "        em += 1.0 if fp + fn == 0 else 0.0\n",
    "    em /= n\n",
    "    f1 /= n\n",
    "    print(\"em:\",em)\n",
    "    print(\"f1:\",f1)\n",
    "    return em,f1\n",
    "                "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "b2d3787f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_code_confing(dev_loader):\n",
    "    # 模型评估\n",
    "    print(\"load_function\")\n",
    "    model.eval()\n",
    "    eval_loss = 0\n",
    "    eval_acc = 0\n",
    "    \n",
    "    predict_label_list = []\n",
    "    dev_label_list = []\n",
    "    example_id_list = []\n",
    "    for data in tqdm(dev_loader):\n",
    "        \n",
    "        if torch.cuda.is_available():\n",
    "            \n",
    "            #qu_graph = data['qu_graph'].cuda()              \n",
    "            #qu_graph = data['qu_graph'].cuda() \n",
    "            qu_graph = data['qu_graph'].cuda()\n",
    "            core_graph = data['core_graph'].cuda()\n",
    "            core_article_info = data['core_article_infor'].cuda()\n",
    "            #print(core_article_infor.size())\n",
    "            core_article_power = data['core_article_power'].cuda()    \n",
    "            core_qu_infor = data['core_qu_infor'].cuda()    \n",
    "            core_qu_power = data['core_qu_power'].cuda()\n",
    "            article_sum = data['article_sum'].cuda()\n",
    "            qu_article_info_sum = data['qu_article_info_sum'].cuda()\n",
    "            qu_article_power_sum = data['qu_article_power_sum'].cuda()\n",
    "#             flower_once_qu_article_power = data['flow_once_qu_article_power'].cuda(),\n",
    "#             flower_once_qu_core_power = data['flow_once_qu_core_power'].cuda(),\n",
    "#             flower_twice_qu_article_power = data['flow_twice_qu_article_power'].cuda(),\n",
    "#             flower_twice_qu_core_power = data['flow_twice_qu_core_power'].cuda(),\n",
    "            example_id = data['example_id']\n",
    "            #merge_sum_sent_qu = data['merge_sum_sent_qu'].cuda()\n",
    "            #merge_sum_sent_other = data['merge_sum_sent_other'].cuda()\n",
    "\n",
    "\n",
    "                #print(\"size:\",con_value[1].size())\n",
    "                #print(\"cai_size:\",core_article_infor.size())\n",
    "            out = model(qu_graph,\n",
    "                        core_graph,\n",
    "                        #core_article_infor,\n",
    "                        core_article_power,\n",
    "                        core_qu_infor,\n",
    "                        core_qu_power,\n",
    "                        article_sum,\n",
    "                        qu_article_power_sum,\n",
    "                        core_article_info,\n",
    "                        qu_article_info_sum,\n",
    "#                         data['flow_once_qu_article_power'].float().cuda(),\n",
    "#                         data['flow_once_qu_core_power'].float().cuda(),\n",
    "#                         data['flow_twice_qu_article_power'].float().cuda(),\n",
    "#                         data['flow_twice_qu_core_power'].float().cuda(),\n",
    "                        #flower_once_qu_article_power,\n",
    "                        #flower_once_qu_core_power,\n",
    "                        )\n",
    "            #print(\"out:\",out)\n",
    "            label = data['label'].cuda()\n",
    "           \n",
    "            #print(\"p_out_t:\",out.size())\n",
    "            #print(\"p_out_t:\",out)\n",
    "\n",
    "            out = out.squeeze(1)\n",
    "            loss = criterion(out,label)\n",
    "            \n",
    "            #print(\"out:\",out)\n",
    "            predict_label_list.extend(out)\n",
    "            #print(len(predict_label_list))\n",
    "            #print(\"predict_label_list:\",predict_label_list)\n",
    "            \n",
    "            #return 0\n",
    "            #print(predict_label_list[1])\n",
    "            dev_label_list.extend(label)\n",
    "            example_id_list.extend(example_id)\n",
    "            \n",
    "            \n",
    "            \n",
    "            \n",
    "            eval_loss += loss.data.item()\n",
    "            pred = out > 0.5\n",
    "            num_correct = (pred.long().reshape(1,-1) == label.long()).sum()\n",
    "            #print(\"pred:\",pred.long().reshape(1,-1))\n",
    "            #print(\"label:\",label.long())\n",
    "            #print(\"num_correct:\",num_correct)\n",
    "            eval_acc += num_correct.item()\n",
    "            #print(\"eval_acc:\",eval_acc)\n",
    "            \n",
    "\n",
    "\n",
    "    #sent_num = 2355\n",
    "    #sent_num = 6911\n",
    "    \n",
    "    print(\"predict_label_list:\",len(predict_label_list),\"dev_label_list:\",len(dev_label_list),\"example_id_list:\",len(example_id_list))\n",
    "    sent_num = len(dev_loader)*64\n",
    "    context_exact = exact_example_num(predict_label_list,dev_label_list,example_id_list)\n",
    "    print('Test Loss: {:.6f}, Acc: {:.6f}'.format(\n",
    "        eval_loss / sent_num,\n",
    "        eval_acc / sent_num,\n",
    "        #context_exact\n",
    "    ))\n",
    "    #return dev_merge_list_to_graph(predict_label_list),dev_merge_list_to_graph(dev_label_list),dev_merge_list_to_graph(example_id_list)\n",
    "\n",
    "\n",
    "#     print('Test Loss: {:.6f}, Acc: {:.6f}'.format(\n",
    "#         eval_loss / (len(dev_s5_v3_loader)),\n",
    "#         eval_acc / (len(dev_s5_v3_loader))\n",
    "#     ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3957aed1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01f512bd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "ce4b0331",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 88.59it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 188.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 10/415 [00:00<00:04, 90.90it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.19739696312364424\n",
      "f1: 0.7102520400785055\n",
      "Test Loss: 0.009138, Acc: 0.710938\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 90.37it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 182.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 9/415 [00:00<00:04, 89.39it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.16052060737527116\n",
      "f1: 0.7316479013875996\n",
      "Test Loss: 0.009686, Acc: 0.695747\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 91.32it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 196.39it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 8/415 [00:00<00:05, 70.89it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.21475054229934923\n",
      "f1: 0.7391695072823071\n",
      "Test Loss: 0.009321, Acc: 0.717882\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 88.54it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 188.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 10/415 [00:00<00:04, 90.46it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.22559652928416485\n",
      "f1: 0.7369469407430377\n",
      "Test Loss: 0.009136, Acc: 0.720052\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 91.11it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 194.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 10/415 [00:00<00:04, 92.33it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.2386117136659436\n",
      "f1: 0.7365544193093015\n",
      "Test Loss: 0.009019, Acc: 0.724392\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 87.65it/s]\n",
      " 47%|████▋     | 17/36 [00:00<00:00, 167.13it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 167.51it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.25813449023861174\n",
      "f1: 0.7229211858279119\n",
      "Test Loss: 0.008700, Acc: 0.730469\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 86.05it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 184.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 9/415 [00:00<00:04, 88.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.25379609544468545\n",
      "f1: 0.7236442516268994\n",
      "Test Loss: 0.008795, Acc: 0.727865\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 87.88it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 184.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 9/415 [00:00<00:04, 88.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.26247288503253796\n",
      "f1: 0.7239334779464945\n",
      "Test Loss: 0.008660, Acc: 0.733073\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 87.67it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 182.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 9/415 [00:00<00:04, 88.09it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.2603036876355748\n",
      "f1: 0.7226061357297813\n",
      "Test Loss: 0.008641, Acc: 0.736111\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 86.98it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 183.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 9/415 [00:00<00:04, 87.37it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.23644251626898047\n",
      "f1: 0.7298849981062576\n",
      "Test Loss: 0.008815, Acc: 0.723090\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 85.64it/s]\n",
      " 50%|█████     | 18/36 [00:00<00:00, 173.73it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 173.27it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.25379609544468545\n",
      "f1: 0.7271889956271747\n",
      "Test Loss: 0.008843, Acc: 0.728299\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 87.53it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 185.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 10/415 [00:00<00:04, 92.61it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.2472885032537961\n",
      "f1: 0.7283011396894272\n",
      "Test Loss: 0.008881, Acc: 0.727431\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 92.31it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 199.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 7/415 [00:00<00:06, 66.19it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.2559652928416486\n",
      "f1: 0.7289742795165801\n",
      "Test Loss: 0.008774, Acc: 0.734375\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 87.04it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 183.13it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 9/415 [00:00<00:04, 87.95it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.26898047722342733\n",
      "f1: 0.7238869951451309\n",
      "Test Loss: 0.008723, Acc: 0.736111\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 87.70it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 183.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.24295010845986983\n",
      "f1: 0.723754433081983\n",
      "Test Loss: 0.008792, Acc: 0.723958\n"
     ]
    }
   ],
   "source": [
    "#l\n",
    "epoch_num = 15\n",
    "train_code_confing(train_loader,dev_loader,epoch_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85cff9ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "#l\n",
    "l = [0.710938,0.695747,0.717882,0.720052,0.724392,0.730469,0.727865,0.733073]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "d85a4e47",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 87.38it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 185.00it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 10/415 [00:00<00:04, 95.97it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.2386117136659436\n",
      "f1: 0.7168801432358929\n",
      "Test Loss: 0.008890, Acc: 0.727431\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 89.37it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 183.85it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 9/415 [00:00<00:04, 88.48it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.23210412147505424\n",
      "f1: 0.7221223702785541\n",
      "Test Loss: 0.008918, Acc: 0.729601\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 85.08it/s]\n",
      " 50%|█████     | 18/36 [00:00<00:00, 177.18it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 176.47it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.24511930585683298\n",
      "f1: 0.7218486382260801\n",
      "Test Loss: 0.008890, Acc: 0.732205\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 85.71it/s]\n",
      " 50%|█████     | 18/36 [00:00<00:00, 173.74it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 174.32it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.24078091106290672\n",
      "f1: 0.7208105223289623\n",
      "Test Loss: 0.008864, Acc: 0.729601\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 86.70it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 182.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 8/415 [00:00<00:05, 70.84it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.23644251626898047\n",
      "f1: 0.7183951382432961\n",
      "Test Loss: 0.008794, Acc: 0.729167\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 90.60it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 192.48it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 9/415 [00:00<00:04, 88.13it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.23427331887201736\n",
      "f1: 0.7172244602830302\n",
      "Test Loss: 0.008844, Acc: 0.727865\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 89.91it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 189.68it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 10/415 [00:00<00:04, 92.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.23644251626898047\n",
      "f1: 0.7203835691905122\n",
      "Test Loss: 0.008834, Acc: 0.729167\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 92.32it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 196.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 9/415 [00:00<00:04, 89.09it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.2299349240780911\n",
      "f1: 0.7219657060221069\n",
      "Test Loss: 0.008902, Acc: 0.728733\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 92.19it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 193.48it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 10/415 [00:00<00:04, 92.57it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.23427331887201736\n",
      "f1: 0.724336328891645\n",
      "Test Loss: 0.008953, Acc: 0.730903\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 92.09it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 196.57it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 10/415 [00:00<00:04, 91.49it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.23427331887201736\n",
      "f1: 0.725472575147197\n",
      "Test Loss: 0.008944, Acc: 0.731771\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 89.72it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 187.35it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 10/415 [00:00<00:04, 90.91it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.227765726681128\n",
      "f1: 0.7206900113624638\n",
      "Test Loss: 0.008964, Acc: 0.729601\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 92.22it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 196.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 10/415 [00:00<00:04, 92.33it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.227765726681128\n",
      "f1: 0.7218865131012651\n",
      "Test Loss: 0.008972, Acc: 0.728733\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 92.34it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 196.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 10/415 [00:00<00:04, 92.74it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.23210412147505424\n",
      "f1: 0.718214371793549\n",
      "Test Loss: 0.008819, Acc: 0.728733\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 92.38it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 194.36it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 9/415 [00:00<00:04, 88.99it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.2299349240780911\n",
      "f1: 0.717028199566162\n",
      "Test Loss: 0.008960, Acc: 0.727431\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 91.96it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 194.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.23210412147505424\n",
      "f1: 0.7193557828048086\n",
      "Test Loss: 0.008998, Acc: 0.730035\n"
     ]
    }
   ],
   "source": [
    "#g\n",
    "epoch_num = 15\n",
    "train_code_confing(train_loader,dev_loader,epoch_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f06ea753",
   "metadata": {},
   "outputs": [],
   "source": [
    "g = [0.727431,0.729601,0.732205,0.729601,0.729167,0.727865,0.729167,0.728733]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "84f4f29b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:05<00:00, 75.23it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 181.07it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 8/415 [00:00<00:05, 77.39it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.24295010845986983\n",
      "f1: 0.6987552938746012\n",
      "Test Loss: 0.008912, Acc: 0.721354\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:05<00:00, 75.66it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 182.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 8/415 [00:00<00:05, 75.71it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.24511930585683298\n",
      "f1: 0.7134541886168801\n",
      "Test Loss: 0.009093, Acc: 0.721354\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:05<00:00, 75.30it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 182.85it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 8/415 [00:00<00:05, 74.50it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.19956616052060738\n",
      "f1: 0.7249578211617275\n",
      "Test Loss: 0.010354, Acc: 0.711806\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:05<00:00, 76.81it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 187.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 7/415 [00:00<00:06, 61.03it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.18872017353579176\n",
      "f1: 0.7170023757876267\n",
      "Test Loss: 0.010956, Acc: 0.700955\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:06<00:00, 65.74it/s]\n",
      " 50%|█████     | 18/36 [00:00<00:00, 179.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 180.34it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.17787418655097614\n",
      "f1: 0.6954601797334999\n",
      "Test Loss: 0.012358, Acc: 0.682292\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:06<00:00, 65.80it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 196.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 7/415 [00:00<00:06, 66.59it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.22342733188720174\n",
      "f1: 0.6935561064628322\n",
      "Test Loss: 0.014679, Acc: 0.712674\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:06<00:00, 66.80it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 192.93it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 7/415 [00:00<00:06, 67.31it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.19522776572668113\n",
      "f1: 0.6840805013256218\n",
      "Test Loss: 0.013017, Acc: 0.698785\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:06<00:00, 66.38it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 197.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 7/415 [00:00<00:06, 65.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.2017353579175705\n",
      "f1: 0.6571256412905011\n",
      "Test Loss: 0.011696, Acc: 0.697049\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:06<00:00, 65.49it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 185.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 7/415 [00:00<00:06, 65.51it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.17787418655097614\n",
      "f1: 0.6385497365974596\n",
      "Test Loss: 0.012243, Acc: 0.689236\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:06<00:00, 66.55it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 195.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 7/415 [00:00<00:06, 67.15it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.18655097613882862\n",
      "f1: 0.6920720999896717\n",
      "Test Loss: 0.011330, Acc: 0.704861\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:06<00:00, 66.93it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 195.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 7/415 [00:00<00:06, 66.40it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.16919739696312364\n",
      "f1: 0.6907275419206016\n",
      "Test Loss: 0.013281, Acc: 0.678385\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:06<00:00, 66.88it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 193.10it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 7/415 [00:00<00:06, 66.22it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.1735357917570499\n",
      "f1: 0.6332524188272561\n",
      "Test Loss: 0.014197, Acc: 0.677083\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:06<00:00, 66.80it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 196.01it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 7/415 [00:00<00:06, 66.81it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.15835140997830802\n",
      "f1: 0.6262249078951905\n",
      "Test Loss: 0.013357, Acc: 0.680990\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:06<00:00, 63.78it/s]\n",
      " 42%|████▏     | 15/36 [00:00<00:00, 141.47it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 137.67it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.19088937093275488\n",
      "f1: 0.695912956650485\n",
      "Test Loss: 0.011383, Acc: 0.704427\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:06<00:00, 65.21it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 183.14it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.18655097613882862\n",
      "f1: 0.6841820748545269\n",
      "Test Loss: 0.014102, Acc: 0.688802\n"
     ]
    }
   ],
   "source": [
    "#l+g+m(2)\n",
    "epoch_num = 15\n",
    "train_code_confing(train_loader,dev_loader,epoch_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6c9f100",
   "metadata": {},
   "outputs": [],
   "source": [
    "l+g+m = [0.721354,0.721354,0.711806,0.700955,0.682292,0.712674,0.698785,0.697049]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29a4a8c4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "2d448d60",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:05<00:00, 82.56it/s]\n",
      " 50%|█████     | 18/36 [00:00<00:00, 176.62it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 175.73it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.24078091106290672\n",
      "f1: 0.6913180456564418\n",
      "Test Loss: 0.008735, Acc: 0.724392\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 85.08it/s]\n",
      " 50%|█████     | 18/36 [00:00<00:00, 177.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 177.62it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.23427331887201736\n",
      "f1: 0.6784629687015813\n",
      "Test Loss: 0.008952, Acc: 0.723958\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 85.65it/s]\n",
      " 50%|█████     | 18/36 [00:00<00:00, 172.17it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 170.77it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.21691973969631237\n",
      "f1: 0.6920531625520794\n",
      "Test Loss: 0.009747, Acc: 0.719618\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 89.41it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 194.84it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 10/415 [00:00<00:04, 91.14it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.1561822125813449\n",
      "f1: 0.7075612023551304\n",
      "Test Loss: 0.012228, Acc: 0.677951\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 89.50it/s]\n",
      " 50%|█████     | 18/36 [00:00<00:00, 179.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 180.02it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.18655097613882862\n",
      "f1: 0.6578814172089669\n",
      "Test Loss: 0.010717, Acc: 0.702257\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 89.14it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 197.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 10/415 [00:00<00:04, 90.71it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.18655097613882862\n",
      "f1: 0.6968546637744045\n",
      "Test Loss: 0.011735, Acc: 0.687500\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 90.83it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 194.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 10/415 [00:00<00:04, 90.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.16702819956616052\n",
      "f1: 0.6683193884929257\n",
      "Test Loss: 0.013107, Acc: 0.694878\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 89.88it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 191.93it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 9/415 [00:00<00:04, 87.59it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.14316702819956617\n",
      "f1: 0.6909186378817624\n",
      "Test Loss: 0.016505, Acc: 0.656684\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 88.34it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 191.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 10/415 [00:00<00:04, 91.07it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.18655097613882862\n",
      "f1: 0.6821970870777816\n",
      "Test Loss: 0.012204, Acc: 0.698351\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 89.69it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 197.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 10/415 [00:00<00:04, 91.10it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.15835140997830802\n",
      "f1: 0.6409083083703482\n",
      "Test Loss: 0.012834, Acc: 0.684028\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 90.85it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 197.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 10/415 [00:00<00:04, 91.08it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.16919739696312364\n",
      "f1: 0.6659608167200369\n",
      "Test Loss: 0.012302, Acc: 0.689670\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 90.89it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 193.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 10/415 [00:00<00:04, 91.25it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.175704989154013\n",
      "f1: 0.6592586853975154\n",
      "Test Loss: 0.013023, Acc: 0.693576\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 91.11it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 196.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 10/415 [00:00<00:04, 91.20it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.17136659436008678\n",
      "f1: 0.6814275384774312\n",
      "Test Loss: 0.012068, Acc: 0.687934\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 90.74it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 193.85it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 9/415 [00:00<00:04, 89.98it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.20607375271149675\n",
      "f1: 0.6647367696174645\n",
      "Test Loss: 0.013595, Acc: 0.707031\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 90.10it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 193.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.10629067245119306\n",
      "f1: 0.4868763557483731\n",
      "Test Loss: 0.015071, Acc: 0.659288\n"
     ]
    }
   ],
   "source": [
    "#l+m\n",
    "epoch_num = 15\n",
    "train_code_confing(train_loader,dev_loader,epoch_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d9f7b7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "l+m= {0.724392,0.723958,0.719618,0.677951,0.702257,0.687500,0.694878,0.656684}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "1ebc7208",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 83.57it/s]\n",
      " 50%|█████     | 18/36 [00:00<00:00, 179.11it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 177.64it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.22125813449023862\n",
      "f1: 0.6898357607685172\n",
      "Test Loss: 0.008884, Acc: 0.721354\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 88.26it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 185.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 10/415 [00:00<00:04, 90.49it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.21908893709327548\n",
      "f1: 0.6878628240884223\n",
      "Test Loss: 0.009525, Acc: 0.718316\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 88.86it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 183.66it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 9/415 [00:00<00:04, 88.13it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.21908893709327548\n",
      "f1: 0.6877061598319741\n",
      "Test Loss: 0.010232, Acc: 0.704861\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 89.44it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 194.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 10/415 [00:00<00:04, 90.50it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.19088937093275488\n",
      "f1: 0.7045690872155094\n",
      "Test Loss: 0.011514, Acc: 0.695312\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 89.70it/s]\n",
      " 53%|█████▎    | 19/36 [00:00<00:00, 180.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 179.67it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.13882863340563992\n",
      "f1: 0.6943979616430817\n",
      "Test Loss: 0.011822, Acc: 0.655382\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 88.84it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 194.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 10/415 [00:00<00:04, 91.27it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.19739696312364424\n",
      "f1: 0.6899717660021359\n",
      "Test Loss: 0.010580, Acc: 0.700087\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 91.03it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 195.66it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 10/415 [00:00<00:04, 91.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.15835140997830802\n",
      "f1: 0.6812725958062193\n",
      "Test Loss: 0.011947, Acc: 0.688802\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 90.78it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 195.37it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 10/415 [00:00<00:04, 90.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.1279826464208243\n",
      "f1: 0.6786420135660937\n",
      "Test Loss: 0.014691, Acc: 0.652778\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 88.23it/s]\n",
      " 50%|█████     | 18/36 [00:00<00:00, 177.20it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 175.77it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.1475054229934924\n",
      "f1: 0.6347278173742388\n",
      "Test Loss: 0.013829, Acc: 0.685330\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 84.24it/s]\n",
      " 50%|█████     | 18/36 [00:00<00:00, 178.06it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 177.14it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.13882863340563992\n",
      "f1: 0.6191182040422827\n",
      "Test Loss: 0.011697, Acc: 0.675347\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 86.72it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 184.68it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 9/415 [00:00<00:04, 86.94it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.16919739696312364\n",
      "f1: 0.6746875322797237\n",
      "Test Loss: 0.012599, Acc: 0.691840\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 86.88it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 184.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 9/415 [00:00<00:04, 86.90it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.1561822125813449\n",
      "f1: 0.6542643666287932\n",
      "Test Loss: 0.014272, Acc: 0.683160\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 86.75it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 182.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 10/415 [00:00<00:04, 90.98it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.17787418655097614\n",
      "f1: 0.6649932858175822\n",
      "Test Loss: 0.013739, Acc: 0.681858\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 88.99it/s]\n",
      " 50%|█████     | 18/36 [00:00<00:00, 175.54it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 175.08it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.16052060737527116\n",
      "f1: 0.5975226388458496\n",
      "Test Loss: 0.012443, Acc: 0.677517\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:04<00:00, 87.65it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 183.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.1475054229934924\n",
      "f1: 0.5488207141135558\n",
      "Test Loss: 0.013758, Acc: 0.658420\n"
     ]
    }
   ],
   "source": [
    "#g+m\n",
    "epoch_num = 15\n",
    "train_code_confing(train_loader,dev_loader,epoch_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76f2d8b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "g+m = {0.721354,0.718316,0.704861,0.695312,0.655382,0.700087,0.688802,0.652778}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "908239e2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:09<00:00, 45.01it/s]\n",
      " 31%|███       | 11/36 [00:00<00:00, 102.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 103.51it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|          | 4/415 [00:00<00:10, 39.01it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "em: 0.004338394793926247\n",
      "f1: 0.6373222463244179\n",
      "Test Loss: 0.225799, Acc: 0.476997\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:09<00:00, 45.66it/s]\n",
      " 36%|███▌      | 13/36 [00:00<00:00, 128.94it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 129.38it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 7/415 [00:00<00:06, 62.42it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.004338394793926247\n",
      "f1: 0.6373222463244179\n",
      "Test Loss: 0.225799, Acc: 0.476997\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:06<00:00, 61.39it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 182.52it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 7/415 [00:00<00:06, 60.82it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.004338394793926247\n",
      "f1: 0.6373222463244179\n",
      "Test Loss: 0.225799, Acc: 0.476997\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:06<00:00, 61.96it/s]\n",
      " 50%|█████     | 18/36 [00:00<00:00, 171.94it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 172.60it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.004338394793926247\n",
      "f1: 0.6373222463244179\n",
      "Test Loss: 0.225799, Acc: 0.476997\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:06<00:00, 62.78it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 185.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 7/415 [00:00<00:06, 66.57it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.004338394793926247\n",
      "f1: 0.6373222463244179\n",
      "Test Loss: 0.225799, Acc: 0.476997\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:06<00:00, 65.57it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 185.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 7/415 [00:00<00:06, 64.94it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.004338394793926247\n",
      "f1: 0.6373222463244179\n",
      "Test Loss: 0.225799, Acc: 0.476997\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:06<00:00, 66.12it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 193.19it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 7/415 [00:00<00:06, 67.08it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.004338394793926247\n",
      "f1: 0.6373222463244179\n",
      "Test Loss: 0.225799, Acc: 0.476997\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 54%|█████▍    | 225/415 [00:03<00:02, 66.50it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-42-eca051d78d5b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;31m#g+l+m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0mepoch_num\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m15\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mtrain_code_confing\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mdev_loader\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mepoch_num\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-39-d5b71fe85cb0>\u001b[0m in \u001b[0;36mtrain_code_confing\u001b[0;34m(train_loader, dev_loader, epoch_num)\u001b[0m\n\u001b[1;32m     80\u001b[0m                 \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     81\u001b[0m                 \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 82\u001b[0;31m                 \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     83\u001b[0m                 \u001b[0mepoch\u001b[0m\u001b[0;34m+=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     84\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m%\u001b[0m\u001b[0;36m500\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.conda/envs/learn_py3/lib/python3.7/site-packages/torch/optim/adam.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m     99\u001b[0m                     \u001b[0mdenom\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmax_exp_avg_sq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgroup\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'eps'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    100\u001b[0m                 \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 101\u001b[0;31m                     \u001b[0mdenom\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mexp_avg_sq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgroup\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'eps'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    102\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    103\u001b[0m                 \u001b[0mbias_correction1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mbeta1\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'step'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "#g+l+m\n",
    "epoch_num = 15\n",
    "train_code_confing(train_loader,dev_loader,epoch_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "539b6414",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "ae297bd4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:08<00:00, 47.76it/s]\n",
      " 36%|███▌      | 13/36 [00:00<00:00, 129.89it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 129.46it/s]\n",
      "/home/baiyq/.conda/envs/learn_py3/lib/python3.7/site-packages/torch/serialization.py:256: UserWarning: Couldn't retrieve source code for container of type conv_one_sent_layer. It won't be checked for correctness upon loading.\n",
      "  \"type \" + obj.__name__ + \". It won't be checked \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.1822125813449024\n",
      "f1: 0.7391729504527791\n",
      "Test Loss: 0.009602, Acc: 0.703993\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/baiyq/.conda/envs/learn_py3/lib/python3.7/site-packages/torch/serialization.py:256: UserWarning: Couldn't retrieve source code for container of type conv_article_layer. It won't be checked for correctness upon loading.\n",
      "  \"type \" + obj.__name__ + \". It won't be checked \"\n",
      "100%|██████████| 415/415 [00:08<00:00, 48.03it/s]\n",
      " 36%|███▌      | 13/36 [00:00<00:00, 126.42it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 126.53it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.24945770065075923\n",
      "f1: 0.7339066212168182\n",
      "Test Loss: 0.008727, Acc: 0.734809\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:08<00:00, 48.75it/s]\n",
      " 39%|███▉      | 14/36 [00:00<00:00, 133.32it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 132.66it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.2668112798264642\n",
      "f1: 0.7337861102503199\n",
      "Test Loss: 0.008695, Acc: 0.742622\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:08<00:00, 48.68it/s]\n",
      " 39%|███▉      | 14/36 [00:00<00:00, 130.59it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 58.79it/s] \n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.25813449023861174\n",
      "f1: 0.7227696863271715\n",
      "Test Loss: 0.008697, Acc: 0.742188\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:06<00:00, 62.36it/s]\n",
      " 50%|█████     | 18/36 [00:00<00:00, 174.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 175.30it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.25162689804772237\n",
      "f1: 0.7052301759460118\n",
      "Test Loss: 0.008700, Acc: 0.735677\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:06<00:00, 61.38it/s]\n",
      " 39%|███▉      | 14/36 [00:00<00:00, 130.32it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 149.10it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.24511930585683298\n",
      "f1: 0.7186946940743046\n",
      "Test Loss: 0.008663, Acc: 0.740885\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:06<00:00, 64.89it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 186.40it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 7/415 [00:00<00:06, 65.44it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.24295010845986983\n",
      "f1: 0.7228626519298978\n",
      "Test Loss: 0.008633, Acc: 0.740451\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:06<00:00, 63.28it/s]\n",
      " 39%|███▉      | 14/36 [00:00<00:00, 137.41it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 136.39it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.24295010845986983\n",
      "f1: 0.7145267362187113\n",
      "Test Loss: 0.008702, Acc: 0.739583\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:08<00:00, 50.39it/s]\n",
      " 39%|███▉      | 14/36 [00:00<00:00, 137.22it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 136.88it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.2386117136659436\n",
      "f1: 0.7169076885996635\n",
      "Test Loss: 0.008660, Acc: 0.738281\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:08<00:00, 50.27it/s]\n",
      " 39%|███▉      | 14/36 [00:00<00:00, 133.27it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 36/36 [00:00<00:00, 133.20it/s]\n",
      "  0%|          | 0/415 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.22559652928416485\n",
      "f1: 0.7033605343800584\n",
      "Test Loss: 0.008731, Acc: 0.732639\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 95%|█████████▍| 393/415 [00:07<00:00, 50.09it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-30-15c82d7f4790>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;31m#g+l\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0mepoch_num\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m15\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mtrain_code_confing\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mdev_loader\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mepoch_num\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-27-d5b71fe85cb0>\u001b[0m in \u001b[0;36mtrain_code_confing\u001b[0;34m(train_loader, dev_loader, epoch_num)\u001b[0m\n\u001b[1;32m     80\u001b[0m                 \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     81\u001b[0m                 \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 82\u001b[0;31m                 \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     83\u001b[0m                 \u001b[0mepoch\u001b[0m\u001b[0;34m+=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     84\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m%\u001b[0m\u001b[0;36m500\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.conda/envs/learn_py3/lib/python3.7/site-packages/torch/optim/adam.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m     91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     92\u001b[0m                 \u001b[0;31m# Decay the first and second moment running average coefficient\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 93\u001b[0;31m                 \u001b[0mexp_avg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmul_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbeta1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mbeta1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     94\u001b[0m                 \u001b[0mexp_avg_sq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmul_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbeta2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maddcmul_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mbeta2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     95\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0mamsgrad\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "#g+l\n",
    "epoch_num = 15\n",
    "train_code_confing(train_loader,dev_loader,epoch_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41d9589b",
   "metadata": {},
   "outputs": [],
   "source": [
    "g+l = [0.703993,0.734809,0.742622,0.742188,0.735677,0.740885,0.740451,0.739583]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "682e9564",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:01<00:00, 227.03it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 332.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 24/415 [00:00<00:01, 237.77it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.14316702819956617\n",
      "f1: 0.6329855731157258\n",
      "Test Loss: 0.014763, Acc: 0.666667\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:01<00:00, 236.20it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 348.35it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 23/415 [00:00<00:01, 228.00it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.1843817787418655\n",
      "f1: 0.6620252728712609\n",
      "Test Loss: 0.011352, Acc: 0.692274\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:01<00:00, 222.48it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 318.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 25/415 [00:00<00:01, 242.64it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.21258134490238612\n",
      "f1: 0.6770770925868552\n",
      "Test Loss: 0.011710, Acc: 0.713976\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:01<00:00, 236.68it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 343.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 25/415 [00:00<00:01, 246.56it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.1735357917570499\n",
      "f1: 0.6510140137038197\n",
      "Test Loss: 0.012344, Acc: 0.687934\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:01<00:00, 240.55it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 354.54it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 25/415 [00:00<00:01, 246.53it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "em: 0.21475054229934923\n",
      "f1: 0.69251971215095\n",
      "Test Loss: 0.012314, Acc: 0.704861\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:01<00:00, 237.37it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 347.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 25/415 [00:00<00:01, 241.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.2017353579175705\n",
      "f1: 0.6937110491340436\n",
      "Test Loss: 0.016319, Acc: 0.691840\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:01<00:00, 221.70it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 317.66it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 24/415 [00:00<00:01, 232.40it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.12364425162689804\n",
      "f1: 0.6528991495368938\n",
      "Test Loss: 0.012197, Acc: 0.671875\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:01<00:00, 236.84it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 334.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 25/415 [00:00<00:01, 244.39it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.1475054229934924\n",
      "f1: 0.5947371139345115\n",
      "Test Loss: 0.011989, Acc: 0.675781\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:01<00:00, 233.81it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 331.28it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 25/415 [00:00<00:01, 242.17it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.1301518438177874\n",
      "f1: 0.5326877388699514\n",
      "Test Loss: 0.015994, Acc: 0.662326\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:01<00:00, 238.27it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 353.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 24/415 [00:00<00:01, 238.94it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.1193058568329718\n",
      "f1: 0.5214526736218712\n",
      "Test Loss: 0.013833, Acc: 0.654948\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:01<00:00, 233.09it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 347.66it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|▌         | 22/415 [00:00<00:01, 219.34it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.1301518438177874\n",
      "f1: 0.6101349722824784\n",
      "Test Loss: 0.013140, Acc: 0.661892\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:01<00:00, 226.67it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 319.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 25/415 [00:00<00:01, 243.47it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.1540130151843818\n",
      "f1: 0.674995696036912\n",
      "Test Loss: 0.013810, Acc: 0.674913\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:01<00:00, 239.35it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 353.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 25/415 [00:00<00:01, 245.39it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num: 2304\n",
      "n: 461\n",
      "em: 0.14316702819956617\n",
      "f1: 0.5890369452191582\n",
      "Test Loss: 0.014085, Acc: 0.676649\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:01<00:00, 239.74it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 353.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  6%|▌         | 25/415 [00:00<00:01, 247.56it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "em: 0.21475054229934923\n",
      "f1: 0.6194401404813558\n",
      "Test Loss: 0.012737, Acc: 0.699653\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 415/415 [00:01<00:00, 239.82it/s]\n",
      "100%|██████████| 36/36 [00:00<00:00, 341.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_function\n",
      "predict_label_list: 2304 dev_label_list: 2304 example_id_list: 2304\n",
      "num: 2304\n",
      "n: 461\n",
      "em: 0.20607375271149675\n",
      "f1: 0.693204903074753\n",
      "Test Loss: 0.013531, Acc: 0.703993\n"
     ]
    }
   ],
   "source": [
    "epoch_num = 15\n",
    "train_code_confing(train_loader,dev_loader,epoch_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c1968cc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "id": "4e5972e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "m = [0.666667,0.692274,0.713976,0.687934,0.704861,0.691840,0.671875,0.675781]\n",
    "g = [0.727431,0.729601,0.732205,0.729601,0.729167,0.727865,0.729167,0.728733]\n",
    "l = [0.710938,0.695747,0.717882,0.720052,0.724392,0.730469,0.727865,0.733073]\n",
    "gl = [0.703993,0.734809,0.742622,0.742188,0.735677,0.740885,0.740451,0.739583]\n",
    "lm= [0.724392,0.723958,0.719618,0.677951,0.702257,0.687500,0.694878,0.656684]\n",
    "gm = [0.721354,0.718316,0.704861,0.695312,0.655382,0.700087,0.688802,0.652778]\n",
    "lgm = [0.721354,0.721354,0.711806,0.700955,0.682292,0.712674,0.698785,0.697049]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "605b32fc",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "learn_py3",
   "language": "python",
   "name": "learn_py3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
