{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8a9f07be-ed84-4748-a8ac-7fd9ef71084a",
   "metadata": {},
   "source": [
    "# BERT model for Argument Unit Recognition and Classification"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3cfdf39e-4987-46a8-b756-e8fae12b54af",
   "metadata": {},
   "source": [
    "This is the code of the model which is only composed of BERT trained for token classification. This code come partly from https://github.com/trtm/AURC\n",
    "\n",
    "BERT is fine tuned on the train dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "yellow-honduras",
   "metadata": {},
   "source": [
    "# Install the libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d222c130-a05e-4dde-a5c3-eca36a2c4c07",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import datetime as dt\n",
    "from tqdm import tqdm\n",
    "import sys\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "from torch.utils.data import (TensorDataset, DataLoader,\n",
    "                              RandomSampler, SequentialSampler)\n",
    "\n",
    "from transformers import (AdamW, BertConfig, get_linear_schedule_with_warmup, \n",
    "                                BertForTokenClassification, BertTokenizer, BertPreTrainedModel, BertModel)\n",
    "\n",
    "from torch_geometric.data import Data\n",
    "from sklearn import metrics\n",
    "\n",
    "sys.path.insert(0, '../../../utils/')\n",
    "from utils_AURC import InputFeatures\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "lasting-cleanup",
   "metadata": {},
   "source": [
    "# Import our Dataset AURC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "saving-friday",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = '../../../data/aurc/bert/large_depth'\n",
    "input_file = 'AURC_DATA_dict.json'\n",
    "num_labels = 3\n",
    "\n",
    "pretrained_weights =  'bert-large-uncased'\n",
    "\n",
    "max_sequence_length = 64\n",
    "\n",
    "# the domains are shuffle\n",
    "target_domain = 'In-Domain'\n",
    "\n",
    "fname = '../../../data/aurc/AURC_DATA_dict.json'\n",
    "\n",
    "# load the json file\n",
    "with open(fname,'r') as my_file:\n",
    "    AURC_DATA_dict = json.load(my_file)\n",
    "print(len(AURC_DATA_dict), [len(AURC_DATA_dict[topic]) for topic in AURC_DATA_dict.keys()])\n",
    "\n",
    "# check the number of example per topic\n",
    "topics = sorted(set(AURC_DATA_dict.keys()))\n",
    "print(len(topics), topics)\n",
    "\n",
    "# define the label to id dictionnary\n",
    "label2id = {}\n",
    "label2id['non'] = 0\n",
    "label2id['con'] = 1\n",
    "label2id['pro'] = 2\n",
    "   \n",
    "# Choose the tokenizer from Hugging Face transformers\n",
    "tokenizer = BertTokenizer.from_pretrained(pretrained_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a535d5dd-3cd3-4263-b63f-c91f097f921e",
   "metadata": {},
   "source": [
    "## Pytorch model definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cade0f5-d9bb-491a-9785-578e1738154a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TokenBERT(nn.Module):\n",
    "\n",
    "    def __init__(self, num_labels, model_name, output_hidden_states=False,\n",
    "            output_attentions=False, batch_first=True):\n",
    "        super(TokenBERT, self).__init__()\n",
    "        \n",
    "        self.num_labels = num_labels\n",
    "        self.batch_first = batch_first\n",
    "        \n",
    "        self.tokenbert = BertForTokenClassification.from_pretrained(\n",
    "            model_name,\n",
    "            num_labels=self.num_labels,\n",
    "            output_hidden_states=output_hidden_states,\n",
    "            output_attentions=output_attentions\n",
    "        )\n",
    "        \n",
    "    def forward(self, input_ids, attention_mask, token_type_ids, labels=None):\n",
    "        \n",
    "        outputs = self.tokenbert.bert(\n",
    "            input_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            token_type_ids=token_type_ids\n",
    "        )\n",
    "        \n",
    "        sequence_output = outputs[0]\n",
    "        sequence_output = self.tokenbert.dropout(sequence_output) \n",
    "        logits = self.tokenbert.classifier(sequence_output)\n",
    "        \n",
    "        if labels is not None: # training\n",
    "                loss_fct = nn.CrossEntropyLoss()\n",
    "                loss = loss_fct(\n",
    "                    logits.view(-1, self.num_labels),\n",
    "                    labels.view(-1)\n",
    "                )\n",
    "                return loss\n",
    "        else: # inference\n",
    "            return torch.argmax(logits, dim=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1c90db3-3b68-4274-a193-fc5936e1cbaa",
   "metadata": {},
   "source": [
    "### Dataset creation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "433ec0fb-b03d-401f-a8cf-27ae0b4b4a57",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_dataset_for_training_BERT(bs):\n",
    "    train_features = []\n",
    "    test_features = []\n",
    "    for topic, AD in AURC_DATA_dict.items():\n",
    "        for ad in AD:\n",
    "            sequence_dict = tokenizer.encode_plus(ad['sentence'], truncation=True, max_length=64, pad_to_max_length=True, add_special_tokens=True)\n",
    "            input_labels = [label2id[label] for label in ad['tokenized_sentence_bert_labels'].split(' ')]\n",
    "            input_labels = [0] + input_labels[:64-1] + [0]*max(0,64-len(input_labels)-1)\n",
    "            sequence_dict['label_ids'] = input_labels\n",
    "            sequence_dict['input_tokens'] = tokenizer.convert_ids_to_tokens(sequence_dict['input_ids'])\n",
    "\n",
    "            for(k, v) in sequence_dict.items():\n",
    "                assert len(v)==64\n",
    "            FE = [\n",
    "                InputFeatures(\n",
    "                    input_ids      = sequence_dict['input_ids'],\n",
    "                    attention_mask = sequence_dict['attention_mask'],\n",
    "                    token_type_ids = sequence_dict['token_type_ids'],\n",
    "                    label_ids      = sequence_dict['label_ids'],\n",
    "                    sentence_hash  = ad['sentence_hash']\n",
    "                )]\n",
    "\n",
    "            if(ad[\"In-Domain\"] == 'Train'):\n",
    "                train_features+=FE\n",
    "            if(ad[\"In-Domain\"] == 'Test'):\n",
    "                test_features+=FE\n",
    "                \n",
    "    # TRAIN DATA\n",
    "    all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)\n",
    "    all_input_mask = torch.tensor([f.attention_mask for f in train_features], dtype=torch.long)\n",
    "    all_segment_ids = torch.tensor([f.token_type_ids for f in train_features], dtype=torch.long)\n",
    "    all_label_ids = torch.tensor([f.label_ids for f in train_features], dtype=torch.long)\n",
    "    train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)\n",
    "    train_sampler = RandomSampler(train_data)\n",
    "    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=bs)\n",
    "    print(len(train_sampler), len(train_dataloader))\n",
    "\n",
    "\n",
    "    # TEST DATA\n",
    "    all_input_ids = torch.tensor([f.input_ids for f in test_features], dtype=torch.long)\n",
    "    all_input_mask = torch.tensor([f.attention_mask for f in test_features], dtype=torch.long)\n",
    "    all_segment_ids = torch.tensor([f.token_type_ids for f in test_features], dtype=torch.long)\n",
    "    all_label_ids = torch.tensor([f.label_ids for f in test_features], dtype=torch.long)\n",
    "    test_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)\n",
    "    test_sampler = SequentialSampler(test_data)\n",
    "    test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=bs)\n",
    "    print(len(test_sampler), len(test_dataloader))\n",
    "    \n",
    "    return(len(train_sampler),train_dataloader,test_dataloader)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a33c7fcd-66e0-4d31-abf7-7bd2059556aa",
   "metadata": {},
   "source": [
    "### Training Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6880dc3a-a332-4ddd-b0cf-75bb63f71aef",
   "metadata": {},
   "outputs": [],
   "source": [
    "def training(train_dataloader, model, device, optimizer, max_grad_norm):\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    nb_tr_examples, nb_tr_steps = 0, 0\n",
    "    predictions_train, true_labels_train = [], []\n",
    "    for step, batch in enumerate(train_dataloader):\n",
    "        # add batch to gpu\n",
    "        batch = tuple(t.to(device) for t in batch)\n",
    "        batch_input_ids, batch_input_mask, batch_sentence_ids, batch_label_ids = batch\n",
    "        \n",
    "        # forward pass\n",
    "        loss = model(\n",
    "            batch_input_ids, \n",
    "            token_type_ids=batch_sentence_ids, \n",
    "            attention_mask=batch_input_mask, \n",
    "            labels=batch_label_ids\n",
    "        )\n",
    "        \n",
    "        # backward pass\n",
    "        loss.backward()\n",
    "\n",
    "        # track train loss\n",
    "        total_loss += loss.item()\n",
    "        nb_tr_examples += batch_input_ids.size(0)\n",
    "        nb_tr_steps += 1\n",
    "        \n",
    "        # gradient clipping\n",
    "        torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)\n",
    "        # update parameters\n",
    "        optimizer.step()\n",
    "        # update learning rate\n",
    "        #scheduler.step()\n",
    "        \n",
    "        model.zero_grad()\n",
    "        \n",
    "    return model, optimizer, total_loss\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e06f301b-5d99-4ba1-b929-47a211be0529",
   "metadata": {},
   "source": [
    "### Evaluation Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "103e6674-47f6-4113-a928-5f260e89cac5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluation(sample_dataloader, model, device, tokenizer):\n",
    "    model.eval()\n",
    "    eval_loss, eval_accuracy = 0, 0\n",
    "    nb_eval_steps, nb_eval_examples = 0, 0\n",
    "    y_true = []\n",
    "    y_pred = []\n",
    "    for step, batch in enumerate(sample_dataloader):\n",
    "        batch = tuple(t.to(device) for t in batch)\n",
    "        batch_input_ids, batch_input_mask, batch_sentence_ids, batch_label_ids = batch\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            batch_tags = model(\n",
    "                batch_input_ids,\n",
    "                token_type_ids=batch_sentence_ids,\n",
    "                attention_mask=batch_input_mask)\n",
    "        \n",
    "        batch_correct_label_ids = []\n",
    "        batch_correct_tags = []\n",
    "        for input_ids, label_ids, tags in zip(batch_input_ids, batch_label_ids, batch_tags):\n",
    "            input_ids = input_ids.cpu().tolist()\n",
    "            label_ids = label_ids.cpu().tolist()\n",
    "            if type(tags)!=list:\n",
    "                tags = tags.cpu().tolist()\n",
    "            seq_len = [i for i,t in enumerate(input_ids) if t==102][0] # 102 is the [SEP] token_id\n",
    "            input_tokens = tokenizer.convert_ids_to_tokens(input_ids)\n",
    "            correct_label_ids = [l for t,l in zip(input_tokens[1:seq_len], label_ids[1:seq_len]) if not t.startswith('##')]\n",
    "            correct_tags = [l for t,l in zip(input_tokens[1:seq_len], tags[1:seq_len]) if not t.startswith('##')]\n",
    "            seq = \" \".join(input_tokens[1:seq_len]).replace(' ##','')\n",
    "            assert len(correct_label_ids)==len(seq.split(' '))\n",
    "            batch_correct_label_ids.append(correct_label_ids)\n",
    "            batch_correct_tags.append(correct_tags)\n",
    "        y_true+=batch_correct_label_ids\n",
    "        y_pred+=batch_correct_tags\n",
    "    \n",
    "    # flatten\n",
    "    YT, YP = [], []\n",
    "    for t, p in zip(y_true, y_pred):\n",
    "        YT+=t\n",
    "        YP+=p\n",
    "    assert len(YT)==len(YP)\n",
    "    \n",
    "    p, r, f1s, s = metrics.precision_recall_fscore_support(y_pred=YP, y_true=YT, average='micro', warn_for=tuple())\n",
    "    return y_true, y_pred, p, r, f1s"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39ecd6bc-0f65-4dda-b642-2c9e4633edca",
   "metadata": {},
   "source": [
    "### Model definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccf5c9b8-6be0-4a3a-97c5-5ba7ea793c0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "device = torch.device( \"cpu\")\n",
    "num_epochs = 10\n",
    "batch_size = 64 \n",
    "already_trained = False\n",
    "#PATH_MODEL = \"../../saves/cross_domain_3_labels_pretrained_large_bert\"\n",
    "fine_tuning = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab69e8c0-7a5c-496e-b9a3-b82d00f49d35",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = TokenBERT(\n",
    "            model_name='bert-large-uncased', \n",
    "            num_labels=3, \n",
    "            output_hidden_states=False).to(device)\n",
    "learning_rate = 1e-5\n",
    "optimizer = AdamW(model.parameters(), lr=learning_rate, eps=1e-8)\n",
    "len_train_features, train_dataloader, test_dataloader = create_dataset_for_training_BERT(batch_size)\n",
    "num_train_optimization_steps = int(len_train_features / batch_size ) * num_epochs\n",
    "scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=num_train_optimization_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf6c6e07-219b-4baf-a6df-a8b056d64083",
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(num_epochs):\n",
    "    print(\"Epoch: %4i\"%epoch, dt.datetime.now())\n",
    "\n",
    "    # TRAINING\n",
    "    model, optimizer, tr_loss = training(train_dataloader, model=model, device=device, optimizer=optimizer, max_grad_norm=1.0)\n",
    "\n",
    "    # EVALUATION: TRAIN SET\n",
    "    y_true_train, y_pred_train, p_train, r_train, f1s_train = evaluation(train_dataloader, model=model, device=device, tokenizer=tokenizer)\n",
    "    print(\"TRAIN:  Pre. %.3f | Rec. %.3f | F1 %.3f\"%(p_train, r_train, f1s_train))\n",
    "\n",
    "    # EVALUATION: TEST SET\n",
    "    y_true_test, y_pred_test, p_test, r_test, f1s_test = evaluation(test_dataloader, model=model, device=device, tokenizer=tokenizer)\n",
    "    print(\"TEST:   Pre. %.3f | Rec. %.3f | F1 %.3f \"%(p_test, r_test, f1s_test))\n",
    "\n",
    "#torch.save(model.state_dict(), PATH_MODEL)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "concept_env",
   "language": "python",
   "name": "concept_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
