{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "47be19a0-0291-4a85-b7ea-13c488a3c101",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"\"\n",
    "\n",
    "import spacy\n",
    "from spacy import displacy\n",
    "\n",
    "nlp = spacy.load('en_core_web_sm')\n",
    "\n",
    "import benepar\n",
    "from benepar import BeneparComponent, NonConstituentException\n",
    "benepar.download('benepar_en3')\n",
    "\n",
    "nlp.add_pipe(\"benepar\", config={\"model\": \"benepar_en3\"})\n",
    "\n",
    "import sys\n",
    "sys.path.insert(0, '../../utils/')\n",
    "from utils_AURC import (\n",
    "    InputFeatures_Constituent, \n",
    "    parse_data)\n",
    "\n",
    "from utils_GNN_CRF import (\n",
    "    compute_tab_logits_pred_sentence_level,\n",
    "    extract_sentence_level_label,\n",
    "    F1_Loss,\n",
    "    find_max_size_batch_without_nt,\n",
    "    remove_internal_nodes,\n",
    "    remove_internal_nodes_without_labels,\n",
    "    reshape_label_to_batch_padded,\n",
    "    reshape_mask_to_batch_padded,\n",
    "    reshape_data_to_batch_padded,\n",
    "    find_the_most_frequent_element,\n",
    "    add_dimensions_null_to_embedding,\n",
    "    construct_the_edges_matrix_consistuency_padding,\n",
    "    find_the_most_frequent_element\n",
    ")\n",
    "\n",
    "from tqdm import tqdm\n",
    "import time\n",
    "import random\n",
    "import datetime\n",
    "from collections import defaultdict\n",
    "import numpy as np\n",
    "\n",
    "# Data Loading\n",
    "import os.path as osp\n",
    "\n",
    "import json\n",
    "\n",
    "from transformers import (AdamW, BertConfig, get_linear_schedule_with_warmup, \n",
    "                                BertForTokenClassification, BertTokenizer, BertPreTrainedModel, BertModel)\n",
    "\n",
    "# Pytorch Module\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "from torch.nn import Linear, ReLU\n",
    "\n",
    "\n",
    "from torch_geometric.nn import GCNConv,GATConv, SAGEConv,Sequential,AGNNConv,global_mean_pool\n",
    "from torch_geometric.data import Dataset, download_url,DataLoader, Data\n",
    "from torch_geometric.utils import to_undirected, sort_edge_index, to_networkx\n",
    "\n",
    "# Optuna Module\n",
    "import optuna\n",
    "from optuna.trial import TrialState\n",
    "from optuna.importance import get_param_importances\n",
    "\n",
    "# F1-Score\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "# Parsing\n",
    "import argparse\n",
    "\n",
    "from torchcrf import CRF"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b05c6dff-4214-4fa2-a237-d4da9e26079d",
   "metadata": {},
   "source": [
    "## Load Dataset "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f1db22fe-a623-4874-95c0-df6b4e1b7779",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Define the Class to load the data\n",
    "class MyOwnDataset(Dataset):\n",
    "    def __init__(self, num_data, root, transform=None, pre_transform=None):\n",
    "        super(MyOwnDataset, self).__init__(root, transform, pre_transform)\n",
    "        self.num_data = num_data\n",
    "    @property\n",
    "    def raw_file_names(self):\n",
    "        return []\n",
    "    @property\n",
    "    def processed_file_names(self):\n",
    "        return ['data_{}.pt'.format(idx) for idx in range(self.num_data)]\n",
    "    def len(self):\n",
    "        return len(self.processed_file_names)\n",
    "    def get(self, idx):\n",
    "        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))\n",
    "        return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fd89b112-d609-4a5d-a584-be1aec7d9116",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the Sequential Network\n",
    "def define_model_GATConv(trial, CLASSES,NODE_FEATURES):\n",
    "    n_layers = trial.suggest_int(\"n_layers\", 1, 3)\n",
    "    layers = []\n",
    "    in_features = NODE_FEATURES\n",
    "\n",
    "    out_features = trial.suggest_int(\"n_units_l{}\".format(0), 50, 300)\n",
    "    heads = trial.suggest_int(\"head_l{}\".format(0), 1, 3)\n",
    "    layers.append( (GATConv(in_features, out_features, heads), 'x, edge_index -> x') )\n",
    "    in_features = out_features*heads\n",
    "    layers.append( (lambda x, edge_index: (F.relu(x), to_undirected(edge_index)) , 'x, edge_index -> x, edge_index') )\n",
    "\n",
    "    for i in range(1,n_layers):\n",
    "        out_features = trial.suggest_int(\"n_units_l{}\".format(i), 50, 300)\n",
    "        heads = trial.suggest_int(\"head_l{}\".format(i), 1, 3)\n",
    "        layers.append((GATConv(in_features, out_features, heads), 'x, edge_index -> x'))\n",
    "        layers.append(nn.ReLU(inplace=True))\n",
    "        p = trial.suggest_float(\"dropout_l{}\".format(i), 0.01, 0.1)\n",
    "        layers.append(nn.Dropout(p))\n",
    "        in_features = out_features*heads\n",
    "\n",
    "    n_layers_linear = trial.suggest_int(\"n_layers_linear\", 1, 3)\n",
    "    for i in range(1,n_layers_linear):\n",
    "        out_features = trial.suggest_int(\"n_units_lin_l{}\".format(i), 50, 250)\n",
    "        layers.append(( nn.Linear(in_features, out_features), 'x -> x'))\n",
    "        layers.append(nn.ReLU(inplace=True))\n",
    "        in_features = out_features\n",
    "\n",
    "    layers.append(( nn.Linear(in_features, CLASSES), 'x -> x'))\n",
    "    return Sequential('x, edge_index',layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "df93f109-488a-472a-a72e-a0a40f9e682e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net_large_CRF_end2end(torch.nn.Module):\n",
    "    def __init__(self,trial,num_node_features, num_labels = 3, batch_first=True, output_hidden_states=True,\n",
    "            output_attentions=False, max_length = 64, model_name=\"bert-large-uncased\"):\n",
    "        super(Net_large_CRF_end2end, self).__init__()\n",
    "        \n",
    "        self.num_labels = num_labels\n",
    "        self.batch_first = batch_first\n",
    "        self.max_length = max_length\n",
    "        self.num_node_features = num_node_features\n",
    "        \n",
    "        self.tokenbert = BertForTokenClassification.from_pretrained(\n",
    "            model_name,\n",
    "            num_labels=self.num_labels,\n",
    "            output_hidden_states=output_hidden_states,\n",
    "            output_attentions=output_attentions\n",
    "        )\n",
    "        \n",
    "        self.GNN_attention = define_model_GATConv(trial, self.num_labels, self.num_node_features)\n",
    "        \n",
    "        self.crf = CRF(self.num_labels, batch_first=self.batch_first)\n",
    "\n",
    "    def forward(self, data, labels=None):\n",
    "        max_size_actu = 21\n",
    "        \n",
    "        edge_index  = data.edge_index\n",
    "        \n",
    "        input_ids, attention_mask, token_type_ids = data.input_ids, data.attention_mask, data.token_type_ids\n",
    "        \n",
    "        outputs= self.tokenbert.bert(\n",
    "            input_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            token_type_ids=token_type_ids\n",
    "        ).last_hidden_state \n",
    "        \n",
    "        hidden_states = torch.zeros((outputs.size(0),self.max_length+max_size_actu,self.num_node_features))\n",
    "        hidden_states[:,max_size_actu:,:] = outputs\n",
    "        hidden_states = hidden_states.view(-1,self.num_node_features)\n",
    "        \n",
    "            \n",
    "        out = self.GNN_attention(hidden_states, edge_index)\n",
    "        out = out.view(-1,self.max_length+max_size_actu,3)\n",
    "        \n",
    "        if(labels is not None):    \n",
    "            logits = out[:,max_size_actu:,:]\n",
    "            labels = labels.view((-1,self.max_length))\n",
    "            return -self.crf(emissions = logits, tags=labels,mask=data.attention_mask.byte())\n",
    "        else:\n",
    "            logits = out[:,max_size_actu:,:]\n",
    "            return self.crf.decode(emissions =logits, mask=data.attention_mask.byte())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6f051c31-8eb9-4def-8bd4-ef5f3f9e1c3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define The optuna objective function\n",
    "def objective(trial, DEVICE, EPOCHS, CLASSES, NODE_FEATURES, train_loader, dev_loader, test_loader):\n",
    "    time_string = datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
    "    # Choose the model.\n",
    "    model = Net_large_CRF_end2end(trial, NODE_FEATURES).to(DEVICE)\n",
    "    \n",
    "    # Generate the optimizers.\n",
    "    optimizer_name = trial.suggest_categorical(\"optimizer\", [\"AdamW\"])\n",
    "    lr = trial.suggest_float(\"lr\", 1e-6, 1e-4, log=True)\n",
    "    \n",
    "    param_optimizer = list(model.named_parameters())\n",
    "    no_decay = ['bias', 'gamma', 'beta']\n",
    "    optimizer_grouped_parameters = [\n",
    "        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],\n",
    "         'weight_decay_rate': 0.01},\n",
    "        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],\n",
    "         'weight_decay_rate': 0.0}\n",
    "    ]\n",
    "    \n",
    "    optimizer = getattr(optim, optimizer_name)(optimizer_grouped_parameters, lr=lr)\n",
    "    \n",
    "    lmbda = lambda epoch: 0.99\n",
    "    scheduler =  torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda)\n",
    "    max_grad_norm = trial.suggest_float(\"max_grad_norm\", 1e-1, 1e+1)\n",
    "    \n",
    "    # Training of the model.\n",
    "    for epoch in range(EPOCHS):\n",
    "        model.train()\n",
    "        for idx, (ele) in enumerate(train_loader):\n",
    "            ele = ele.to(DEVICE)\n",
    "            \n",
    "            optimizer.zero_grad()\n",
    "            loss = model(ele,ele[\"y\"])\n",
    "\n",
    "            loss.backward()\n",
    "            # gradient clipping\n",
    "            torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)\n",
    "            optimizer.step()\n",
    "        \n",
    "        model.eval()\n",
    "        \n",
    "        \n",
    "        # initalize the variable to compute the F1 score at token level\n",
    "        total_class_predict_token = []\n",
    "        total_true_class_token = []\n",
    "\n",
    "        with torch.no_grad():\n",
    "            for idx, (ele) in enumerate(dev_loader):\n",
    "                ele = ele.to(DEVICE)\n",
    "                #ele[\"x\"] = ele[\"x\"].float()\n",
    "                class_predict = model(ele)\n",
    "\n",
    "                ## Flatten the result\n",
    "                class_predict_flat = [item for sublist in class_predict for item in sublist]\n",
    "                total_class_predict_token.extend(class_predict_flat)\n",
    "                \n",
    "                ## Extract the true label without the padding\n",
    "                attention_mask = ele[\"attention_mask\"].flatten()\n",
    "                ele_y_clean = torch.where(attention_mask  > 0, ele[\"y\"], 10)\n",
    "                ele_y_clean = ele_y_clean[ele_y_clean!=10].tolist()\n",
    "                total_true_class_token.extend(ele_y_clean)\n",
    "\n",
    "                \n",
    "\n",
    "            f1_token_level = f1_score(total_class_predict_token,total_true_class_token,labels = [0,1,2], average=\"macro\")\n",
    "            trial.report(f1_token_level, epoch)\n",
    "            scheduler.step()\n",
    "            \n",
    "    f = open(time_string+\".txt\", \"a\")\n",
    "    f.write(\"f1 score on token level dev is : \" + str(f1_token_level))\n",
    "    f.close()\n",
    "\n",
    "    return f1_token_level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e44ec0f6-4462-4b76-85bf-5d91ceef3e35",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the main function to call by script\n",
    "def main(actual_number_tested):\n",
    "    \n",
    "    batch_size = 190\n",
    "    num_trials = 300\n",
    "    depth_list_params = [\n",
    "        {\"data_dir\":'../../data/aurc/bert/end_2_end_depth_2_Cross',\n",
    "        \"study_name\":\"Trained_BERT_GNN_CRF_endtoend_depth_2\",\n",
    "        \"storage\":\"sqlite:///../optuna_db/Trained_BERT_large_endtoend_depth_2.db\"},\n",
    "        \n",
    "        {\"data_dir\":\"../../data/aurc/bert/end_2_end\",\n",
    "        \"study_name\":\"Trained_BERT_GNN_CRF_endtoend\",\n",
    "        \"storage\":\"sqlite:///../optuna_db/Trained_BERT_large_endtoend.db\"},\n",
    "        \n",
    "        {\"data_dir\":'../../data/aurc/bert/end_2_end_depth_4_Cross',\n",
    "        \"study_name\":\"Trained_BERT_GNN_CRF_endtoend_depth_4\",\n",
    "        \"storage\":\"sqlite:///../optuna_db/Trained_BERT_large_endtoend_depth_4.db\"},\n",
    "        \n",
    "        \n",
    "        \n",
    "        {\"data_dir\":'../../data/aurc/bert/end_2_end_depth_2_IN',\n",
    "        \"study_name\":\"Trained_BERT_GNN_CRF_endtoend_depth_2_IN\",\n",
    "        \"storage\":\"sqlite:///../optuna_db/Trained_BERT_large_endtoend_depth_2_IN.db\"},\n",
    "        \n",
    "        {\"data_dir\":\"../../data/aurc/bert/end_2_end_IN\",\n",
    "        \"study_name\":\"Trained_BERT_GNN_CRF_endtoend_IN\",\n",
    "        \"storage\":\"sqlite:///../optuna_db/Trained_BERT_large_endtoend_IN.db\"},\n",
    "        \n",
    "        {\"data_dir\":'../../data/aurc/bert/end_2_end_depth_4_IN',\n",
    "        \"study_name\":\"Trained_BERT_GNN_CRF_endtoend_depth_4_IN\",\n",
    "        \"storage\":\"sqlite:///../optuna_db/Trained_BERT_large_endtoend_depth_4_IN.db\"}\n",
    "    ]\n",
    "\n",
    "    data_dir = depth_list_params[actual_number_tested][\"data_dir\"]\n",
    "    EPOCHS = 8\n",
    "    NODE_FEATURES = 1024\n",
    "    DEVICE = \"cpu\"\n",
    "    CLASSES = 3\n",
    "    timeout = 3600*24\n",
    "    num_data_train = 4157\n",
    "    num_data_dev = 593\n",
    "    num_data_test = 1189\n",
    "    #num_data_train = 3960\n",
    "    #num_data_dev = 790\n",
    "    #num_data_test = 1959\n",
    "    \n",
    "    \n",
    "    \n",
    "    train_dataset = MyOwnDataset(num_data = num_data_train,root = data_dir+\"/Train\")\n",
    "    test_dataset = MyOwnDataset(num_data = num_data_test,root = data_dir+\"/Test\")\n",
    "    dev_dataset = MyOwnDataset(num_data = num_data_dev,root = data_dir+\"/Dev\")\n",
    "    \n",
    "    \n",
    "    train_loader = DataLoader(train_dataset, batch_size=batch_size)\n",
    "    dev_loader = DataLoader(dev_dataset, batch_size=batch_size)\n",
    "    test_loader = DataLoader(test_dataset, batch_size=batch_size)\n",
    "    \n",
    "    ## Initialize the optuna research\n",
    "    study = optuna.create_study(\n",
    "        direction=\"maximize\",\n",
    "        study_name = depth_list_params[actual_number_tested][\"study_name\"], \n",
    "        storage = depth_list_params[actual_number_tested][\"storage\"],\n",
    "        load_if_exists=True)\n",
    "    \n",
    "    study.optimize(lambda trial: objective(trial, DEVICE, EPOCHS, CLASSES, NODE_FEATURES, train_loader, dev_loader, test_loader), n_trials=num_trials, timeout=timeout)\n",
    "\n",
    "    pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])\n",
    "    complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])\n",
    "\n",
    "    print(\"Study statistics: \")\n",
    "    print(\"  Number of finished trials: \", len(study.trials))\n",
    "    print(\"  Number of pruned trials: \", len(pruned_trials))\n",
    "    print(\"  Number of complete trials: \", len(complete_trials))\n",
    "\n",
    "    print(\"Best trial:\")\n",
    "    trial = study.best_trial\n",
    "\n",
    "    print(\"  Value: \", trial.value)\n",
    "\n",
    "    print(\"  Params: \")\n",
    "    for key, value in trial.params.items():\n",
    "        print(\"    {}: {}\".format(key, value))\n",
    "    \n",
    "    print(\"Name of the study : \" + study.study_name)\n",
    "    \n",
    "    print(\"  Params Importance: \")\n",
    "    dict_params = get_param_importances(study)\n",
    "    for key, value in dict_params.items():\n",
    "        print(\"    {}: {}\".format(key, value))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5418c2ed-9273-439c-89d4-9c55f9f979d6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m[I 2021-08-18 07:50:34,802]\u001b[0m Using an existing study with name 'Trained_BERT_GNN_CRF_endtoend_depth_4' instead of creating a new one.\u001b[0m\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Study statistics: \n",
      "  Number of finished trials:  53\n",
      "  Number of pruned trials:  0\n",
      "  Number of complete trials:  52\n",
      "Best trial:\n",
      "  Value:  0.6279585463516141\n",
      "  Params: \n",
      "    dropout_l1: 0.011584407920879683\n",
      "    head_l0: 3\n",
      "    head_l1: 3\n",
      "    lr: 2.8165088960394907e-05\n",
      "    max_grad_norm: 9.755862297742242\n",
      "    n_layers: 2\n",
      "    n_layers_linear: 2\n",
      "    n_units_l0: 290\n",
      "    n_units_l1: 103\n",
      "    n_units_lin_l1: 96\n",
      "    optimizer: AdamW\n",
      "Name of the study : Trained_BERT_GNN_CRF_endtoend_depth_4\n",
      "  Params Importance: \n",
      "    max_grad_norm: 0.49597189450000223\n",
      "    lr: 0.2920221608683945\n",
      "    head_l0: 0.07889252240736525\n",
      "    n_layers_linear: 0.05661960351366516\n",
      "    n_units_l0: 0.055713677852698734\n",
      "    n_layers: 0.02078014085787429\n",
      "    optimizer: 0.0\n"
     ]
    }
   ],
   "source": [
    "## Initialize the optuna research\n",
    "\n",
    "depth_list_params = [\n",
    "        {\"data_dir\":'../../data/aurc/bert/end_2_end_depth_2_Cross',\n",
    "        \"study_name\":\"Trained_BERT_GNN_CRF_endtoend_depth_2\",\n",
    "        \"storage\":\"sqlite:///../optuna_db/Trained_BERT_large_endtoend_depth_2.db\"},\n",
    "        \n",
    "        {\"data_dir\":\"../../data/aurc/bert/end_2_end\",\n",
    "        \"study_name\":\"Trained_BERT_GNN_CRF_endtoend\",\n",
    "        \"storage\":\"sqlite:///../optuna_db/Trained_BERT_large_endtoend.db\"},\n",
    "        \n",
    "        {\"data_dir\":'../../data/aurc/bert/end_2_end_depth_4_Cross',\n",
    "        \"study_name\":\"Trained_BERT_GNN_CRF_endtoend_depth_4\",\n",
    "        \"storage\":\"sqlite:///../optuna_db/Trained_BERT_large_endtoend_depth_4.db\"},\n",
    "    \n",
    "    \n",
    "        {\"data_dir\":'../../data/aurc/bert/end_2_end_depth_2_IN',\n",
    "        \"study_name\":\"Trained_BERT_GNN_CRF_endtoend_depth_2_IN\",\n",
    "        \"storage\":\"sqlite:///../optuna_db/Trained_BERT_large_endtoend_depth_2_IN.db\"},\n",
    "        \n",
    "        {\"data_dir\":\"../../data/aurc/bert/end_2_end_IN\",\n",
    "        \"study_name\":\"Trained_BERT_GNN_CRF_endtoend_IN\",\n",
    "        \"storage\":\"sqlite:///../optuna_db/Trained_BERT_large_endtoend_IN.db\"},\n",
    "        \n",
    "        {\"data_dir\":'../../data/aurc/bert/end_2_end_depth_4_IN',\n",
    "        \"study_name\":\"Trained_BERT_GNN_CRF_endtoend_depth_4_IN\",\n",
    "        \"storage\":\"sqlite:///../optuna_db/Trained_BERT_large_endtoend_depth_4_IN.db\"}\n",
    "    ]\n",
    "actual_number_tested = 2\n",
    "\n",
    "study = optuna.create_study(\n",
    "    direction=\"maximize\",\n",
    "    study_name = depth_list_params[actual_number_tested][\"study_name\"], \n",
    "    storage = depth_list_params[actual_number_tested][\"storage\"],\n",
    "    load_if_exists=True)\n",
    "\n",
    "print(\"Study statistics: \")\n",
    "print(\"  Number of finished trials: \", len(study.trials))\n",
    "pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])\n",
    "complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])\n",
    "print(\"  Number of pruned trials: \", len(pruned_trials))\n",
    "print(\"  Number of complete trials: \", len(complete_trials))\n",
    "    \n",
    "print(\"Best trial:\")\n",
    "trial = study.best_trial\n",
    "\n",
    "print(\"  Value: \", trial.value)\n",
    "\n",
    "print(\"  Params: \")\n",
    "for key, value in trial.params.items():\n",
    "    print(\"    {}: {}\".format(key, value))\n",
    "\n",
    "print(\"Name of the study : \" + study.study_name)\n",
    "\n",
    "print(\"  Params Importance: \")\n",
    "dict_params = get_param_importances(study)\n",
    "for key, value in dict_params.items():\n",
    "    print(\"    {}: {}\".format(key, value))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b93139a9-6688-4590-9429-dedc2af9be93",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m[I 2021-08-17 17:10:23,404]\u001b[0m Using an existing study with name 'Trained_BERT_GNN_CRF_endtoend_depth_4_IN' instead of creating a new one.\u001b[0m\n",
      "Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
      "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-large-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "main(2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "concept_env",
   "language": "python",
   "name": "concept_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
