{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0617 14:45:57.437368 26044 file_utils.py:39] PyTorch version 1.2.0+cu92 available.\n",
      "C:\\Users\\danil\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\h5py\\__init__.py:34: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "from __future__ import print_function\n",
    "import os\n",
    "import math\n",
    "import numpy as np\n",
    "import json\n",
    "import time\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "\n",
    "from make_fn_data import load_fn_data\n",
    "from neural_net import Model, NpClassDataset\n",
    "from transformers import BertTokenizer, BertModel, BertForMaskedLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
    "bert_model = BertModel.from_pretrained('bert-base-uncased')\n",
    "bert_model.eval()\n",
    "# bert_model.to('cuda')\n",
    "print(bert_model.config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "statistics\n",
      "# lex units:  13572\n",
      "# frames:  1073\n",
      "# data points:  200751\n",
      "# lex units without data:  3271\n"
     ]
    }
   ],
   "source": [
    "# Load and prepare data\n",
    "data = load_fn_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# datapoints =  200750\n",
      "max labels =  1072\n",
      "1073\n"
     ]
    }
   ],
   "source": [
    "# create datapoints from data\n",
    "\n",
    "frame_dict = {}\n",
    "frame_dict_rev = {}\n",
    "\n",
    "inputs = []\n",
    "labels = []\n",
    "\n",
    "for lu in data:\n",
    "    frame =  lu[\"frame\"]\n",
    "    if not frame in frame_dict.keys():\n",
    "        frame_dict[frame] = len(frame_dict.keys())\n",
    "        frame_dict_rev[frame_dict[frame]] = frame\n",
    "    frame_id = frame_dict[frame]\n",
    "    \n",
    "    for sentence in lu[\"sentences\"]:\n",
    "        text = sentence[\"text\"]\n",
    "        indexes = sentence[\"indexes\"]\n",
    "        if len(indexes) > 0:\n",
    "            start = min([int(i[0]) for i in indexes])\n",
    "            end = max([int(i[1]) for i in indexes])\n",
    "            inputs.append((text, start, end))\n",
    "            labels.append(frame_id)\n",
    "        \n",
    "print(\"# datapoints = \", len(labels))\n",
    "print(\"max labels = \", max(labels))\n",
    "print(len(frame_dict.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'Abandonment': 664,\n",
      "'Abounding_with': 557,\n",
      "'Absorb_heat': 338,\n",
      "'Abundance': 452,\n",
      "'Abusing': 544,\n",
      "'Accompaniment': 915,\n",
      "'Accomplishment': 402,\n",
      "'Accoutrements': 387,\n",
      "'Accuracy': 609,\n",
      "'Achieving_first': 853,\n",
      "'Active_substance': 292,\n",
      "'Activity_done_state': 1027,\n",
      "'Activity_finish': 268,\n",
      "'Activity_ongoing': 269,\n",
      "'Activity_pause': 270,\n",
      "'Activity_prepare': 271,\n",
      "'Activity_ready_state': 399,\n",
      "'Activity_resume': 153,\n",
      "'Activity_start': 487,\n",
      "'Activity_stop': 261,\n",
      "'Actually_occurring_entity': 437,\n",
      "'Addiction': 556,\n",
      "'Adding_up': 1058,\n",
      "'Adducing': 451,\n",
      "'Adjacency': 960,\n",
      "'Adjusting': 326,\n",
      "'Adopt_selection': 440,\n",
      "'Aesthetics': 56,\n",
      "'Affirm_or_deny': 975,\n",
      "'Age': 16,\n",
      "'Aggregate': 152,\n",
      "'Aging': 619,\n",
      "'Agree_or_refuse_to_act': 477,\n",
      "'Agriculture': 656,\n",
      "'Aiming': 607,\n",
      "'Alliance': 433,\n",
      "'Alternatives': 632,\n",
      "'Amalgamation': 340,\n",
      "'Amassing': 382,\n",
      "'Ambient_temperature': 87,\n",
      "'Ammunition': 424,\n",
      "'Amounting_to': 206,\n",
      "'Animals': 951,\n",
      "'Annoyance': 674,\n",
      "'Appeal': 1052,\n",
      "'Appellations': 772,\n",
      "'Apply_heat': 341,\n",
      "'Appointing': 624,\n",
      "'Architectural_part': 14,\n",
      "'Armor': 829,\n",
      "'Arraignment': 1000,\n",
      "'Arranging': 314,\n",
      "'Arrest': 342,\n",
      "'Arriving': 204,\n",
      "'Arson': 1012,\n",
      "'Artifact': 83,\n",
      "'Artifact_subpart': 785,\n",
      "'Artificiality': 504,\n",
      "'Artistic_style': 790,\n",
      "'Assemble': 745,\n",
      "'Assessing': 343,\n",
      "'Assigned_location': 844,\n",
      "'Assistance': 289,\n",
      "'Atonement': 1040,\n",
      "'Attaching': 318,\n",
      "'Attack': 344,\n",
      "'Attempt': 50,\n",
      "'Attempt_means': 703,\n",
      "'Attempt_obtain_food_scenario': 998,\n",
      "'Attempt_suasion': 24,\n",
      "'Attending': 663,\n",
      "'Attention': 630,\n",
      "'Attention_getting': 722,\n",
      "'Attitude_description': 897,\n",
      "'Attributed_information': 144,\n",
      "'Authority': 784,\n",
      "'Avoiding': 346,\n",
      "'Awareness': 181,\n",
      "'Awareness_status': 826,\n",
      "'Bail_decision': 1001,\n",
      "'Be_in_agreement_on_action': 331,\n",
      "'Be_in_agreement_on_assessment': 392,\n",
      "'Be_on_alert': 907,\n",
      "'Be_subset_of': 576,\n",
      "'Be_translation_equivalent': 501,\n",
      "'Bearing_arms': 448,\n",
      "'Beat_opponent': 691,\n",
      "'Becoming': 972,\n",
      "'Becoming_a_member': 526,\n",
      "'Becoming_attached': 273,\n",
      "'Becoming_aware': 348,\n",
      "'Becoming_detached': 514,\n",
      "'Becoming_dry': 662,\n",
      "'Becoming_separated': 583,\n",
      "'Becoming_silent': 267,\n",
      "'Becoming_visible': 819,\n",
      "'Behind_the_scenes': 207,\n",
      "'Being_active': 412,\n",
      "'Being_at_risk': 420,\n",
      "'Being_attached': 470,\n",
      "'Being_awake': 512,\n",
      "'Being_born': 1072,\n",
      "'Being_detached': 347,\n",
      "'Being_dry': 208,\n",
      "'Being_employed': 117,\n",
      "'Being_in_captivity': 868,\n",
      "'Being_in_category': 577,\n",
      "'Being_in_control': 569,\n",
      "'Being_in_effect': 393,\n",
      "'Being_in_operation': 209,\n",
      "'Being_incarcerated': 870,\n",
      "'Being_located': 9,\n",
      "'Being_named': 279,\n",
      "'Being_necessary': 155,\n",
      "'Being_obligated': 724,\n",
      "'Being_obligatory': 161,\n",
      "'Being_operational': 178,\n",
      "'Being_questionable': 1056,\n",
      "'Being_relevant': 800,\n",
      "'Being_rotted': 1034,\n",
      "'Being_up_to_it': 362,\n",
      "'Being_wet': 1031,\n",
      "'Besieging': 575,\n",
      "'Beyond_compare': 120,\n",
      "'Billing': 979,\n",
      "'Biological_area': 640,\n",
      "'Biological_classification': 891,\n",
      "'Biological_entity': 809,\n",
      "'Biological_mechanisms': 914,\n",
      "'Biological_urge': 817,\n",
      "'Board_vehicle': 483,\n",
      "'Body_decoration': 1028,\n",
      "'Body_description_holistic': 109,\n",
      "'Body_description_part': 210,\n",
      "'Body_mark': 29,\n",
      "'Body_movement': 1,\n",
      "'Body_parts': 69,\n",
      "'Bond_maturation': 770,\n",
      "'Borrowing': 918,\n",
      "'Boundary': 265,\n",
      "'Bragging': 476,\n",
      "'Breaking_apart': 582,\n",
      "'Breaking_off': 581,\n",
      "'Breaking_out_captive': 862,\n",
      "'Breathing': 351,\n",
      "'Bringing': 361,\n",
      "'Building': 358,\n",
      "'Building_subparts': 22,\n",
      "'Buildings': 63,\n",
      "'Bungling': 359,\n",
      "'Burying': 911,\n",
      "'Business_closure': 839,\n",
      "'Businesses': 211,\n",
      "'Cache': 814,\n",
      "'Calendric_unit': 586,\n",
      "'Candidness': 315,\n",
      "'Capability': 41,\n",
      "'Capacity': 708,\n",
      "'Capital_stock': 878,\n",
      "'Cardinal_numbers': 71,\n",
      "'Carry_goods': 580,\n",
      "'Catastrophe': 538,\n",
      "'Catching_fire': 933,\n",
      "'Categorization': 172,\n",
      "'Causation': 177,\n",
      "'Cause_bodily_experience': 764,\n",
      "'Cause_change': 168,\n",
      "'Cause_change_of_consistency': 262,\n",
      "'Cause_change_of_phase': 212,\n",
      "'Cause_change_of_position_on_a_scale': 26,\n",
      "'Cause_change_of_strength': 559,\n",
      "'Cause_emotion': 666,\n",
      "'Cause_expansion': 263,\n",
      "'Cause_fluidic_motion': 213,\n",
      "'Cause_harm': 93,\n",
      "'Cause_impact': 214,\n",
      "'Cause_motion': 188,\n",
      "'Cause_proliferation_in_number': 611,\n",
      "'Cause_temperature_change': 366,\n",
      "'Cause_to_amalgamate': 367,\n",
      "'Cause_to_be_dry': 215,\n",
      "'Cause_to_be_included': 738,\n",
      "'Cause_to_be_sharp': 272,\n",
      "'Cause_to_be_wet': 368,\n",
      "'Cause_to_continue': 89,\n",
      "'Cause_to_end': 145,\n",
      "'Cause_to_experience': 369,\n",
      "'Cause_to_fragment': 370,\n",
      "'Cause_to_land': 849,\n",
      "'Cause_to_make_noise': 37,\n",
      "'Cause_to_make_progress': 312,\n",
      "'Cause_to_move_in_place': 216,\n",
      "'Cause_to_perceive': 713,\n",
      "'Cause_to_resume': 324,\n",
      "'Cause_to_rot': 746,\n",
      "'Cause_to_start': 372,\n",
      "'Cause_to_wake': 1024,\n",
      "'Ceasing_to_be': 697,\n",
      "'Certainty': 373,\n",
      "'Change_accessibility': 883,\n",
      "'Change_direction': 496,\n",
      "'Change_event_duration': 546,\n",
      "'Change_event_time': 23,\n",
      "'Change_of_consistency': 375,\n",
      "'Change_of_leadership': 47,\n",
      "'Change_of_phase': 940,\n",
      "'Change_of_quantity_of_possession': 733,\n",
      "'Change_of_temperature': 238,\n",
      "'Change_operational_state': 513,\n",
      "'Change_position_on_a_scale': 298,\n",
      "'Change_post-state': 994,\n",
      "'Change_posture': 486,\n",
      "'Change_resistance': 560,\n",
      "'Change_tool': 173,\n",
      "'Chaos': 893,\n",
      "'Chatting': 217,\n",
      "'Chemical-sense_description': 1021,\n",
      "'Chemical_potency': 908,\n",
      "'Choosing': 480,\n",
      "'Circumscribed_existence': 834,\n",
      "'Citing': 981,\n",
      "'Claim_ownership': 218,\n",
      "'Clemency': 103,\n",
      "'Closure': 1018,\n",
      "'Clothing': 222,\n",
      "'Clothing_parts': 1013,\n",
      "'Co-association': 995,\n",
      "'Cogitation': 183,\n",
      "'Cognitive_connection': 200,\n",
      "'Cognitive_impact': 792,\n",
      "'Coincidence': 572,\n",
      "'Collaboration': 156,\n",
      "'Colonization': 564,\n",
      "'Color': 921,\n",
      "'Color_qualities': 756,\n",
      "'Come_down_with': 997,\n",
      "'Come_into_effect': 969,\n",
      "'Come_together': 321,\n",
      "'Coming_to_be': 599,\n",
      "'Coming_to_believe': 404,\n",
      "'Coming_up_with': 309,\n",
      "'Commemorative': 643,\n",
      "'Commerce_buy': 920,\n",
      "'Commerce_collect': 977,\n",
      "'Commerce_pay': 547,\n",
      "'Commerce_scenario': 219,\n",
      "'Commerce_sell': 646,\n",
      "'Commercial_transaction': 278,\n",
      "'Commitment': 141,\n",
      "'Committing_crime': 275,\n",
      "'Commonality': 864,\n",
      "'Communicate_categorization': 565,\n",
      "'Communication': 39,\n",
      "'Communication_manner': 949,\n",
      "'Communication_means': 220,\n",
      "'Communication_noise': 1029,\n",
      "'Communication_response': 533,\n",
      "'Commutation': 104,\n",
      "'Commutative_process': 747,\n",
      "'Commutative_statement': 748,\n",
      "'Compatibility': 221,\n",
      "'Competition': 434,\n",
      "'Complaining': 473,\n",
      "'Completeness': 411,\n",
      "'Compliance': 110,\n",
      "'Concessive': 712,\n",
      "'Condition_symptom_relation': 901,\n",
      "'Conditional_occurrence': 962,\n",
      "'Conduct': 27,\n",
      "'Conferring_benefit': 905,\n",
      "'Confronting_problem': 680,\n",
      "'Connecting_architecture': 68,\n",
      "'Connectors': 578,\n",
      "'Conquering': 659,\n",
      "'Contacting': 223,\n",
      "'Containers': 224,\n",
      "'Containing': 528,\n",
      "'Contingency': 34,\n",
      "'Continued_state_of_affairs': 444,\n",
      "'Contrary_circumstances': 793,\n",
      "'Contrition': 1038,\n",
      "'Control': 568,\n",
      "'Controller_object': 855,\n",
      "'Convey_importance': 416,\n",
      "'Convoy': 827,\n",
      "'Cooking_creation': 879,\n",
      "'Corporal_punishment': 685,\n",
      "'Correctness': 395,\n",
      "'Corroding': 941,\n",
      "'Corroding_caused': 946,\n",
      "'Cotheme': 133,\n",
      "'Counterattack': 873,\n",
      "'Court_examination': 1053,\n",
      "'Craft': 21,\n",
      "'Create_physical_artwork': 502,\n",
      "'Create_representation': 500,\n",
      "'Creating': 264,\n",
      "'Criminal_investigation': 992,\n",
      "'Cure': 667,\n",
      "'Custom': 553,\n",
      "'Cutting': 485,\n",
      "'Damaging': 225,\n",
      "'Daring': 923,\n",
      "'Dead_or_alive': 479,\n",
      "'Death': 677,\n",
      "'Deception_success': 919,\n",
      "'Deciding': 888,\n",
      "'Defending': 293,\n",
      "'Degree': 650,\n",
      "'Degree_of_processing': 158,\n",
      "'Delimitation_of_diversity': 511,\n",
      "'Delivery': 300,\n",
      "'Deny_or_grant_permission': 642,\n",
      "'Departing': 31,\n",
      "'Deserving': 226,\n",
      "'Desirability': 32,\n",
      "'Desirable_event': 333,\n",
      "'Desiring': 73,\n",
      "'Destiny': 570,\n",
      "'Destroying': 290,\n",
      "'Detaching': 337,\n",
      "'Detaining': 983,\n",
      "'Detonate_explosive': 813,\n",
      "'Differentiation': 545,\n",
      "'Difficulty': 328,\n",
      "'Dimension': 723,\n",
      "'Direction': 227,\n",
      "'Directional_locative_relation': 968,\n",
      "'Discussion': 140,\n",
      "'Disembarking': 484,\n",
      "'Disgraceful_situation': 462,\n",
      "'Dispersal': 307,\n",
      "'Distant_operated_IED': 865,\n",
      "'Distinctiveness': 130,\n",
      "'Distributed_position': 339,\n",
      "'Diversity': 603,\n",
      "'Documents': 228,\n",
      "'Dodging': 283,\n",
      "'Domain': 1049,\n",
      "'Dominate_competitor': 567,\n",
      "'Dominate_situation': 566,\n",
      "'Dough_rising': 356,\n",
      "'Downing': 847,\n",
      "'Dressing': 194,\n",
      "'Drop_in_on': 636,\n",
      "'Dunking': 727,\n",
      "'Duplication': 0,\n",
      "'Duration_description': 592,\n",
      "'Duration_relation': 700,\n",
      "'Dying': 678,\n",
      "'Dynamism': 980,\n",
      "'Earnings_and_losses': 587,\n",
      "'Eclipse': 1059,\n",
      "'Economy': 431,\n",
      "'Education_teaching': 38,\n",
      "'Electricity': 410,\n",
      "'Elusive_goal': 284,\n",
      "'Emanating': 943,\n",
      "'Emergency': 934,\n",
      "'Emergency_fire': 937,\n",
      "'Emitting': 944,\n",
      "'Emotion_active': 46,\n",
      "'Emotion_directed': 142,\n",
      "'Emotion_heat': 1064,\n",
      "'Emotions_by_stimulus': 668,\n",
      "'Emotions_of_mental_activity': 671,\n",
      "'Emotions_success_or_failure': 669,\n",
      "'Emphasizing': 415,\n",
      "'Employing': 229,\n",
      "'Emptying': 7,\n",
      "'Encoding': 96,\n",
      "'Encounter': 751,\n",
      "'Endangering': 306,\n",
      "'Endeavor_failure': 837,\n",
      "'Enforcing': 573,\n",
      "'Entering_of_plea': 1005,\n",
      "'Entity': 523,\n",
      "'Entourage': 825,\n",
      "'Erasing': 788,\n",
      "'Escaping': 930,\n",
      "'Estimated_value': 398,\n",
      "'Estimating': 397,\n",
      "'Evading': 985,\n",
      "'Evaluative_comparison': 230,\n",
      "'Event': 277,\n",
      "'Event_instance': 710,\n",
      "'Eventive_affecting': 688,\n",
      "'Eventive_cognizer_affecting': 542,\n",
      "'Evidence': 176,\n",
      "'Evoking': 54,\n",
      "'Examination': 999,\n",
      "'Exchange': 162,\n",
      "'Exchange_currency': 171,\n",
      "'Exclude_member': 231,\n",
      "'Excreting': 40,\n",
      "'Execute_plan': 522,\n",
      "'Execution': 232,\n",
      "'Exemplar': 753,\n",
      "'Exemplariness': 754,\n",
      "'Exercising': 773,\n",
      "'Existence': 79,\n",
      "'Expansion': 233,\n",
      "'Expectation': 515,\n",
      "'Expected_location_of_person': 467,\n",
      "'Expend_resource': 737,\n",
      "'Expensiveness': 926,\n",
      "'Experience_bodily_harm': 235,\n",
      "'Experiencer_focus': 66,\n",
      "'Experiencer_obj': 352,\n",
      "'Experimentation': 880,\n",
      "'Expertise': 295,\n",
      "'Explaining_the_facts': 377,\n",
      "'Explosion': 815,\n",
      "'Exporting': 427,\n",
      "'Expressing_publicly': 474,\n",
      "'Extradition': 322,\n",
      "'Extreme_point': 432,\n",
      "'Extreme_value': 445,\n",
      "'Facial_expression': 1010,\n",
      "'Fairness_evaluation': 543,\n",
      "'Fall_asleep': 718,\n",
      "'Fall_for': 812,\n",
      "'Fame': 590,\n",
      "'Familiarity': 446,\n",
      "'Fastener': 447,\n",
      "'Fear': 676,\n",
      "'Feeling': 75,\n",
      "'Feigning': 506,\n",
      "'Fields': 301,\n",
      "'Fighting_activity': 1057,\n",
      "'Filling': 15,\n",
      "'Fining': 125,\n",
      "'Finish_competition': 236,\n",
      "'Finish_game': 730,\n",
      "'Fire_break': 939,\n",
      "'Fire_burning': 929,\n",
      "'Fire_going_out': 928,\n",
      "'Firefighting': 935,\n",
      "'Firing': 922,\n",
      "'Firing_point': 860,\n",
      "'First_experience': 72,\n",
      "'First_rank': 308,\n",
      "'Fleeing': 503,\n",
      "'Fluidic_motion': 942,\n",
      "'Food': 518,\n",
      "'Food_gathering': 654,\n",
      "'Foreign_or_domestic_country': 383,\n",
      "'Forging': 505,\n",
      "'Forgiveness': 1043,\n",
      "'Forgoing': 282,\n",
      "'Forming_relationships': 1065,\n",
      "'Freeing_from_confinement': 869,\n",
      "'Frequency': 443,\n",
      "'Friction': 1069,\n",
      "'Friendly_or_hostile': 871,\n",
      "'Front_for': 529,\n",
      "'Frugality': 90,\n",
      "'Fugitive': 881,\n",
      "'Fullness': 1015,\n",
      "'Function': 798,\n",
      "'Funding': 978,\n",
      "'Gathering_up': 417,\n",
      "'Gesture': 185,\n",
      "'Get_a_job': 116,\n",
      "'Getting': 380,\n",
      "'Getting_triggered': 846,\n",
      "'Getting_underway': 490,\n",
      "'Getting_up': 355,\n",
      "'Getting_vehicle_underway': 334,\n",
      "'Give_impression': 530,\n",
      "'Giving': 365,\n",
      "'Giving_birth': 349,\n",
      "'Giving_in': 389,\n",
      "'Gizmo': 33,\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "'Go_into_shape': 735,\n",
      "'Goal': 165,\n",
      "'Going_back_on_a_commitment': 426,\n",
      "'Gradable_artistic_quality': 906,\n",
      "'Gradable_proximity': 961,\n",
      "'Graph_shape': 796,\n",
      "'Grasp': 20,\n",
      "'Grinding': 984,\n",
      "'Grooming': 253,\n",
      "'Ground_up': 554,\n",
      "'Growing_food': 657,\n",
      "'Guest_and_host': 634,\n",
      "'Guilt_or_innocence': 113,\n",
      "'Gusto': 810,\n",
      "'Hair_configuration': 237,\n",
      "'Halt': 67,\n",
      "'Have_as_requirement': 385,\n",
      "'Have_as_translation_equivalent': 498,\n",
      "'Have_associated': 633,\n",
      "'Have_visitor_over': 635,\n",
      "'Having_or_lacking_access': 584,\n",
      "'Health_response': 1068,\n",
      "'Hearsay': 1025,\n",
      "'Heat_potential': 761,\n",
      "'Hedging': 971,\n",
      "'Heralding': 88,\n",
      "'Hiding_objects': 510,\n",
      "'Hindering': 186,\n",
      "'Hiring': 118,\n",
      "'Historic_event': 613,\n",
      "'History': 698,\n",
      "'Hit_or_miss': 608,\n",
      "'Hit_target': 1050,\n",
      "'Holding_off_on': 435,\n",
      "'Hospitality': 638,\n",
      "'Hostile_encounter': 105,\n",
      "'Hunting': 653,\n",
      "'Hunting_success_or_failure': 655,\n",
      "'Identicality': 74,\n",
      "'Identity': 990,\n",
      "'Idiosyncrasy': 35,\n",
      "'Imitating': 509,\n",
      "'Immobilization': 734,\n",
      "'Impact': 195,\n",
      "'Import_export_scenario': 1048,\n",
      "'Importance': 391,\n",
      "'Importing': 425,\n",
      "'Imposing_obligation': 386,\n",
      "'Impression': 755,\n",
      "'Imprisonment': 114,\n",
      "'Improvement_or_decline': 993,\n",
      "'Improvised_explosive_device': 841,\n",
      "'Inclination': 840,\n",
      "'Inclusion': 148,\n",
      "'Increment': 44,\n",
      "'Indicating': 695,\n",
      "'Indigenous_origin': 419,\n",
      "'Individual_history': 612,\n",
      "'Ineffability': 631,\n",
      "'Infecting': 956,\n",
      "'Information': 441,\n",
      "'Information_display': 854,\n",
      "'Infrastructure': 381,\n",
      "'Ingest_substance': 701,\n",
      "'Ingestion': 471,\n",
      "'Ingredients': 304,\n",
      "'Inherent_purpose': 902,\n",
      "'Inhibit_movement': 364,\n",
      "'Inspecting': 254,\n",
      "'Installing': 413,\n",
      "'Instance': 1006,\n",
      "'Institutionalization': 472,\n",
      "'Institutions': 454,\n",
      "'Intentional_deception': 811,\n",
      "'Intentional_traversing': 491,\n",
      "'Intentionally_act': 132,\n",
      "'Intentionally_affect': 626,\n",
      "'Intentionally_create': 154,\n",
      "'Intercepting': 403,\n",
      "'Interior_profile_relation': 166,\n",
      "'Interrupt_process': 604,\n",
      "'Intoxicants': 954,\n",
      "'Intoxication': 497,\n",
      "'Invading': 660,\n",
      "'Irregular_combatants': 828,\n",
      "'Isolated_places': 982,\n",
      "'Judgment': 164,\n",
      "'Judgment_communication': 121,\n",
      "'Judgment_direct_address': 122,\n",
      "'Judgment_of_intensity': 904,\n",
      "'Judicial_body': 457,\n",
      "'Jury_deliberation': 1045,\n",
      "'Just_found_out': 675,\n",
      "'Justifying': 1054,\n",
      "'Key': 913,\n",
      "'Kidnapping': 1002,\n",
      "'Killing': 644,\n",
      "'Kinship': 316,\n",
      "'Knot_creation': 1017,\n",
      "'Labeling': 239,\n",
      "'Labor_product': 714,\n",
      "'Launch_process': 687,\n",
      "'Law': 8,\n",
      "'Law_enforcement_agency': 856,\n",
      "'Leadership': 135,\n",
      "'Leaving_traces': 794,\n",
      "'Left_to_do': 81,\n",
      "'Legal_rulings': 693,\n",
      "'Legality': 108,\n",
      "'Lending': 917,\n",
      "'Level_of_force_exertion': 428,\n",
      "'Level_of_force_resistance': 682,\n",
      "'Level_of_light': 896,\n",
      "'Light_movement': 786,\n",
      "'Likelihood': 179,\n",
      "'Limitation': 899,\n",
      "'Limiting': 898,\n",
      "'Linguistic_meaning': 1047,\n",
      "'Lively_place': 610,\n",
      "'Living_conditions': 807,\n",
      "'Locale': 91,\n",
      "'Locale_by_characteristic_entity': 884,\n",
      "'Locale_by_collocation': 842,\n",
      "'Locale_by_event': 620,\n",
      "'Locale_by_ownership': 715,\n",
      "'Locale_by_use': 55,\n",
      "'Locale_closure': 838,\n",
      "'Locating': 243,\n",
      "'Location_in_time': 711,\n",
      "'Location_of_light': 618,\n",
      "'Locative_relation': 184,\n",
      "'Losing': 729,\n",
      "'Losing_it': 19,\n",
      "'Losing_someone': 728,\n",
      "'Losing_track_of_perceiver': 732,\n",
      "'Losing_track_of_theme': 731,\n",
      "'Luck': 571,\n",
      "'Make_acquaintance': 102,\n",
      "'Make_agreement_on_action': 299,\n",
      "'Make_cognitive_connection': 531,\n",
      "'Make_compromise': 885,\n",
      "'Make_noise': 246,\n",
      "'Making_arrangements': 950,\n",
      "'Making_faces': 1071,\n",
      "'Manipulate_into_doing': 516,\n",
      "'Manipulate_into_shape': 707,\n",
      "'Manipulation': 187,\n",
      "'Manner': 806,\n",
      "'Manner_of_life': 805,\n",
      "'Manufacturing': 276,\n",
      "'Margin_of_resolution': 422,\n",
      "'Mass_motion': 1022,\n",
      "'Mathematical_relationship': 903,\n",
      "'Means': 562,\n",
      "'Measurable_attributes': 721,\n",
      "'Measure_area': 1009,\n",
      "'Measure_by_action': 247,\n",
      "'Measure_duration': 551,\n",
      "'Measure_linear_extent': 552,\n",
      "'Measure_mass': 414,\n",
      "'Measure_volume': 991,\n",
      "'Medical_conditions': 85,\n",
      "'Medical_instruments': 1020,\n",
      "'Medical_interaction_scenario': 894,\n",
      "'Medical_intervention': 912,\n",
      "'Medical_professionals': 536,\n",
      "'Medical_specialties': 1019,\n",
      "'Medium': 651,\n",
      "'Meet_specifications': 741,\n",
      "'Meet_with': 743,\n",
      "'Meet_with_response': 739,\n",
      "'Member_of_military': 866,\n",
      "'Membership': 70,\n",
      "'Memorization': 128,\n",
      "'Memory': 1008,\n",
      "'Mental_property': 353,\n",
      "'Mental_stimulus_exp_focus': 670,\n",
      "'Mental_stimulus_stimulus_focus': 672,\n",
      "'Mention': 694,\n",
      "'Military': 396,\n",
      "'Military_operation': 872,\n",
      "'Mining': 781,\n",
      "'Misdeed': 111,\n",
      "'Money': 539,\n",
      "'Morality_evaluation': 92,\n",
      "'Motion': 190,\n",
      "'Motion_directional': 488,\n",
      "'Motion_noise': 42,\n",
      "'Moving_in_place': 36,\n",
      "'Name_conferral': 645,\n",
      "'Namesake': 241,\n",
      "'Natural_features': 49,\n",
      "'Needing': 325,\n",
      "'Negation': 957,\n",
      "'Negative_conditional': 963,\n",
      "'Network': 461,\n",
      "'Noise_makers': 57,\n",
      "'Non-commutative_process': 749,\n",
      "'Non-commutative_statement': 750,\n",
      "'Non-gradable_proximity': 958,\n",
      "'Noncombatant': 833,\n",
      "'Notability': 801,\n",
      "'Notification_of_charges': 1004,\n",
      "'Nuclear_process': 625,\n",
      "'Objective_influence': 101,\n",
      "'Obscurity': 591,\n",
      "'Obviousness': 320,\n",
      "'Occupy_rank': 330,\n",
      "'Offenses': 260,\n",
      "'Offering': 742,\n",
      "'Offshoot': 996,\n",
      "'Omen': 86,\n",
      "'Ontogeny': 889,\n",
      "'Openness': 585,\n",
      "'Operate_vehicle': 466,\n",
      "'Operating_a_system': 388,\n",
      "'Operational_testing': 280,\n",
      "'Opinion': 350,\n",
      "'Opportunity': 767,\n",
      "'Optical_image': 763,\n",
      "'Ordinal_numbers': 248,\n",
      "'Organization': 323,\n",
      "'Origin': 126,\n",
      "'Others_situation_as_stimulus': 673,\n",
      "'Out_of_existence': 696,\n",
      "'Pardon': 1062,\n",
      "'Part_edge': 1032,\n",
      "'Part_inner_outer': 250,\n",
      "'Part_ordered_segments': 251,\n",
      "'Part_orientational': 629,\n",
      "'Part_piece': 851,\n",
      "'Part_whole': 77,\n",
      "'Partiality': 249,\n",
      "'Participation': 136,\n",
      "'Partitive': 53,\n",
      "'Passing': 726,\n",
      "'Passing_off': 508,\n",
      "'Path_shape': 119,\n",
      "'Path_traveled': 442,\n",
      "'Patrolling': 832,\n",
      "'Pattern': 705,\n",
      "'People': 313,\n",
      "'People_along_political_spectrum': 423,\n",
      "'People_by_age': 17,\n",
      "'People_by_jurisdiction': 64,\n",
      "'People_by_military_specialty': 867,\n",
      "'People_by_morality': 112,\n",
      "'People_by_origin': 18,\n",
      "'People_by_religion': 61,\n",
      "'People_by_residence': 252,\n",
      "'People_by_vocation': 95,\n",
      "'Perception_active': 51,\n",
      "'Perception_body': 25,\n",
      "'Perception_experience': 30,\n",
      "'Performers': 107,\n",
      "'Performers_and_roles': 594,\n",
      "'Performing_arts': 589,\n",
      "'Personal_relationship': 258,\n",
      "'Personal_success': 774,\n",
      "'Physical_artworks': 499,\n",
      "'Piracy': 259,\n",
      "'Placing': 174,\n",
      "'Planned_trajectory': 808,\n",
      "'Planting': 658,\n",
      "'Plants': 955,\n",
      "'Point_of_dispute': 458,\n",
      "'Political_actions': 789,\n",
      "'Political_locales': 45,\n",
      "'Popularity': 771,\n",
      "'Posing_as': 507,\n",
      "'Position_on_a_scale': 602,\n",
      "'Possession': 205,\n",
      "'Possibility': 699,\n",
      "'Posture': 84,\n",
      "'Practice': 10,\n",
      "'Praiseworthiness': 1041,\n",
      "'Prank': 822,\n",
      "'Precariousness': 892,\n",
      "'Precipitation': 683,\n",
      "'Predicament': 1026,\n",
      "'Predicting': 534,\n",
      "'Preference': 478,\n",
      "'Preliminaries': 628,\n",
      "'Presence': 621,\n",
      "'Presentation_of_mitigation': 988,\n",
      "'Preserving': 257,\n",
      "'Prevarication': 99,\n",
      "'Prevent_or_allow_possession': 679,\n",
      "'Preventing_or_letting': 256,\n",
      "'Price_per_unit': 778,\n",
      "'Prison': 1039,\n",
      "'Probability': 768,\n",
      "'Process': 1060,\n",
      "'Process_completed_state': 1030,\n",
      "'Process_continue': 305,\n",
      "'Process_end': 182,\n",
      "'Process_resume': 947,\n",
      "'Process_start': 1007,\n",
      "'Process_stop': 641,\n",
      "'Processing_materials': 139,\n",
      "'Procreative_sex': 967,\n",
      "'Product_development': 887,\n",
      "'Product_line': 799,\n",
      "'Progression': 296,\n",
      "'Prohibiting_or_licensing': 648,\n",
      "'Project': 159,\n",
      "'Proliferating_in_number': 449,\n",
      "'Prominence': 374,\n",
      "'Proper_reference': 439,\n",
      "'Proportion': 779,\n",
      "'Proportional_quantity': 622,\n",
      "'Protecting': 481,\n",
      "'Protest': 823,\n",
      "'Provide_lodging': 616,\n",
      "'Public_services': 455,\n",
      "'Publishing': 835,\n",
      "'Punctual_perception': 692,\n",
      "'Purpose': 157,\n",
      "'Putting_out_fire': 927,\n",
      "'Quantified_mass': 311,\n",
      "'Quantity': 831,\n",
      "'Quarreling': 600,\n",
      "'Questioning': 1037,\n",
      "'Quitting': 354,\n",
      "'Quitting_a_place': 492,\n",
      "'Race_descriptor': 752,\n",
      "'Range': 317,\n",
      "'Rank': 327,\n",
      "'Ranked_expectation': 409,\n",
      "'Rape': 1003,\n",
      "'Rashness': 438,\n",
      "'Rate_description': 775,\n",
      "'Rate_quantification': 776,\n",
      "'Ratification': 319,\n",
      "'Reading_activity': 936,\n",
      "'Reading_aloud': 736,\n",
      "'Reading_perception': 555,\n",
      "'Reason': 468,\n",
      "'Reasoning': 550,\n",
      "'Reassuring': 329,\n",
      "'Rebellion': 820,\n",
      "'Receiving': 169,\n",
      "'Recording': 463,\n",
      "'Records': 877,\n",
      "'Recovery': 6,\n",
      "'Redirecting': 406,\n",
      "'Reference_text': 623,\n",
      "'Referring_by_name': 242,\n",
      "'Reforming_a_system': 429,\n",
      "'Regard': 345,\n",
      "'Rejuvenation': 464,\n",
      "'Relating_concepts': 681,\n",
      "'Relation': 1044,\n",
      "'Relational_natural_features': 384,\n",
      "'Relational_political_locales': 595,\n",
      "'Relational_quantity': 780,\n",
      "'Relative_time': 98,\n",
      "'Releasing': 520,\n",
      "'Reliance': 3,\n",
      "'Reliance_on_expectation': 5,\n",
      "'Religious_belief': 60,\n",
      "'Remainder': 78,\n",
      "'Remembering_experience': 197,\n",
      "'Remembering_information': 198,\n",
      "'Remembering_to_do': 199,\n",
      "'Removing': 234,\n",
      "'Render_nonfunctional': 180,\n",
      "'Renting': 524,\n",
      "'Renting_out': 525,\n",
      "'Renunciation': 202,\n",
      "'Reparation': 1042,\n",
      "'Repayment': 769,\n",
      "'Repel': 661,\n",
      "'Replacing': 124,\n",
      "'Reporting': 469,\n",
      "'Representative': 875,\n",
      "'Representing': 717,\n",
      "'Request': 58,\n",
      "'Request_entity': 890,\n",
      "'Required_event': 332,\n",
      "'Rescuing': 952,\n",
      "'Research': 147,\n",
      "'Reserving': 876,\n",
      "'Reshaping': 686,\n",
      "'Residence': 617,\n",
      "'Resolve_problem': 405,\n",
      "'Respond_to_proposal': 475,\n",
      "'Response': 281,\n",
      "'Responsibility': 857,\n",
      "'Rest': 43,\n",
      "'Resurrection': 357,\n",
      "'Retaining': 924,\n",
      "'Reveal_secret': 143,\n",
      "'Revenge': 4,\n",
      "'Revolution': 821,\n",
      "'Rewards_and_punishments': 1014,\n",
      "'Ride_vehicle': 549,\n",
      "'Rising_to_a_challenge': 363,\n",
      "'Risky_situation': 540,\n",
      "'Rite': 59,\n",
      "'Roadways': 639,\n",
      "'Robbery': 288,\n",
      "'Rope_manipulation': 1016,\n",
      "'Rotting': 1033,\n",
      "'Run_risk': 541,\n",
      "'Sacrificing_for': 886,\n",
      "'Satisfying': 740,\n",
      "'Scarcity': 453,\n",
      "'Scheduling': 948,\n",
      "'Scope': 627,\n",
      "'Scouring': 255,\n",
      "'Scrutiny': 192,\n",
      "'Secrecy_status': 138,\n",
      "'See_through': 970,\n",
      "'Seeking': 910,\n",
      "'Seeking_to_achieve': 379,\n",
      "'Self_control': 852,\n",
      "'Self_motion': 287,\n",
      "'Sending': 150,\n",
      "'Sensation': 52,\n",
      "'Sent_items': 459,\n",
      "'Sentencing': 115,\n",
      "'Separating': 574,\n",
      "'Sequence': 706,\n",
      "'Serving_in_capacity': 797,\n",
      "'Set_of_interrelated_entities': 465,\n",
      "'Setting_back_burn': 932,\n",
      "'Setting_fire': 938,\n",
      "'Setting_out': 494,\n",
      "'Severity_of_offense': 1055,\n",
      "'Sex': 966,\n",
      "'Shaped_part': 13,\n",
      "'Shapes': 274,\n",
      "'Sharing': 925,\n",
      "'Sharpness': 266,\n",
      "'Shoot_projectiles': 310,\n",
      "'Shopping': 593,\n",
      "'Short_selling': 719,\n",
      "'Sidereal_appearance': 360,\n",
      "'Sign': 647,\n",
      "'Sign_agreement': 146,\n",
      "'Silencing': 1023,\n",
      "'Similarity': 129,\n",
      "'Simple_name': 244,\n",
      "'Simple_naming': 245,\n",
      "'Simultaneity': 606,\n",
      "'Size': 759,\n",
      "'Sleep': 203,\n",
      "'Smuggling': 1011,\n",
      "'Soaking': 945,\n",
      "'Soaking_up': 665,\n",
      "'Sociability': 783,\n",
      "'Social_connection': 532,\n",
      "'Social_desirability': 859,\n",
      "'Social_event': 62,\n",
      "'Social_event_collective': 757,\n",
      "'Social_event_individuals': 758,\n",
      "'Social_interaction_evaluation': 588,\n",
      "'Sole_instance': 1035,\n",
      "'Sound_level': 716,\n",
      "'Sound_movement': 1070,\n",
      "'Sounds': 189,\n",
      "'Source_of_getting': 303,\n",
      "'Spatial_co-location': 964,\n",
      "'Spatial_contact': 959,\n",
      "'Speak_on_topic': 709,\n",
      "'Specific_individual': 863,\n",
      "'Speed_description': 11,\n",
      "'Spelling_and_pronouncing': 97,\n",
      "'Sports_jargon': 795,\n",
      "'Stage_of_progress': 297,\n",
      "'Standing_by': 845,\n",
      "'State_continue': 151,\n",
      "'State_of_entity': 450,\n",
      "'Statement': 196,\n",
      "'Stimulus_focus': 65,\n",
      "'Stinginess': 1066,\n",
      "'Store': 394,\n",
      "'Storing': 418,\n",
      "'Strictness': 791,\n",
      "'Studying': 127,\n",
      "'Suasion': 986,\n",
      "'Subjective_influence': 100,\n",
      "'Subjective_temperature': 760,\n",
      "'Submitting_documents': 401,\n",
      "'Subordinates_and_superiors': 561,\n",
      "'Subsisting': 804,\n",
      "'Substance': 291,\n",
      "'Substance_by_phase': 803,\n",
      "'Subversion': 106,\n",
      "'Success_or_failure': 335,\n",
      "'Successful_action': 336,\n",
      "'Successfully_communicate_message': 495,\n",
      "'Sufficiency': 240,\n",
      "'Suicide_attack': 858,\n",
      "'Suitability': 430,\n",
      "'Summarizing': 517,\n",
      "'Supply': 302,\n",
      "'Supporting': 649,\n",
      "'Surpassing': 489,\n",
      "'Surrendering': 652,\n",
      "'Surrendering_possession': 390,\n",
      "'Surrounding': 965,\n",
      "'Surviving': 563,\n",
      "'Suspicion': 1051,\n",
      "'System': 460,\n",
      "'System_complexity': 782,\n",
      "'Take_place_of': 163,\n",
      "'Taking': 684,\n",
      "'Taking_captive': 861,\n",
      "'Taking_sides': 285,\n",
      "'Taking_time': 12,\n",
      "'Talking_into': 989,\n",
      "'Tasting': 765,\n",
      "'Team': 931,\n",
      "'Telling': 836,\n",
      "'Temperature': 82,\n",
      "'Temporal_collocation': 605,\n",
      "'Temporal_pattern': 80,\n",
      "'Temporal_subregion': 378,\n",
      "'Temporary_group': 824,\n",
      "'Temporary_leave': 895,\n",
      "'Temporary_stay': 597,\n",
      "'Terms_of_agreement': 407,\n",
      "'Terrorism': 537,\n",
      "'Text': 48,\n",
      "'Text_creation': 535,\n",
      "'Theft': 689,\n",
      "'Thermodynamic_phase': 802,\n",
      "'Thriving': 558,\n",
      "'Thwarting': 704,\n",
      "'Time_period_of_action': 762,\n",
      "'Time_vector': 123,\n",
      "'Timespan': 720,\n",
      "'Timetable': 953,\n",
      "'Tolerating': 702,\n",
      "'Tool_purpose': 900,\n",
      "'Topic': 2,\n",
      "'Touring': 614,\n",
      "'Toxic_substance': 160,\n",
      "'Transfer': 170,\n",
      "'Transition_to_a_quality': 973,\n",
      "'Transition_to_a_situation': 976,\n",
      "'Transition_to_state': 94,\n",
      "'Translating': 493,\n",
      "'Transportation_status': 850,\n",
      "'Trap': 816,\n",
      "'Travel': 598,\n",
      "'Traversing': 286,\n",
      "'Treating_and_mistreating': 916,\n",
      "'Trendiness': 596,\n",
      "'Trial': 1063,\n",
      "'Triggering': 818,\n",
      "'Trust': 527,\n",
      "'Try_defendant': 1067,\n",
      "'Trying_out': 766,\n",
      "'Turning_out': 725,\n",
      "'Type': 28,\n",
      "'Typicality': 131,\n",
      "'Unattributed_information': 149,\n",
      "'Undergo_change': 167,\n",
      "'Undergo_transformation': 974,\n",
      "'Undergoing': 521,\n",
      "'Undressing': 193,\n",
      "'Unemployment_rate': 456,\n",
      "'Use_firearm': 987,\n",
      "'Used_up': 408,\n",
      "'Usefulness': 294,\n",
      "'Using': 134,\n",
      "'Using_resource': 787,\n",
      "'Vehicle': 519,\n",
      "'Vehicle_departure_initial_stage': 843,\n",
      "'Vehicle_landing': 848,\n",
      "'Vehicle_subpart': 777,\n",
      "'Verdict': 1036,\n",
      "'Verification': 201,\n",
      "'Version_sequence': 637,\n",
      "'Victim_operated_IED': 874,\n",
      "'Violence': 882,\n",
      "'Visiting': 615,\n",
      "'Vocalizations': 579,\n",
      "'Volubility': 175,\n",
      "'Wagering': 744,\n",
      "'Waiting': 421,\n",
      "'Waking_up': 371,\n",
      "'Want_suspect': 1061,\n",
      "'Warning': 909,\n",
      "'Waver_between_options': 436,\n",
      "'Wealthiness': 76,\n",
      "'Weapon': 137,\n",
      "'Wearing': 191,\n",
      "'Weather': 548,\n",
      "'Willingness': 482,\n",
      "'Win_prize': 690,\n",
      "'Withdraw_from_participation': 400,\n",
      "'Within_distance': 376,\n",
      "'Word_relations': 1046,\n",
      "'Work': 601,\n",
      "'Working_a_post': 830}\n"
     ]
    }
   ],
   "source": [
    "import pprint\n",
    "pp = pprint.PrettyPrinter(indent=0, depth=6)\n",
    "pp.pprint(frame_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# You should build your custom dataset as below.\n",
    "class FnBertDataset(torch.utils.data.Dataset):\n",
    "    \n",
    "    def __init__(self, inputs, labels, frame_dict, tokenizer, bert_model):\n",
    "        \"\"\"\n",
    "        First two arguments should be lists with the format:\n",
    "        inputs: [(text1, start1, end1), ...]\n",
    "        labels: [label_id1, ...]\n",
    "        \"\"\"\n",
    "        self.inputs = inputs\n",
    "        self.labels = labels\n",
    "        \n",
    "        self.tokenizer = tokenizer\n",
    "        self.bert_model = bert_model\n",
    "        \n",
    "        self.MAX_LEN = 4\n",
    "        self.INPUT_DIM = self.MAX_LEN * self.bert_model.config.hidden_size\n",
    "        self.OUTPUT_DIM = len(frame_dict.keys())\n",
    "        \n",
    "    def __getitem__(self, index):\n",
    "        text, start, end = self.inputs[index]\n",
    "        x = self.get_bert_hidden_state(text, start, end)\n",
    "        y = torch.tensor(self.labels[index]).long()        \n",
    "        return x, y\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.labels)\n",
    "    \n",
    "    def get_bert_hidden_state(self, text, start, end):\n",
    "        text = \"[CLS] \" + text + \" [SEP]\"\n",
    "        start += len(\"[CLS] \")\n",
    "        end += len(\"[CLS] \")\n",
    "        \n",
    "        # Compute start end end using token indexes\n",
    "        tk_start, tk_end = self.pos_to_token_idx(text, start, end)\n",
    "        tk_end = min(tk_start + self.MAX_LEN, tk_end)\n",
    "        # Tokenize input\n",
    "        tokenized_text = self.tokenizer.tokenize(text)\n",
    "    \n",
    "        # Convert token to vocabulary indices\n",
    "        indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)\n",
    "        # Convert inputs to PyTorch tensors\n",
    "        \n",
    "#         tokens_tensor = torch.tensor([indexed_tokens]).to('cuda')\n",
    "        tokens_tensor = torch.tensor([indexed_tokens])\n",
    "        \n",
    "        # Predict hidden states features for each layer\n",
    "        with torch.no_grad():\n",
    "            outputs = self.bert_model(tokens_tensor)\n",
    "            # Hidden state of the last layer of the Bert model\n",
    "            hidden = torch.squeeze(outputs[0], dim = 0)\n",
    "            # Slice hidden state to hidden[start:end]\n",
    "            hidden = hidden.narrow(0, tk_start, tk_end-tk_start)\n",
    "            # Add padding\n",
    "            pad = torch.zeros(self.MAX_LEN, hidden.size()[1])            \n",
    "            pad[0:hidden.size()[0],:] = hidden\n",
    "            hidden = torch.flatten(pad)\n",
    "            return hidden\n",
    "\n",
    "    def pos_to_token_idx(self, text, start, end):\n",
    "        target_prefix = self.tokenizer.tokenize(text[:start])\n",
    "        target = self.tokenizer.tokenize(text[start:end+1])\n",
    "        tk_start = len(target_prefix)\n",
    "        tk_end = tk_start + len(target)\n",
    "        return tk_start, tk_end\n",
    "    \n",
    "dataset = FnBertDataset(inputs, labels, frame_dict, tokenizer, bert_model)\n",
    "print(\"dataset in = \", dataset[100][0])\n",
    "print(\"dataset out = \", dataset[100][1], dataset[100][1].type())\n",
    "print(\"dimensions: in =\", dataset.INPUT_DIM, \" out = \", dataset.OUTPUT_DIM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_net(input_dim, output_dim):\n",
    "    layers = [\n",
    "        nn.Dropout(),\n",
    "        nn.Linear(input_dim, 400),\n",
    "        nn.ReLU(),\n",
    "        nn.Dropout(),\n",
    "        nn.Linear(400, output_dim),\n",
    "    ]\n",
    "    model = nn.Sequential(*layers)\n",
    "    return model\n",
    "\n",
    "# Run training & testing\n",
    "net = create_net(input_dim = dataset.INPUT_DIM, output_dim = dataset.OUTPUT_DIM)\n",
    "\n",
    "net = net.cpu()\n",
    "net.load_state_dict(torch.load('C:\\\\Users\\\\danil\\\\Documents\\\\Northwestern\\\\Research\\\\projects\\\\frame_classification\\\\state_dict_3'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model = Model(net, criterion = nn.CrossEntropyLoss(),\n",
    "              optimizer=optim.Adam(net.parameters(), lr=10e-5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# model.fit(dataset, n_epochs=10, batch_size=32, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# torch.save(\n",
    "#     net.state_dict(), 'C:\\\\Users\\\\danil\\\\Documents\\\\Northwestern\\\\Research\\\\projects\\\\frame_classification\\\\state_dict_5')\n",
    "# torch.save(\n",
    "#     net, 'C:\\\\Users\\\\danil\\\\Documents\\\\Northwestern\\\\Research\\\\projects\\\\frame_classification\\\\net_5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dev_idxs = random.choices(range(len(inputs)), k=1000)\n",
    "dev_inputs = [inputs[idx] for idx in dev_idxs]\n",
    "dev_labels = [labels[idx] for idx in dev_idxs]\n",
    "\n",
    "net.eval()\n",
    "dev_dataset = FnBertDataset(dev_inputs, dev_labels, frame_dict, tokenizer, bert_model)\n",
    "print(\"length of dev set: \", len(dev_dataset))\n",
    "model.test(dev_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def predict_top_k_dataset(dataset, k, batch_size=1):\n",
    "        predicted_lst = []\n",
    "        probs_lst = []\n",
    "        data_loader = torch.utils.data.DataLoader(\n",
    "            dataset=dataset, batch_size=batch_size, shuffle=False)    \n",
    "        with torch.no_grad():\n",
    "            for (inputs, _) in data_loader:\n",
    "                inputs = inputs.to(\"cuda\")\n",
    "                predicted, probs = predict_top_k(inputs, k)\n",
    "                predicted_lst.append(predicted)\n",
    "                probs_lst.append(probs)\n",
    "        predicted_tensor = torch.cat(predicted_lst, 0)\n",
    "        probs_tensor = torch.cat(probs_lst, 0)\n",
    "        return predicted_tensor, probs_tensor\n",
    "    \n",
    "def predict_top_k(inputs, k, batch_size=1):\n",
    "    inputs = inputs.to(\"cuda\")\n",
    "    with torch.no_grad():\n",
    "        outputs = net(inputs)\n",
    "        logits, predicted = torch.topk(outputs.data, k, dim = 1)\n",
    "        softmax = nn.Softmax(dim=1)\n",
    "        probs = softmax(logits)\n",
    "        return predicted, probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dev_in = [\n",
    "    (\"the problem is telling which is the original document and which the copy\", 68, 71),\n",
    "    (\"the cause of the accident is not clear\", 4, 8),\n",
    "    (\"Rubella, also known as German measles or three-day measles, is an infection caused by the rubella virus.\", 0, 6),\n",
    "    (\"he died after a long illness\", 21, 27),\n",
    "    (\"for a time revolution was a strong probability\", 35, 45),\n",
    "]\n",
    "dev_lab = [\n",
    "    frame_dict[\"Duplication\"], frame_dict[\"Causation\"], \n",
    "    frame_dict[\"Medical_conditions\"], frame_dict[\"Medical_conditions\"],\n",
    "    frame_dict[\"Probability\"]\n",
    "]\n",
    "dev_dataset = FnBertDataset(dev_in, dev_lab, frame_dict, tokenizer, bert_model)\n",
    "preds, probs = predict_top_k_dataset(dev_dataset, 5)\n",
    "preds = preds.tolist()\n",
    "probs = probs.tolist()\n",
    "for pred, prob in zip(preds, probs):\n",
    "    print([(frame_dict_rev[x], round(y, 2)) for x, y in zip(pred, prob)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
