{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "42c76c6d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTorch Version:  1.11.0+cu113\n",
      "Torchvision Version:  0.12.0+cu113\n",
      "cpu\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import sys\n",
    "import logging\n",
    "from pathlib import Path\n",
    "import random\n",
    "import tarfile\n",
    "import tempfile\n",
    "import warnings\n",
    "import matplotlib.pyplot as plt\n",
    "# import pandas_path  # Path style access for pandas\n",
    "from tqdm import tqdm\n",
    "import torch                    \n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "from torchvision import datasets, models, transforms\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "import os\n",
    "import copy\n",
    "import pprint\n",
    "print(\"PyTorch Version: \",torch.__version__)\n",
    "print(\"Torchvision Version: \",torchvision.__version__)\n",
    "from sklearn.metrics import confusion_matrix\n",
    "\n",
    "from torch.utils.data import TensorDataset\n",
    "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler\n",
    "\n",
    "from collections import OrderedDict\n",
    "\n",
    "from PIL import ImageFile,Image\n",
    "ImageFile.LOAD_TRUNCATED_IMAGES = True\n",
    "\n",
    "from transformers import BertModel\n",
    "from transformers import BertForSequenceClassification, AdamW, BertConfig\n",
    "\n",
    "\n",
    "import logging\n",
    "logging.propagate = False \n",
    "logging.getLogger().setLevel(logging.ERROR)\n",
    "\n",
    "# WandB – Import the wandb library\n",
    "import wandb\n",
    "\n",
    "from models import  MultimodalFramework\n",
    "from model_utils import set_seed, build_optimizer, MemesDataset, evaluate_f1\n",
    "\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(device)\n",
    "\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "#train_model(model_name, dataloaders_dict, criterion, len_train, len_val, config, path)\n",
    "def train_model(model_name, dataloaders, criterion, config, path):\n",
    "#def train_model(model, dataloaders, criterion, optimizer, len_train, len_val, num_epochs, path):\n",
    "    \n",
    "    set_seed(42)\n",
    "    \"\"\"\n",
    "    if model_name == \"bert\":\n",
    "        model = Bert() #Bert()\n",
    "        #for param in model.bert.parameters():\n",
    "        #    param.requires_grad = False\n",
    "    elif model_name == \"mlp\":\n",
    "        model = Net()\n",
    "    else:\n",
    "        model = ResNet() \n",
    "        #for param in model.resnet18.parameters():\n",
    "        #    param.requires_grad = False\n",
    "    \"\"\"\n",
    "    model = MultimodalFramework()\n",
    "    \n",
    "    torch.cuda.empty_cache()\n",
    "    device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "    model.to(device)\n",
    "    \n",
    "    num_epochs = 1\n",
    "    optimizer = build_optimizer(model, \"adamW\",0.01, 0.9)\n",
    "    #\"\"\"\n",
    "    since = time.time()\n",
    "\n",
    "    val_acc_history = []\n",
    "    val_loss_history = []\n",
    "    train_acc_history = []\n",
    "    train_loss_history = []\n",
    "\n",
    "    #best_model_wts = copy.deepcopy(model.state_dict())\n",
    "    best_acc = 0.0\n",
    "    patience = 5 \n",
    "    trigger = 0\n",
    "    \n",
    "    acc_dict = {}\n",
    "    for epoch in range(num_epochs):\n",
    "        #scheduler.step()\n",
    "        print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n",
    "        print('-' * 10)\n",
    "\n",
    "        # Each epoch has a training and validation phase\n",
    "        for phase in ['train', 'val']:\n",
    "            if phase == 'train':\n",
    "                model.train()  # Set model to training mode\n",
    "            else:\n",
    "                model.eval()   # Set model to evaluate mode\n",
    "\n",
    "            running_loss = 0.0\n",
    "            running_corrects = 0\n",
    "            f1 = 0\n",
    "\n",
    "            for data in dataloaders[phase]:\n",
    "                if model_name == \"bert\":\n",
    "                    inputs, masks, labels = data\n",
    "                    inputs, masks, labels = inputs.to(device), masks.to(device), labels.to(device)\n",
    "                elif model_name == \"mlp\":\n",
    "                    inputs, labels = data\n",
    "                    inputs, labels =inputs.float(), labels.long()\n",
    "                    inputs, labels = inputs.to(device), labels.to(device)\n",
    "                else:\n",
    "                    inputs, labels = data\n",
    "                    inputs, labels = inputs.to(device), labels.to(device)\n",
    "\n",
    "\n",
    "                #print(torch.equal(text_labels,img_labels))\n",
    "\n",
    "                # zero the parameter gradients\n",
    "                optimizer.zero_grad()\n",
    "\n",
    "                # forward\n",
    "                # track history if only in train\n",
    "                with torch.set_grad_enabled(phase == 'train'):\n",
    "                    # Get model outputs and calculate loss\n",
    "                    # Special case for inception because in training it has an auxiliary output. In train\n",
    "                    #   mode we calculate the loss by summing the final output and the auxiliary output\n",
    "                    #   but in testing we only consider the final output.\n",
    "                    if model_name == \"bert\":\n",
    "                        outputs = model([inputs, masks], model_name)\n",
    "\n",
    "                    else:\n",
    "                        outputs = model(inputs, model_name)\n",
    "\n",
    "\n",
    "                    loss = criterion(outputs, labels)\n",
    "\n",
    "                    _, preds = torch.max(outputs, 1)\n",
    "\n",
    "\n",
    "                    # backward + optimize only if in training phase\n",
    "                    if phase == 'train':\n",
    "                        loss.backward()\n",
    "                        optimizer.step()\n",
    "\n",
    "                # statistics\n",
    "                #print(\"text_inp.size(0)\")\n",
    "                #print(text_inp.size(0))\n",
    "\n",
    "                running_loss += loss.item() * inputs.size(0)\n",
    "                running_corrects += torch.sum(preds == labels.data)\n",
    "                f = evaluate_f1(preds, labels.data).to(device)\n",
    "                print(f)\n",
    "                f1 += f\n",
    "                print(\"f1: \" + str(f1))\n",
    "            epoch_loss = running_loss / len(labels)\n",
    "            epoch_acc = running_corrects.double() / len(labels)\n",
    "            epoch_f1 = f1.double() / len(labels)\n",
    "            print(\"epoch f1: \" + str(epoch_f1))\n",
    "\n",
    "            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))\n",
    "\n",
    "            if phase == 'val':\n",
    "                #wandb.log({\"val_loss\": epoch_loss, \"val_f1\": epoch_f1})\n",
    "                acc_dict[epoch] = float(epoch_acc.detach().cpu())\n",
    "                val_acc_history.append(epoch_acc.detach().cpu())\n",
    "                val_loss_history.append(epoch_loss)\n",
    "                torch.save(model.state_dict(), path+\"_current.pth\")\n",
    "                if epoch_acc > best_acc:\n",
    "                    best_acc = epoch_acc\n",
    "                    #best_model_wts = copy.deepcopy(model.state_dict())\n",
    "                    #torch.save(model.state_dict(), path+\"_best.pth\")\n",
    "                #\"\"\"\n",
    "                if (epoch > 10) and (acc_dict[epoch] <= acc_dict[epoch - 10]):\n",
    "                    trigger +=1\n",
    "                    if trigger >= patience:\n",
    "                        return model, {\"train_acc\":train_acc_history, \"val_acc\":val_acc_history,\"train_loss\":train_loss_history, \"val_loss\":val_loss_history}\n",
    "                else:\n",
    "                    trigger = 0\n",
    "                #\"\"\"    \n",
    "            if phase == 'train':\n",
    "                #wandb.log({\"train_loss\": epoch_loss, \"train_acc\": epoch_f1, \"epoch\": epoch})\n",
    "                train_acc_history.append(epoch_acc.detach().cpu())\n",
    "                train_loss_history.append(epoch_loss)\n",
    "\n",
    "\n",
    "    time_elapsed = time.time() - since\n",
    "    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))\n",
    "    print('Best val Acc: {:4f}'.format(best_acc))\n",
    "\n",
    "    # load best model weights\n",
    "    #model.load_state_dict(best_model_wts)\n",
    "    return model, {\"train_acc\":train_acc_history, \"val_acc\":val_acc_history,\"train_loss\":train_loss_history, \"val_loss\":val_loss_history}\n",
    "\n",
    "         \n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c018650e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Random seed set as 42\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0/0\n",
      "----------\n",
      "tensor(0.5000)\n",
      "0.5\n",
      "f1: tensor(0.5000)\n",
      "tensor(0.7500)\n",
      "0.6666666666666666\n",
      "f1: tensor(1.2500)\n",
      "tensor(0.)\n",
      "0.0\n",
      "f1: tensor(1.2500)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1580: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 due to no true nor predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, \"true nor predicted\", \"F-score is\", len(true_sum))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(1.)\n",
      "0.0\n",
      "f1: tensor(2.2500)\n",
      "tensor(0.2500)\n",
      "0.0\n",
      "f1: tensor(2.5000)\n",
      "tensor(0.)\n",
      "0.0\n",
      "f1: tensor(2.5000)\n",
      "tensor(0.5000)\n",
      "0.6666666666666666\n",
      "f1: tensor(3.)\n",
      "tensor(0.2500)\n",
      "0.4\n",
      "f1: tensor(3.2500)\n",
      "tensor(0.7500)\n",
      "0.0\n",
      "f1: tensor(4.)\n",
      "tensor(0.7500)\n",
      "0.0\n",
      "f1: tensor(4.7500)\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_53191/4201735545.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     79\u001b[0m                 \u001b[0;31m#   but in testing we only consider the final output.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     80\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0mmodel_name\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"bert\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 81\u001b[0;31m                     \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmasks\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     82\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     83\u001b[0m                 \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1108\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1109\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1111\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1112\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/gpfs/home/mgolovan/Memes_2/models.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, model)\u001b[0m\n\u001b[1;32m    126\u001b[0m         \u001b[0;32melif\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"bert\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    127\u001b[0m             \u001b[0mtext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmasks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 128\u001b[0;31m             \u001b[0mbert\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmasks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtoken_type_ids\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlast_hidden_state\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    129\u001b[0m             \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbert_classification\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    130\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1108\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1109\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1111\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1112\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/transformers/models/bert/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m   1026\u001b[0m             \u001b[0moutput_attentions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1027\u001b[0m             \u001b[0moutput_hidden_states\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_hidden_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1028\u001b[0;31m             \u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreturn_dict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1029\u001b[0m         )\n\u001b[1;32m   1030\u001b[0m         \u001b[0msequence_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mencoder_outputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1108\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1109\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1111\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1112\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/transformers/models/bert/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m    612\u001b[0m                     \u001b[0mencoder_attention_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    613\u001b[0m                     \u001b[0mpast_key_value\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 614\u001b[0;31m                     \u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    615\u001b[0m                 )\n\u001b[1;32m    616\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1108\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1109\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1111\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1112\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/transformers/models/bert/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[1;32m    496\u001b[0m             \u001b[0mhead_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    497\u001b[0m             \u001b[0moutput_attentions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 498\u001b[0;31m             \u001b[0mpast_key_value\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself_attn_past_key_value\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    499\u001b[0m         )\n\u001b[1;32m    500\u001b[0m         \u001b[0mattention_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself_attention_outputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1108\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1109\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1111\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1112\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/transformers/models/bert/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[1;32m    430\u001b[0m             \u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    431\u001b[0m         )\n\u001b[0;32m--> 432\u001b[0;31m         \u001b[0mattention_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself_outputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden_states\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    433\u001b[0m         \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mattention_output\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself_outputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m  \u001b[0;31m# add attentions if we output them\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    434\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1108\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1109\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1111\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1112\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/transformers/models/bert/modeling_bert.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, input_tensor)\u001b[0m\n\u001b[1;32m    381\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden_states\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_tensor\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    382\u001b[0m         \u001b[0mhidden_states\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdense\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhidden_states\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 383\u001b[0;31m         \u001b[0mhidden_states\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhidden_states\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    384\u001b[0m         \u001b[0mhidden_states\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLayerNorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhidden_states\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0minput_tensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    385\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mhidden_states\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1108\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1109\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1111\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1112\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/torch/nn/modules/dropout.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m     56\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     57\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minplace\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     59\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mdropout\u001b[0;34m(input, p, training, inplace)\u001b[0m\n\u001b[1;32m   1277\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mp\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m0.0\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mp\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1.0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1278\u001b[0m         \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"dropout probability has to be between 0 and 1, \"\u001b[0m \u001b[0;34m\"but got {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1279\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0m_VF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0minplace\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0m_VF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1280\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1281\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "dataloaders = dataloaders_dict\n",
    "set_seed(42)\n",
    "\"\"\"\n",
    "if model_name == \"bert\":\n",
    "    model = Bert() #Bert()\n",
    "    #for param in model.bert.parameters():\n",
    "    #    param.requires_grad = False\n",
    "elif model_name == \"mlp\":\n",
    "    model = Net()\n",
    "else:\n",
    "    model = ResNet() \n",
    "    #for param in model.resnet18.parameters():\n",
    "    #    param.requires_grad = False\n",
    "\"\"\"\n",
    "model = MultimodalFramework()\n",
    "\n",
    "torch.cuda.empty_cache()\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "model.to(device)\n",
    "\n",
    "num_epochs = 1\n",
    "optimizer = build_optimizer(model, \"adamW\",0.01, 0.9)\n",
    "#\"\"\"\n",
    "since = time.time()\n",
    "\n",
    "val_acc_history = []\n",
    "val_loss_history = []\n",
    "train_acc_history = []\n",
    "train_loss_history = []\n",
    "\n",
    "#best_model_wts = copy.deepcopy(model.state_dict())\n",
    "best_acc = 0.0\n",
    "patience = 5 \n",
    "trigger = 0\n",
    "\n",
    "acc_dict = {}\n",
    "for epoch in range(num_epochs):\n",
    "    predicted_labels, ground_truth_labels = [], []\n",
    "    #scheduler.step()\n",
    "    print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n",
    "    print('-' * 10)\n",
    "\n",
    "    # Each epoch has a training and validation phase\n",
    "    for phase in ['train', 'val']:\n",
    "        if phase == 'train':\n",
    "            model.train()  # Set model to training mode\n",
    "        else:\n",
    "            model.eval()   # Set model to evaluate mode\n",
    "\n",
    "        running_loss = 0.0\n",
    "        running_corrects = 0\n",
    "        f1 = 0\n",
    "\n",
    "        for data in dataloaders[phase]:\n",
    "            if model_name == \"bert\":\n",
    "                inputs, masks, labels = data\n",
    "                inputs, masks, labels = inputs.to(device), masks.to(device), labels.to(device)\n",
    "            elif model_name == \"mlp\":\n",
    "                inputs, labels = data\n",
    "                inputs, labels =inputs.float(), labels.long()\n",
    "                inputs, labels = inputs.to(device), labels.to(device)\n",
    "            else:\n",
    "                inputs, labels = data\n",
    "                inputs, labels = inputs.to(device), labels.to(device)\n",
    "\n",
    "\n",
    "            #print(torch.equal(text_labels,img_labels))\n",
    "\n",
    "            # zero the parameter gradients\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            # forward\n",
    "            # track history if only in train\n",
    "            with torch.set_grad_enabled(phase == 'train'):\n",
    "                # Get model outputs and calculate loss\n",
    "                # Special case for inception because in training it has an auxiliary output. In train\n",
    "                #   mode we calculate the loss by summing the final output and the auxiliary output\n",
    "                #   but in testing we only consider the final output.\n",
    "                if model_name == \"bert\":\n",
    "                    outputs = model([inputs, masks], model_name)\n",
    "\n",
    "                else:\n",
    "                    outputs = model(inputs, model_name)\n",
    "\n",
    "\n",
    "                loss = criterion(outputs, labels)\n",
    "\n",
    "                _, preds = torch.max(outputs, 1)\n",
    "\n",
    "\n",
    "                # backward + optimize only if in training phase\n",
    "                if phase == 'train':\n",
    "                    loss.backward()\n",
    "                    optimizer.step()\n",
    "\n",
    "            # statistics\n",
    "            #print(\"text_inp.size(0)\")\n",
    "            #print(text_inp.size(0))\n",
    "\n",
    "            running_loss += loss.item() * inputs.size(0)\n",
    "            running_corrects += torch.sum(preds == labels.data)\n",
    "            predicted_labels.extend(preds.cpu().detach().numpy())\n",
    "            ground_truth_labels.extend(labels.cpu().detach().numpy())\n",
    "            \n",
    "            #abc\n",
    "            \n",
    "            f = evaluate_f1(preds, labels.data).to(device)\n",
    "            \n",
    "            from sklearn.metrics import f1_score   \n",
    "            f1_score = f1_score(labels.cpu().data, preds.cpu())\n",
    "            print(f)\n",
    "            print(f1_score)\n",
    "            f1 += f\n",
    "            print(\"f1: \" + str(f1))\n",
    "        epoch_loss = running_loss / len(dataloaders[phase].dataset)\n",
    "        epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)\n",
    "        epoch_f1 = f1.double() / len(dataloaders[phase].dataset)\n",
    "        print(\"epoch f1: \" + str(epoch_f1))\n",
    "\n",
    "        print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))\n",
    "\n",
    "        if phase == 'val':\n",
    "            #wandb.log({\"val_loss\": epoch_loss, \"val_f1\": epoch_f1})\n",
    "            acc_dict[epoch] = float(epoch_acc.detach().cpu())\n",
    "            val_acc_history.append(epoch_acc.detach().cpu())\n",
    "            val_loss_history.append(epoch_loss)\n",
    "            torch.save(model.state_dict(), path+\"_current.pth\")\n",
    "            if epoch_acc > best_acc:\n",
    "                best_acc = epoch_acc\n",
    "                #best_model_wts = copy.deepcopy(model.state_dict())\n",
    "                #torch.save(model.state_dict(), path+\"_best.pth\")\n",
    "            #\"\"\"\n",
    "            if (epoch > 10) and (acc_dict[epoch] <= acc_dict[epoch - 10]):\n",
    "                trigger +=1\n",
    "                if trigger >= patience:\n",
    "                    print(\"here\")#return model, {\"train_acc\":train_acc_history, \"val_acc\":val_acc_history,\"train_loss\":train_loss_history, \"val_loss\":val_loss_history}\n",
    "            else:\n",
    "                trigger = 0\n",
    "            #\"\"\"    \n",
    "        if phase == 'train':\n",
    "            #wandb.log({\"train_loss\": epoch_loss, \"train_acc\": epoch_f1, \"epoch\": epoch})\n",
    "            train_acc_history.append(epoch_acc.detach().cpu())\n",
    "            train_loss_history.append(epoch_loss)\n",
    "\n",
    "\n",
    "time_elapsed = time.time() - since\n",
    "print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))\n",
    "print('Best val Acc: {:4f}'.format(best_acc))\n",
    "\n",
    "# load best model weights\n",
    "    #model.load_state_dict(best_model_wts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "2895d59f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.3225806451612903"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.metrics import f1_score\n",
    "f1_score(ground_truth_labels, predicted_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1a17ece9",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"bert\"\n",
    "\n",
    "if model_name == \"bert\":\n",
    "    train_inputs = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/train_txt.pt')\n",
    "    val_inputs = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/val_txt.pt')\n",
    "\n",
    "else:\n",
    "    train_inputs = torch.load(\"/users/mgolovan/data/mgolovan/facebook_memes/data/train_img.pt\")\n",
    "    val_inputs = torch.load(\"/users/mgolovan/data/mgolovan/facebook_memes/data/val_img.pt\")\n",
    "\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "\n",
    "\n",
    "train_dataloader = DataLoader(train_inputs, batch_size=4,shuffle=False)\n",
    "\n",
    "val_dataloader= DataLoader(val_inputs, batch_size=4, shuffle=False)\n",
    "\n",
    "dataloaders_dict = {'train':train_dataloader, 'val':val_dataloader}\n",
    "\n",
    "path = '//users/mgolovan/data/mgolovan/facebook_memes/unimodal_models/model_' \n",
    "#train_model(model_name, dataloaders_dict, criterion, \"config\", path)  \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8e4d474d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(dataloaders_dict[\"train\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "242bd6ef",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([[  101,  2043, 19817,  1008,  1050, 15580,  2215,  2000,  2022,  3970,\n",
       "           2021,  2027, 14163, 26065,  2618,  2037,  4230,   102,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0],\n",
       "         [  101,  3342, 12455,  3531,  2393,  2491,  1996,  6355,  1010,  2317,\n",
       "           6494,  4095,  1010, 10041,  2313,  1025,  2031,  2115,  8398,  6793,\n",
       "          12403, 20821,  2030, 11265, 19901,  2098,   102,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0],\n",
       "         [  101,  7489,  6206, 12114,   102,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0],\n",
       "         [  101,  1016,  1009,  1016,  2003, 14557, 15718,  1015,  2003, 13433,\n",
       "           2080,  4248,  8785,   999,   102,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0]]),\n",
       " tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0],\n",
       "         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0],\n",
       "         [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0],\n",
       "         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0]]),\n",
       " tensor([1, 1, 1, 0])]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7d2a8795",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTorch Version:  1.11.0+cu113\n",
      "Torchvision Version:  0.12.0+cu113\n",
      "cpu\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the bert: 67 %\n",
      "cpu\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the bert: 70 %\n",
      "cpu\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the bert: 70 %\n",
      "cpu\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the bert: 68 %\n",
      "cpu\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the bert: 69 %\n",
      "cpu\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the bert: 69 %\n",
      "cpu\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the bert: 68 %\n",
      "cpu\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the bert: 68 %\n",
      "cpu\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the bert: 69 %\n",
      "cpu\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the bert: 70 %\n",
      "AUROC         0.643529\n",
      "accuracy     68.800000\n",
      "precision    66.177446\n",
      "recall       64.352910\n",
      "f1-score     64.797999\n",
      "dtype: float64\n",
      "AUROC        0.009263\n",
      "accuracy     1.032796\n",
      "precision    1.116388\n",
      "recall       0.926337\n",
      "f1-score     0.975161\n",
      "dtype: float64\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/ipykernel_launcher.py:135: FutureWarning: Dropping of nuisance columns in DataFrame reductions (with 'numeric_only=None') is deprecated; in a future version this will raise TypeError.  Select only valid columns before calling the reduction.\n",
      "/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/ipykernel_launcher.py:136: FutureWarning: Dropping of nuisance columns in DataFrame reductions (with 'numeric_only=None') is deprecated; in a future version this will raise TypeError.  Select only valid columns before calling the reduction.\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import sys\n",
    "import logging\n",
    "from pathlib import Path\n",
    "import random\n",
    "import tarfile\n",
    "import tempfile\n",
    "import warnings\n",
    "import matplotlib.pyplot as plt\n",
    "# import pandas_path  # Path style access for pandas\n",
    "from tqdm import tqdm\n",
    "import torch                    \n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "from torchvision import datasets, models, transforms\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "import os\n",
    "import copy\n",
    "print(\"PyTorch Version: \",torch.__version__)\n",
    "print(\"Torchvision Version: \",torchvision.__version__)\n",
    "\n",
    "from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score\n",
    "\n",
    "from torch.utils.data import TensorDataset\n",
    "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler,Dataset\n",
    " \n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "from transformers import BertTokenizer\n",
    "from transformers import BertForSequenceClassification, AdamW, BertConfig, BertModel\n",
    "\n",
    "\n",
    "from models import MultimodalFramework\n",
    "from model_utils import set_seed, build_optimizer, ReviewsDataset\n",
    "\n",
    " #'/gpfs/home/mgolovan/data/mgolovan/Reviews/amazon/home/test_rvw_inputs.pt'\n",
    "model_name = \"bert_resnet_l\"\n",
    "lr = 0.00005\n",
    "epochs = 34\n",
    "batch_size = int(32)\n",
    "#best_model_1e-06_22_adamW_20_resnet.pth\n",
    "random_seeds = [15, 0, 1, 67, 128, 87, 261, 510, 340, 22] # \n",
    "df = pd.DataFrame(columns = ['AUROC','accuracy', \"precision\", \"recall\", \"f1-score\", \"CM\", \"CR\"])\n",
    "\n",
    "for seed in random_seeds:\n",
    "    model_path = '/users/mgolovan/data/mgolovan/facebook_memes/bert_resnet_models/best_model_'+str(lr)+'_' + str(seed)+'_adamW_' + str(epochs)+'_' + str(model_name)+ '.pth_current.pth'\n",
    "\n",
    "    device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "    print(device)\n",
    "\n",
    "    torch.cuda.empty_cache()\n",
    "\n",
    "    model = MultimodalFramework()\n",
    "    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) #eager-sweep-1\n",
    "    model.to(device)\n",
    "\n",
    "    test_inputs_txt = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/test_txt.pt')\n",
    "    test_inputs_img = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/test_img.pt')\n",
    "\n",
    "    if model_name.split(\"_\")[:2] == [\"bert\", \"resnet\"]:\n",
    "        modality_1 = DataLoader(test_inputs_img, batch_size=batch_size)\n",
    "        modality_2 = DataLoader(test_inputs_txt, batch_size=batch_size) \n",
    "\n",
    "    elif model_name.split(\"_\")[:2] == [\"resnet\", \"mlp\"]:\n",
    "        modality_1 = DataLoader(test_inputs_img, batch_size=batch_size) \n",
    "        modality_2 = DataLoader(test_inputs_tab, batch_size=batch_size)\n",
    "\n",
    "    else:\n",
    "        modality_1 = DataLoader(test_inputs_tab, batch_size=batch_size) \n",
    "        modality_2 = DataLoader(test_inputs_txt, batch_size=batch_size)\n",
    "\n",
    "\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    pred = []\n",
    "    test_labels = []\n",
    "\n",
    "    # since we're not training, we don't need to calculate the gradients for our outputs\n",
    "    with torch.no_grad():\n",
    "        for modality1, modality2 in zip(modality_1, modality_2):\n",
    "\n",
    "            if model_name.split(\"_\")[:2] == [\"bert\", \"resnet\"]:\n",
    "                text_inp, masks, text_labels = modality2\n",
    "                img_inp, labels = modality1\n",
    "\n",
    "                text_inp, masks, text_labels = text_inp.to(device), masks.to(device), text_labels.to(device)\n",
    "                img_inp, labels = img_inp.to(device), labels.to(device)\n",
    "\n",
    "                outputs = model([img_inp, text_inp, masks], model_name)\n",
    "\n",
    "            elif model_name.split(\"_\")[:2] == [\"resnet\", \"mlp\"]:\n",
    "                img_inp, labels = modality1\n",
    "                tab_inp, tab_labels = modality2\n",
    "                tab_inp, tab_labels = tab_inp.float(), tab_labels.long()\n",
    "\n",
    "                tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)\n",
    "                img_inp, labels = img_inp.to(device), labels.to(device)\n",
    "\n",
    "                outputs = model([tab_inp, img_inp], model_name)\n",
    "            else:\n",
    "                tab_inp, tab_labels = modality1\n",
    "                text_inp, masks, labels = modality2\n",
    "                tab_inp, tab_labels = tab_inp.float(), tab_labels.long()\n",
    "\n",
    "                tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)\n",
    "                text_inp, masks, labels = text_inp.to(device), masks.to(device), labels.to(device)\n",
    "\n",
    "                outputs = model([tab_inp, text_inp, masks], model_name)\n",
    "\n",
    "            test_labels.extend(np.array(labels.cpu()))\n",
    "            _, predicted = torch.max(outputs, 1)\n",
    "            pred.extend(predicted.cpu().numpy())\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "\n",
    "    acc= 100 * correct // total\n",
    "    print(f'Accuracy of the bert: {100 * correct // total} %')\n",
    "\n",
    "    test_labels = np.array(test_labels)\n",
    "\n",
    "    #print(confusion_matrix(test_labels, pred))\n",
    "    cm = confusion_matrix(test_labels, pred)\n",
    "    #print(classification_report(test_labels, pred))\n",
    "    cr = classification_report(test_labels, pred, output_dict=True)\n",
    "    auc = roc_auc_score(test_labels, pred)\n",
    "    df = df.append({'AUROC': auc,'accuracy': acc, \"precision\":cr[\"macro avg\"][\"precision\"]*100 ,\n",
    "                    \"recall\":cr[\"macro avg\"][\"recall\"]*100, \"f1-score\":cr[\"macro avg\"][\"f1-score\"]*100,\n",
    "                    \"CM\":cm, \"CR\":cr}, ignore_index=True)\n",
    "\n",
    "df.to_csv(model_name + \"_current_results.csv\")\n",
    "print(df.mean())\n",
    "print(df.std())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3fa310e3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/ipykernel_launcher.py:1: FutureWarning: Dropping of nuisance columns in DataFrame reductions (with 'numeric_only=None') is deprecated; in a future version this will raise TypeError.  Select only valid columns before calling the reduction.\n",
      "  \"\"\"Entry point for launching an IPython kernel.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Unnamed: 0     3.027650\n",
       "AUROC          0.048424\n",
       "accuracy       2.796824\n",
       "precision     12.079991\n",
       "recall         4.842412\n",
       "f1-score       9.065280\n",
       "dtype: float64"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.read_csv(\"bert_resnet_luong_current_results.csv\").std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b1b440a5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import logging\n",
    "import random\n",
    "# import pandas_path  # Path style access for pandas\n",
    "from tqdm import tqdm\n",
    "import torch                    \n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "from torchvision import datasets, models, transforms\n",
    "from torch import flatten\n",
    "\n",
    "from collections import OrderedDict\n",
    "from transformers import BertModel, DistilBertModel\n",
    "\n",
    "from PIL import ImageFile,Image\n",
    "ImageFile.LOAD_TRUNCATED_IMAGES = True\n",
    "\n",
    "from transformers import BertModel\n",
    "\n",
    "\n",
    "class Attention(torch.nn.Module):\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__()\n",
    "        self.encoder_dim = encoder_dim\n",
    "        self.decoder_dim = decoder_dim\n",
    "\n",
    "    def forward(self, \n",
    "        query: torch.Tensor,  # [decoder_dim]\n",
    "        values: torch.Tensor, # [seq_length, encoder_dim]\n",
    "        ):\n",
    "        weights = self._get_weights(query, values) # [seq_length]\n",
    "        weights = torch.nn.functional.softmax(weights, dim=0)\n",
    "        return weights @ values  # [encoder_dim]\n",
    "\n",
    "class MultiplicativeAttention(Attention):\n",
    "\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__(encoder_dim, decoder_dim)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(\n",
    "            self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))\n",
    "\n",
    "    def _get_weights(self,\n",
    "        query: torch.Tensor,  # [decoder_dim]\n",
    "        values: torch.Tensor, # [seq_length, encoder_dim]\n",
    "    ):\n",
    "        weights = query @ self.W @ values.T  # [seq_length]\n",
    "        return weights #/np.sqrt(self.decoder_dim)    \n",
    "\n",
    "\n",
    "class OneVSOthers(Attention):\n",
    "\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__(encoder_dim, decoder_dim)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))\n",
    "\n",
    "    def _get_weights(self,others, main):\n",
    "        mean = sum(others) / len(others)\n",
    "        weights = mean @ self.W @ main.T  # [seq_length]\n",
    "        return weights #/np.sqrt(self.decoder_dim) \n",
    "    \n",
    "class MultimodalFramework(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super(MultimodalFramework, self).__init__()\n",
    "        ##MLP\n",
    "        self.fc1 = nn.Linear(53, 256)\n",
    "        self.fc2 = nn.Linear(256, 256)\n",
    "        self.fc3 = nn.Linear(256, 2)\n",
    "        self.relu = nn.ReLU()\n",
    "        \n",
    "        ##RESNET\n",
    "        self.resnet18 = models.resnet18(pretrained=True)\n",
    "        n_inputs = self.resnet18.fc.in_features\n",
    "\n",
    "        self.resnet18.fc = nn.Sequential(OrderedDict([\n",
    "            ('fc1', nn.Linear(n_inputs, 512))\n",
    "        ])) \n",
    "\n",
    "        self.resnet_classification = nn.Linear(512, 2) #4\n",
    "        \n",
    "        ##BERT\n",
    "        self.bert = BertModel.from_pretrained('bert-base-uncased') \n",
    "        self.bert_classification = nn.Linear(768, 2)\n",
    "        \n",
    "        #Two Modality models\n",
    "        self.bert_resnet_classification = nn.Linear(512 + 768, 2)\n",
    "        self.bert_mlp_classification = nn.Linear(256 + 768, 2)\n",
    "        self.resnet_mlp_classification = nn.Linear(256 + 512, 2)\n",
    "        self.bert_resnet_mlp_classification = nn.Linear(256 + 512 +768, 2)\n",
    "        self.bert_resnet_mlp_l_classification = nn.Linear(256*3, 2)\n",
    "        self.att_classification = nn.Linear(256*2, 2)\n",
    "        self.vaswani_3_classification = nn.Linear(3 *(256*2), 2)\n",
    "        self.OvO_classification = nn.Linear(3*256, 2)\n",
    "\n",
    "        self.res_wrap = nn.Linear(512, 256)\n",
    "        self.bert_wrap = nn.Linear(768, 256)\n",
    "        \n",
    "        self.luong_attention  = MultiplicativeAttention(encoder_dim=256, decoder_dim=256)\n",
    "        self.vaswani_attention  = nn.MultiheadAttention(256, 2, batch_first = True)\n",
    "        self.OvO_attention = OneVSOthers(encoder_dim=256, decoder_dim=256)\n",
    "        self.OvO_multihead_attention = MultiHeadAttention(256,2, typ = \"OvO\")\n",
    "        self.luong_multihead_attention = MultiHeadAttention(256,2, typ = \"luong\")\n",
    "        \n",
    "        \n",
    "    def bi_directional_att(self, pair):\n",
    "        x = pair[0]\n",
    "        y = pair[1]\n",
    "        attn_output_LV, attn_output_weights_LV = self.vaswani_attention(x, y, y)\n",
    "        attn_output_VL, attn_output_weights_VL = self.vaswani_attention(y, x, x)\n",
    "        combined = torch.cat((attn_output_LV,\n",
    "                              attn_output_VL), dim=1)\n",
    "        return combined\n",
    "\n",
    "    def forward(self, x, model):\n",
    "        if model == \"mlp\":\n",
    "            x = self.fc1(x)\n",
    "            x = self.relu(x)\n",
    "            x = self.fc2(x)\n",
    "            x = self.relu(x)\n",
    "            out = self.fc3(x)\n",
    "            \n",
    "        elif model == \"resnet\":\n",
    "            res = self.resnet18(x)       \n",
    "            out = self.resnet_classification(res)\n",
    "        \n",
    "        elif model == \"bert\":\n",
    "            text, masks = x\n",
    "            bert = self.bert(text, attention_mask=masks, token_type_ids=None).last_hidden_state[:,0,:]\n",
    "            out = self.bert_classification(bert)\n",
    "            \n",
    "        elif model == \"bert_resnet\":\n",
    "            img, text, masks = x\n",
    "            res_emb = self.resnet18(img)\n",
    "            \n",
    "            bert_emb = self.bert(text, attention_mask=masks).last_hidden_state[:,0,:]\n",
    "            combined = torch.cat((res_emb,\n",
    "                                  bert_emb), dim=1)\n",
    "            out = self.bert_resnet_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_resnet_l\":\n",
    "            img, text, masks = x\n",
    "            res_emb = self.resnet18(img)\n",
    "            res = self.res_wrap(res_emb)\n",
    "            \n",
    "            bert_emb = self.bert(text, attention_mask=masks).last_hidden_state[:,0,:]\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "            \n",
    "            combined = torch.cat((res,\n",
    "                                  bert), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_luong\":\n",
    "            img, text, masks = x\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "\n",
    "            attn_output_LV = self.luong_attention(bert, res)\n",
    "            attn_output_VL = self.luong_attention(res, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_vaswani\":\n",
    "            img, text, masks = x\n",
    "            res_emb = self.resnet18(img)\n",
    "        \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            res = self.res_wrap(res_emb)\n",
    "            res = res[:, None, :]\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "            bert = bert[:, None, :]\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, res, res)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(res, bert, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV.squeeze(1),\n",
    "                                  attn_output_VL.squeeze(1)), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_mlp\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            \n",
    "            combined = torch.cat((bert,feat), dim=1)\n",
    "            out = self.bert_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_mlp_luong\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "\n",
    "            attn_output_LV = self.luong_attention(bert, feat)\n",
    "            attn_output_VL = self.luong_attention(feat, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)   \n",
    "            \n",
    "        elif model == \"bert_mlp_vaswani\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, feat, feat)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, bert, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"resnet_mlp\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            \n",
    "            combined = torch.cat((feat,res), dim=1)\n",
    "            out = self.resnet_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"resnet_mlp_luong\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            attn_output_LV = self.luong_attention(res, feat)\n",
    "            attn_output_VL = self.luong_attention(feat, res)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)   \n",
    "            \n",
    "        elif model == \"resnet_mlp_vaswani\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(res, feat, feat)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, res, res)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            \n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_resnet_mlp\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            res = self.resnet18(img)\n",
    "\n",
    "            \n",
    "            combined = torch.cat((bert,feat, res), dim=1)\n",
    "            out = self.bert_resnet_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_resnet_mlp_l\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            combined = torch.cat((bert, feat, res), dim=1)\n",
    "            out = self.bert_resnet_mlp_l_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_mlp_vaswani\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "            \n",
    "            pairs = [[feat, bert],[feat,res],[bert,res]]\n",
    "        \n",
    "            results = []\n",
    "            for pair in pairs:\n",
    "                combined = self.bi_directional_att(pair)\n",
    "                results.append(combined)\n",
    "\n",
    "            comb = torch.cat(results, dim=1)\n",
    "            out = self.vaswani_3_classification(comb)\n",
    "            \n",
    "        else:\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            print(res.shape)\n",
    "            res = self.res_wrap(res)\n",
    "            \n",
    "            attn_txt = self.OvO_concat_attention([feat, res], bert)\n",
    "            attn_img = self.OvO_concat_attention([feat, bert],res)\n",
    "            attn_tab = self.OvO_concat_attention([bert, res], feat)\n",
    "\n",
    "            comb = torch.cat([attn_txt, attn_img, attn_tab], dim=1)\n",
    "            out = self.OvO_classification(comb)\n",
    "\n",
    "        return out\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "174ed459",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTorch Version:  1.11.0+cu113\n",
      "Torchvision Version:  0.12.0+cu113\n",
      "cpu\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the bert: 62 %\n",
      "cpu\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the bert: 58 %\n",
      "cpu\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the bert: 61 %\n",
      "cpu\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the bert: 62 %\n",
      "cpu\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the bert: 59 %\n",
      "cpu\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the bert: 59 %\n",
      "cpu\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the bert: 60 %\n",
      "cpu\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the bert: 58 %\n",
      "cpu\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the bert: 63 %\n",
      "cpu\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the bert: 61 %\n",
      "AUROC         0.562000\n",
      "accuracy     60.300000\n",
      "precision    56.513300\n",
      "recall       56.200030\n",
      "f1-score     56.265014\n",
      "dtype: float64\n",
      "AUROC        0.016231\n",
      "accuracy     1.766981\n",
      "precision    1.751285\n",
      "recall       1.623102\n",
      "f1-score     1.674605\n",
      "dtype: float64\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/ipykernel_launcher.py:120: FutureWarning: Dropping of nuisance columns in DataFrame reductions (with 'numeric_only=None') is deprecated; in a future version this will raise TypeError.  Select only valid columns before calling the reduction.\n",
      "/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/ipykernel_launcher.py:121: FutureWarning: Dropping of nuisance columns in DataFrame reductions (with 'numeric_only=None') is deprecated; in a future version this will raise TypeError.  Select only valid columns before calling the reduction.\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import sys\n",
    "import logging\n",
    "from pathlib import Path\n",
    "import random\n",
    "import tarfile\n",
    "import tempfile\n",
    "import warnings\n",
    "import matplotlib.pyplot as plt\n",
    "# import pandas_path  # Path style access for pandas\n",
    "from tqdm import tqdm\n",
    "import torch                    \n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "from torchvision import datasets, models, transforms\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "import os\n",
    "import copy\n",
    "print(\"PyTorch Version: \",torch.__version__)\n",
    "print(\"Torchvision Version: \",torchvision.__version__)\n",
    "\n",
    "from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score\n",
    "\n",
    "from torch.utils.data import TensorDataset\n",
    "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler,Dataset\n",
    " \n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "from transformers import BertTokenizer\n",
    "from transformers import BertForSequenceClassification, AdamW, BertConfig, BertModel\n",
    "\n",
    "\n",
    "#from models import MultimodalFramework\n",
    "from model_utils import set_seed, build_optimizer, ReviewsDataset\n",
    "\n",
    " #'/gpfs/home/mgolovan/data/mgolovan/Reviews/amazon/home/test_rvw_inputs.pt'\n",
    "model_name = \"resnet\"\n",
    "lr = 0.001\n",
    "epochs = 25\n",
    "batch_size = int(32)\n",
    "#best_model_1e-06_22_adamW_20_resnet.pth\n",
    "random_seeds = [15, 0, 1, 67, 128, 87, 261, 510, 340, 22] # \n",
    "df = pd.DataFrame(columns = ['AUROC','accuracy', \"precision\", \"recall\", \"f1-score\", \"CM\", \"CR\"])\n",
    "\n",
    "for seed in random_seeds:\n",
    "    model_path = '/users/mgolovan/data/mgolovan/facebook_memes/unimodal_models/model_'+str(lr)+'_' + str(seed)+'_adamW_' + str(epochs)+'_' + str(model_name)+ '_current.pth'\n",
    "\n",
    "    device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "    print(device)\n",
    "\n",
    "    torch.cuda.empty_cache()\n",
    "\n",
    "    model = MultimodalFramework()\n",
    "\n",
    "    if model_name == \"bert\":\n",
    "        test_inputs = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/test_txt.pt')\n",
    "        test_dataloader = DataLoader(test_inputs, batch_size=batch_size)\n",
    "\n",
    "    elif model_name == \"resnet\":\n",
    "        test_inputs = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/test_img.pt')\n",
    "        test_dataloader = DataLoader(test_inputs, batch_size=batch_size)\n",
    "    else:\n",
    "        test_inputs = torch.load('/gpfs/home/mgolovan/data/mgolovan/Reviews/amazon/Electronics/test_rvw_binary_tab.pt')\n",
    "        test_dataloader = DataLoader(test_inputs, batch_size=batch_size)\n",
    "\n",
    "    model.load_state_dict(torch.load(model_path,map_location=torch.device('cpu'))) #eager-sweep-1\n",
    "    model.to(device)\n",
    "\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    pred = []\n",
    "    test_labels = []\n",
    "\n",
    "    # since we're not training, we don't need to calculate the gradients for our outputs\n",
    "    with torch.no_grad():\n",
    "        for data in test_dataloader:\n",
    "            if model_name == \"bert\":\n",
    "                inputs, masks, labels = data\n",
    "                inputs, masks, labels = inputs.to(device), masks.to(device), labels.to(device)\n",
    "                outputs = model([inputs, masks], model_name)\n",
    "\n",
    "            elif model_name == \"resnet\":\n",
    "                inputs, labels = data\n",
    "                inputs, labels = inputs.to(device), labels.to(device)\n",
    "                outputs = model(inputs, model_name)\n",
    "            else:\n",
    "                inputs, labels = data\n",
    "                inputs, labels = inputs.to(device), labels.to(device)\n",
    "                inputs, labels =inputs.float(), labels.long()\n",
    "                outputs = model(inputs, model_name)\n",
    "\n",
    "            test_labels.extend(np.array(labels.cpu()))\n",
    "            _, predicted = torch.max(outputs, 1)\n",
    "            pred.extend(predicted.cpu().numpy())\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "\n",
    "    acc= 100 * correct // total\n",
    "    print(f'Accuracy of the bert: {100 * correct // total} %')\n",
    "\n",
    "    test_labels = np.array(test_labels)\n",
    "\n",
    "    #print(confusion_matrix(test_labels, pred))\n",
    "    cm = confusion_matrix(test_labels, pred)\n",
    "    #print(classification_report(test_labels, pred))\n",
    "    cr = classification_report(test_labels, pred, output_dict=True)\n",
    "    auc = roc_auc_score(test_labels, pred)\n",
    "    df = df.append({'AUROC': auc,'accuracy': acc, \"precision\":cr[\"macro avg\"][\"precision\"]*100 ,\n",
    "                    \"recall\":cr[\"macro avg\"][\"recall\"]*100, \"f1-score\":cr[\"macro avg\"][\"f1-score\"]*100,\n",
    "                    \"CM\":cm, \"CR\":cr}, ignore_index=True)\n",
    "\n",
    "df.to_csv(model_name + \"_current_results.csv\")\n",
    "\n",
    "\n",
    "df.to_csv(model_name + \"_current_results.csv\")\n",
    "print(df.mean())\n",
    "print(df.std())\n",
    "    \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "e06047b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "i = torch.rand((64,256))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "8540275e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0136, 0.0166, 0.0119,  ..., 0.0106, 0.0193, 0.0188],\n",
       "        [0.0217, 0.0102, 0.0218,  ..., 0.0102, 0.0110, 0.0099],\n",
       "        [0.0234, 0.0094, 0.0222,  ..., 0.0246, 0.0182, 0.0183],\n",
       "        ...,\n",
       "        [0.0103, 0.0122, 0.0147,  ..., 0.0236, 0.0208, 0.0157],\n",
       "        [0.0248, 0.0157, 0.0184,  ..., 0.0111, 0.0204, 0.0153],\n",
       "        [0.0109, 0.0236, 0.0130,  ..., 0.0199, 0.0097, 0.0097]])"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "F.softmax(i, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "ece7c845",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0035, 0.0043, 0.0030,  ..., 0.0027, 0.0050, 0.0048],\n",
       "        [0.0055, 0.0026, 0.0055,  ..., 0.0026, 0.0028, 0.0025],\n",
       "        [0.0058, 0.0023, 0.0055,  ..., 0.0061, 0.0045, 0.0046],\n",
       "        ...,\n",
       "        [0.0026, 0.0031, 0.0037,  ..., 0.0061, 0.0053, 0.0040],\n",
       "        [0.0059, 0.0038, 0.0044,  ..., 0.0027, 0.0049, 0.0037],\n",
       "        [0.0027, 0.0059, 0.0032,  ..., 0.0050, 0.0024, 0.0024]])"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "F.softmax(i, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "3a463b78",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0035, 0.0043, 0.0030,  ..., 0.0027, 0.0050, 0.0048],\n",
       "        [0.0055, 0.0026, 0.0055,  ..., 0.0026, 0.0028, 0.0025],\n",
       "        [0.0058, 0.0023, 0.0055,  ..., 0.0061, 0.0045, 0.0046],\n",
       "        ...,\n",
       "        [0.0026, 0.0031, 0.0037,  ..., 0.0061, 0.0053, 0.0040],\n",
       "        [0.0059, 0.0038, 0.0044,  ..., 0.0027, 0.0049, 0.0037],\n",
       "        [0.0027, 0.0059, 0.0032,  ..., 0.0050, 0.0024, 0.0024]])"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "F.softmax(i, dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f3efde57",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import logging\n",
    "import random\n",
    "# import pandas_path  # Path style access for pandas\n",
    "from tqdm import tqdm\n",
    "import torch                    \n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "from torchvision import datasets, models, transforms\n",
    "from torch import flatten\n",
    "\n",
    "from collections import OrderedDict\n",
    "from transformers import BertModel, DistilBertModel\n",
    "\n",
    "from PIL import ImageFile,Image\n",
    "ImageFile.LOAD_TRUNCATED_IMAGES = True\n",
    "\n",
    "from transformers import BertModel\n",
    "\n",
    "\n",
    "class Attention(torch.nn.Module):\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__()\n",
    "        self.encoder_dim = encoder_dim\n",
    "        self.decoder_dim = decoder_dim\n",
    "\n",
    "    def forward(self, \n",
    "        query: torch.Tensor,  # [decoder_dim]\n",
    "        values: torch.Tensor, # [seq_length, encoder_dim]\n",
    "        ):\n",
    "        weights = self._get_weights(query, values) # [seq_length]\n",
    "        weights = torch.nn.functional.softmax(weights, dim=0)\n",
    "        return weights @ values  # [encoder_dim]\n",
    "\n",
    "class MultiplicativeAttention(Attention):\n",
    "\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__(encoder_dim, decoder_dim)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(\n",
    "            self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))\n",
    "\n",
    "    def _get_weights(self,\n",
    "        query: torch.Tensor,  # [decoder_dim]\n",
    "        values: torch.Tensor, # [seq_length, encoder_dim]\n",
    "    ):\n",
    "        weights = query @ self.W @ values.T  # [seq_length]\n",
    "        return weights #/np.sqrt(self.decoder_dim)    \n",
    "\n",
    "\n",
    "class OneVSOthers(Attention):\n",
    "\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__(encoder_dim, decoder_dim)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))\n",
    "\n",
    "    def _get_weights(self,others, main):\n",
    "        mean = sum(others) / len(others)\n",
    "        weights = mean @ self.W @ main.T  # [seq_length]\n",
    "        return weights #/np.sqrt(self.decoder_dim) \n",
    "    \n",
    "class OneVSOthers_concat(Attention):\n",
    "\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__(encoder_dim, decoder_dim)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))\n",
    "\n",
    "    def _get_weights(self,others, main):\n",
    "        m2 = others[0]\n",
    "        m3 = others[1]\n",
    "        con = torch.cat((m2, m3), dim=1)\n",
    "        weights = con @ self.W @ main.T  # [seq_length]\n",
    "        return weights #/np.sqrt(self.decoder_dim) \n",
    "    \n",
    "class MultimodalFramework(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super(MultimodalFramework, self).__init__()\n",
    "        ##MLP\n",
    "        self.fc1 = nn.Linear(53, 256)\n",
    "        self.fc2 = nn.Linear(256, 256)\n",
    "        self.fc3 = nn.Linear(256, 2)\n",
    "        self.relu = nn.ReLU()\n",
    "        \n",
    "        ##RESNET\n",
    "        self.resnet18 = models.resnet18(pretrained=True)\n",
    "        n_inputs = self.resnet18.fc.in_features\n",
    "\n",
    "        self.resnet18.fc = nn.Sequential(OrderedDict([\n",
    "            ('fc1', nn.Linear(n_inputs, 512))\n",
    "        ])) \n",
    "\n",
    "        self.resnet_classification = nn.Linear(512, 2) #4\n",
    "        \n",
    "        ##BERT\n",
    "        self.bert = BertModel.from_pretrained('bert-base-uncased') \n",
    "        self.bert_classification = nn.Linear(768, 2)\n",
    "        \n",
    "        #Two Modality models\n",
    "        self.bert_resnet_classification = nn.Linear(512 + 768, 2)\n",
    "        self.bert_mlp_classification = nn.Linear(256 + 768, 2)\n",
    "        self.resnet_mlp_classification = nn.Linear(256 + 512, 2)\n",
    "        self.bert_resnet_mlp_classification = nn.Linear(256 + 512 +768, 2)\n",
    "        self.bert_resnet_mlp_l_classification = nn.Linear(256*3, 2)\n",
    "        self.att_classification = nn.Linear(256*2, 2)\n",
    "        self.vaswani_3_classification = nn.Linear(3 *(256*2), 2)\n",
    "        self.OvO_classification = nn.Linear(3*256, 2)\n",
    "\n",
    "        self.res_wrap = nn.Linear(512, 256)\n",
    "        self.bert_wrap = nn.Linear(768, 256)\n",
    "        \n",
    "        self.luong_attention  = MultiplicativeAttention(encoder_dim=256, decoder_dim=256)\n",
    "        self.vaswani_attention  = nn.MultiheadAttention(256, 2, batch_first = True)\n",
    "        self.OvO_attention = OneVSOthers(encoder_dim=256, decoder_dim=256)\n",
    "        self.OvO_concat_attention = OneVSOthers_concat(encoder_dim=256, decoder_dim=512)\n",
    "        \n",
    "    def bi_directional_att(self, pair):\n",
    "        x = pair[0]\n",
    "        y = pair[1]\n",
    "        attn_output_LV, attn_output_weights_LV = self.vaswani_attention(x, y, y)\n",
    "        attn_output_VL, attn_output_weights_VL = self.vaswani_attention(y, x, x)\n",
    "        combined = torch.cat((attn_output_LV,\n",
    "                              attn_output_VL), dim=1)\n",
    "        return combined\n",
    "\n",
    "    def forward(self, x, model):\n",
    "        if model == \"mlp\":\n",
    "            x = self.fc1(x)\n",
    "            x = self.relu(x)\n",
    "            x = self.fc2(x)\n",
    "            x = self.relu(x)\n",
    "            out = self.fc3(x)\n",
    "            \n",
    "        elif model == \"resnet\":\n",
    "            res = self.resnet18(x)       \n",
    "            out = self.resnet_classification(res)\n",
    "        \n",
    "        elif model == \"bert\":\n",
    "            text, masks = x\n",
    "            bert = self.bert(text, attention_mask=masks, token_type_ids=None).last_hidden_state[:,0,:]\n",
    "            out = self.bert_classification(bert)\n",
    "            \n",
    "        elif model == \"bert_resnet\":\n",
    "            img, text, masks = x\n",
    "            res_emb = self.resnet18(img)\n",
    "            \n",
    "            bert_emb = self.bert(text, attention_mask=masks).last_hidden_state[:,0,:]\n",
    "            combined = torch.cat((res_emb,\n",
    "                                  bert_emb), dim=1)\n",
    "            out = self.bert_resnet_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_resnet_l\":\n",
    "            img, text, masks = x\n",
    "            res_emb = self.resnet18(img)\n",
    "            res = self.res_wrap(res_emb)\n",
    "            \n",
    "            bert_emb = self.bert(text, attention_mask=masks).last_hidden_state[:,0,:]\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "            \n",
    "            combined = torch.cat((res,\n",
    "                                  bert), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_luong\":\n",
    "            img, text, masks = x\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "\n",
    "            attn_output_LV = self.luong_attention(bert, res)\n",
    "            attn_output_VL = self.luong_attention(res, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_vaswani\":\n",
    "            img, text, masks = x\n",
    "            res_emb = self.resnet18(img)\n",
    "        \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            res = self.res_wrap(res_emb)\n",
    "            res = res[:, None, :]\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "            bert = bert[:, None, :]\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, res, res)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(res, bert, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV.squeeze(1),\n",
    "                                  attn_output_VL.squeeze(1)), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_mlp\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            \n",
    "            combined = torch.cat((bert,feat), dim=1)\n",
    "            out = self.bert_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_mlp_luong\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "\n",
    "            attn_output_LV = self.luong_attention(bert, feat)\n",
    "            attn_output_VL = self.luong_attention(feat, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)   \n",
    "            \n",
    "        elif model == \"bert_mlp_vaswani\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, feat, feat)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, bert, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"resnet_mlp\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            \n",
    "            combined = torch.cat((feat,res), dim=1)\n",
    "            out = self.resnet_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"resnet_mlp_luong\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            attn_output_LV = self.luong_attention(res, feat)\n",
    "            attn_output_VL = self.luong_attention(feat, res)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)   \n",
    "            \n",
    "        elif model == \"resnet_mlp_vaswani\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(res, feat, feat)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, res, res)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            \n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_resnet_mlp\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            res = self.resnet18(img)\n",
    "\n",
    "            \n",
    "            combined = torch.cat((bert,feat, res), dim=1)\n",
    "            out = self.bert_resnet_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_resnet_mlp_l\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            combined = torch.cat((bert, feat, res), dim=1)\n",
    "            out = self.bert_resnet_mlp_l_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_mlp_vaswani\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "            \n",
    "            pairs = [[feat, bert],[feat,res],[bert,res]]\n",
    "        \n",
    "            results = []\n",
    "            for pair in pairs:\n",
    "                combined = self.bi_directional_att(pair)\n",
    "                results.append(combined)\n",
    "\n",
    "            comb = torch.cat(results, dim=1)\n",
    "            out = self.vaswani_3_classification(comb)\n",
    "            \n",
    "        else:\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            print(res.shape)\n",
    "            res = self.res_wrap(res)\n",
    "            \n",
    "            attn_txt = self.OvO_concat_attention([feat, res], bert)\n",
    "            attn_img = self.OvO_concat_attention([feat, bert],res)\n",
    "            attn_tab = self.OvO_concat_attention([bert, res], feat)\n",
    "\n",
    "            comb = torch.cat([attn_txt, attn_img, attn_tab], dim=1)\n",
    "            out = self.OvO_classification(comb)\n",
    "\n",
    "        return out\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "12928128",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTorch Version:  1.11.0+cu113\n",
      "Torchvision Version:  0.12.0+cu113\n",
      "cpu\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import sys\n",
    "import logging\n",
    "from pathlib import Path\n",
    "import random\n",
    "import tarfile\n",
    "import tempfile\n",
    "import warnings\n",
    "import matplotlib.pyplot as plt\n",
    "# import pandas_path  # Path style access for pandas\n",
    "from tqdm import tqdm\n",
    "import torch                    \n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "from torchvision import datasets, models, transforms\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "import os\n",
    "import copy\n",
    "import pprint\n",
    "print(\"PyTorch Version: \",torch.__version__)\n",
    "print(\"Torchvision Version: \",torchvision.__version__)\n",
    "from sklearn.metrics import confusion_matrix,f1_score\n",
    "\n",
    "from torch.utils.data import TensorDataset\n",
    "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler\n",
    "\n",
    "from collections import OrderedDict\n",
    "\n",
    "from PIL import ImageFile,Image\n",
    "ImageFile.LOAD_TRUNCATED_IMAGES = True\n",
    "\n",
    "from transformers import BertModel\n",
    "from transformers import BertForSequenceClassification, AdamW, BertConfig\n",
    "\n",
    "\n",
    "import logging\n",
    "logging.propagate = False \n",
    "logging.getLogger().setLevel(logging.ERROR)\n",
    "\n",
    "# WandB – Import the wandb library\n",
    "import wandb\n",
    "\n",
    "from model_utils import set_seed, build_optimizer\n",
    "\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(device)\n",
    "\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "#train_model(model_name, dataloaders_dict, criterion, len_train, len_val, config, path)\n",
    "def train_model(model_name, dataloaders, criterion, len_train, len_val, config, path):\n",
    "    \n",
    "    set_seed(42)\n",
    "    \"\"\"\n",
    "    if model_name == \"bert_resnet\":\n",
    "        model = BertResNet()\n",
    "    elif model_name == \"bert_resnet_luong\":\n",
    "        model = BertResNetLuong() \n",
    "    else:\n",
    "        model = BertResNetVaswani()  \n",
    "    \n",
    "    #for param in model.resnet18.parameters():\n",
    "    #        param.requires_grad = False\n",
    "    \n",
    "    #for param in model.bert.parameters():\n",
    "    #        param.requires_grad = False\n",
    "    \"\"\"\n",
    "    \n",
    "    model = MultimodalFramework()\n",
    "    \n",
    "    torch.cuda.empty_cache()\n",
    "    device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "    model.to(device)\n",
    "    \n",
    "    num_epochs = 11\n",
    "    optimizer = build_optimizer(model, \"adamW\", 0.1, 0.9)\n",
    "\n",
    "    since = time.time()\n",
    "\n",
    "    val_acc_history = []\n",
    "    val_loss_history = []\n",
    "    train_acc_history = []\n",
    "    train_loss_history = []\n",
    "\n",
    "    best_acc = 0.0\n",
    "    patience = 5 \n",
    "    trigger = 0\n",
    "    acc_dict = {}\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        #scheduler.step()\n",
    "        print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n",
    "        print('-' * 10)\n",
    "\n",
    "        # Each epoch has a training and validation phase\n",
    "        for phase in ['train', 'val']:\n",
    "            if phase == 'train':\n",
    "                length = len_train\n",
    "                model.train()  # Set model to training mode\n",
    "            else:\n",
    "                length = len_val\n",
    "                model.eval()   # Set model to evaluate mode\n",
    "\n",
    "            running_loss = 0.0\n",
    "            running_corrects = 0\n",
    "            predicted_labels, ground_truth_labels = [], []\n",
    "\n",
    "            for modality1, modality2 in zip(dataloaders[phase][0], dataloaders[phase][1]):\n",
    "                \n",
    "                # zero the parameter gradients\n",
    "                optimizer.zero_grad()\n",
    "\n",
    "                # forward\n",
    "                # track history if only in train\n",
    "                with torch.set_grad_enabled(phase == 'train'):\n",
    "                    \n",
    "                    if model_name.split(\"_\")[:2] == [\"bert\", \"resnet\"]:\n",
    "                        text_inp, masks, text_labels = modality2\n",
    "                        img_inp, labels = modality1\n",
    "\n",
    "                        text_inp, masks, text_labels = text_inp.to(device), masks.to(device), text_labels.to(device)\n",
    "                        img_inp, labels = img_inp.to(device), labels.to(device)\n",
    "                        \n",
    "                        inp_len = text_inp.size(0)\n",
    "                        outputs = model([img_inp, text_inp, masks], model_name)\n",
    "\n",
    "                    elif model_name.split(\"_\")[:2] == [\"resnet\", \"mlp\"]:\n",
    "                        img_inp, labels = modality1\n",
    "                        tab_inp, tab_labels = modality2\n",
    "                        tab_inp, tab_labels = tab_inp.float(), tab_labels.long()\n",
    "\n",
    "                        tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)\n",
    "                        img_inp, labels = img_inp.to(device), labels.to(device)\n",
    "                        \n",
    "                        inp_len = tab_inp.size(0)\n",
    "                        outputs = model([tab_inp, img_inp], model_name)\n",
    "                    else:\n",
    "                        tab_inp, tab_labels = modality1\n",
    "                        text_inp, masks, labels = modality2\n",
    "                        tab_inp, tab_labels = tab_inp.float(), tab_labels.long()\n",
    "\n",
    "                        tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)\n",
    "                        text_inp, masks, labels = text_inp.to(device), masks.to(device), labels.to(device)\n",
    "                        \n",
    "                        inp_len = tab_inp.size(0)\n",
    "                        outputs = model([tab_inp, text_inp, masks], model_name)\n",
    "                    \n",
    "                    loss = criterion(outputs, labels)\n",
    "\n",
    "                    _, preds = torch.max(outputs, 1)\n",
    "\n",
    "                    # backward + optimize only if in training phase\n",
    "                    if phase == 'train':\n",
    "                        loss.backward()\n",
    "                        optimizer.step()\n",
    "\n",
    "                # statistics\n",
    "                #print(\"text_inp.size(0)\")\n",
    "                #print(text_inp.size(0))\n",
    "\n",
    "                running_loss += loss.item() * labels.size(0)\n",
    "                running_corrects += torch.sum(preds == labels.data)\n",
    "                predicted_labels.extend(preds.cpu().detach().numpy())\n",
    "                ground_truth_labels.extend(labels.cpu().detach().numpy())\n",
    "                \n",
    "            epoch_loss = running_loss / length\n",
    "            epoch_acc = running_corrects.double() / length\n",
    "            #epoch_f1 = f1.double() / len(dataloaders[phase].dataset)\n",
    "            epoch_f1 = f1_score(ground_truth_labels, predicted_labels)\n",
    "\n",
    "            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))\n",
    "\n",
    "            if phase == 'val':\n",
    "                #wandb.log({\"val_loss\": epoch_loss, \"val_acc\": epoch_acc, \"val_f1\": epoch_f1})\n",
    "                acc_dict[epoch] = float(epoch_acc.detach().cpu())\n",
    "                val_acc_history.append(epoch_acc.detach().cpu())\n",
    "                val_loss_history.append(epoch_loss)\n",
    "                torch.save(model.state_dict(), path+\"_current.pth\")\n",
    "                if epoch_acc > best_acc:\n",
    "                    best_acc = epoch_acc\n",
    "                    #best_model_wts = copy.deepcopy(model.state_dict())\n",
    "                    #torch.save(model.state_dict(), path+\"_best.pth\")\n",
    "                #\"\"\"\n",
    "                if (epoch > 10) and (acc_dict[epoch] <= acc_dict[epoch - 10]):\n",
    "                    trigger +=1\n",
    "                    if trigger >= patience:\n",
    "                        return model, {\"train_acc\":train_acc_history, \"val_acc\":val_acc_history,\"train_loss\":train_loss_history, \"val_loss\":val_loss_history}\n",
    "                else:\n",
    "                    trigger = 0\n",
    "                #\"\"\"    \n",
    "            if phase == 'train':\n",
    "                wandb.log({\"train_loss\": epoch_loss, \"train_acc\": epoch_acc,\"train_f1\": epoch_f1, \"epoch\": epoch})\n",
    "                train_acc_history.append(epoch_acc.detach().cpu())\n",
    "                train_loss_history.append(epoch_loss)\n",
    "\n",
    "\n",
    "    time_elapsed = time.time() - since\n",
    "    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))\n",
    "    print('Best val Acc: {:4f}'.format(best_acc))\n",
    "\n",
    "    # load best model weights\n",
    "    #model.load_state_dict(best_model_wts)\n",
    "    return model, {\"train_acc\":train_acc_history, \"val_acc\":val_acc_history,\"train_loss\":train_loss_history, \"val_loss\":val_loss_history}\n",
    "\n",
    " \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e967b16f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Random seed set as 42\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0/10\n",
      "----------\n"
     ]
    },
    {
     "ename": "ValueError",
     "evalue": "not enough values to unpack (expected 4, got 3)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_14823/81535617.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     25\u001b[0m \u001b[0mpath\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'/users/mgolovan/data/mgolovan/facebook_memes/bert_resnet_models/best_model_'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     26\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 27\u001b[0;31m \u001b[0mtrain_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataloaders_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen_val\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"config\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     28\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     29\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_14823/1316478374.py\u001b[0m in \u001b[0;36mtrain_model\u001b[0;34m(model_name, dataloaders, criterion, len_train, len_val, config, path)\u001b[0m\n\u001b[1;32m    129\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    130\u001b[0m                         \u001b[0minp_len\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtext_inp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m                         \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mimg_inp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtext_inp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmasks\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    132\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    133\u001b[0m                     \u001b[0;32melif\u001b[0m \u001b[0mmodel_name\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"_\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m\"resnet\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"mlp\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1108\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1109\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1111\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1112\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_14823/2190398549.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, model)\u001b[0m\n\u001b[1;32m    343\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    344\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 345\u001b[0;31m             \u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmasks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    346\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    347\u001b[0m             \u001b[0mfeat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: not enough values to unpack (expected 4, got 3)"
     ]
    }
   ],
   "source": [
    "\n",
    "model_name = \"bert_resnet_ours_concat\"\n",
    "\n",
    "train_inputs_txt = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/train_txt.pt')\n",
    "val_inputs_txt = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/val_txt.pt')\n",
    "\n",
    "train_inputs_img = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/train_img.pt')\n",
    "val_inputs_img = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/val_img.pt')\n",
    "\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "\n",
    "train_dataloader_text = DataLoader(train_inputs_txt, batch_size=4,shuffle=False)\n",
    "val_dataloader_text = DataLoader(val_inputs_txt, batch_size=4, shuffle=False)\n",
    "\n",
    "train_dataloader_img = DataLoader(train_inputs_img, batch_size=4,shuffle=False)\n",
    "val_dataloader_img = DataLoader(val_inputs_img, batch_size=4, shuffle=False)\n",
    "\n",
    "len_val = len(val_inputs_txt)\n",
    "len_train = len(train_inputs_txt)\n",
    "\n",
    "dataloaders_dict = {'train':[train_dataloader_img, train_dataloader_text], 'val':[val_dataloader_img, val_dataloader_text]}\n",
    "\n",
    "\n",
    "path = '/users/mgolovan/data/mgolovan/facebook_memes/bert_resnet_models/best_model_' \n",
    "\n",
    "train_model(model_name, dataloaders_dict, criterion, len_train, len_val, \"config\", path)  \n",
    "\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f3f7134a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTorch Version:  1.11.0+cu113\n",
      "Torchvision Version:  0.12.0+cu113\n",
      "cpu\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import sys\n",
    "import logging\n",
    "from pathlib import Path\n",
    "import random\n",
    "import tarfile\n",
    "import tempfile\n",
    "import warnings\n",
    "import matplotlib.pyplot as plt\n",
    "# import pandas_path  # Path style access for pandas\n",
    "from tqdm import tqdm\n",
    "import torch                    \n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "from torchvision import datasets, models, transforms\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "import os\n",
    "import copy\n",
    "import pprint\n",
    "print(\"PyTorch Version: \",torch.__version__)\n",
    "print(\"Torchvision Version: \",torchvision.__version__)\n",
    "from sklearn.metrics import confusion_matrix,f1_score\n",
    "\n",
    "from torch.utils.data import TensorDataset\n",
    "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler\n",
    "\n",
    "from collections import OrderedDict\n",
    "\n",
    "from PIL import ImageFile,Image\n",
    "ImageFile.LOAD_TRUNCATED_IMAGES = True\n",
    "\n",
    "from transformers import BertModel\n",
    "from transformers import BertForSequenceClassification, AdamW, BertConfig\n",
    "\n",
    "\n",
    "import logging\n",
    "logging.propagate = False \n",
    "logging.getLogger().setLevel(logging.ERROR)\n",
    "\n",
    "# WandB – Import the wandb library\n",
    "import wandb\n",
    "\n",
    "#from models import MultimodalFramework\n",
    "from model_utils import set_seed, build_optimizer\n",
    "\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(device)\n",
    "\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "#train_model(model_name, dataloaders_dict, criterion, len_train, len_val, config, path)\n",
    "def train_model(model_name, dataloaders, criterion, len_train, len_val, config, path):\n",
    "    \n",
    "    set_seed(42)\n",
    "\n",
    "    model = MultimodalFramework()\n",
    "    torch.cuda.empty_cache()\n",
    "    device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "    model.to(device)\n",
    "    \n",
    "    num_epochs = 1\n",
    "    optimizer = build_optimizer(model, \"adamW\", 0.1, 0.9)\n",
    "\n",
    "    since = time.time()\n",
    "\n",
    "    val_acc_history = []\n",
    "    val_loss_history = []\n",
    "    train_acc_history = []\n",
    "    train_loss_history = []\n",
    "\n",
    "    best_acc = 0.0\n",
    "    patience = 10 \n",
    "    trigger = 0\n",
    "    acc_dict = {}\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        #scheduler.step()\n",
    "        print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n",
    "        print('-' * 10)\n",
    "\n",
    "        # Each epoch has a training and validation phase\n",
    "        for phase in ['train', 'val']:\n",
    "            if phase == 'train':\n",
    "                length = len_train\n",
    "                model.train()  # Set model to training mode\n",
    "            else:\n",
    "                length = len_val\n",
    "                model.eval()   # Set model to evaluate mode\n",
    "\n",
    "            running_loss = 0.0\n",
    "            running_corrects = 0\n",
    "            predicted_labels, ground_truth_labels = [], []\n",
    "\n",
    "            for modality1, modality2 in zip(dataloaders[phase][0], dataloaders[phase][1]):\n",
    "                \n",
    "                # zero the parameter gradients\n",
    "                optimizer.zero_grad()\n",
    "\n",
    "                # forward\n",
    "                # track history if only in train\n",
    "                with torch.set_grad_enabled(phase == 'train'):\n",
    "                    \n",
    "                    if model_name.split(\"_\")[:2] == [\"bert\", \"resnet\"]:\n",
    "                        text_inp, masks, text_labels = modality2\n",
    "                        img_inp, labels = modality1\n",
    "\n",
    "                        text_inp, masks, text_labels = text_inp.to(device), masks.to(device), text_labels.to(device)\n",
    "                        img_inp, labels = img_inp.to(device), labels.to(device)\n",
    "                        \n",
    "                        inp_len = text_inp.size(0)\n",
    "                        outputs = model([img_inp, text_inp, masks], model_name)\n",
    "\n",
    "                    elif model_name.split(\"_\")[:2] == [\"resnet\", \"mlp\"]:\n",
    "                        img_inp, labels = modality1\n",
    "                        tab_inp, tab_labels = modality2\n",
    "                        tab_inp, tab_labels = tab_inp.float(), tab_labels.long()\n",
    "\n",
    "                        tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)\n",
    "                        img_inp, labels = img_inp.to(device), labels.to(device)\n",
    "                        \n",
    "                        inp_len = tab_inp.size(0)\n",
    "                        outputs = model([tab_inp, img_inp], model_name)\n",
    "                    else:\n",
    "                        tab_inp, tab_labels = modality1\n",
    "                        text_inp, masks, labels = modality2\n",
    "                        tab_inp, tab_labels = tab_inp.float(), tab_labels.long()\n",
    "\n",
    "                        tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)\n",
    "                        text_inp, masks, labels = text_inp.to(device), masks.to(device), labels.to(device)\n",
    "                        \n",
    "                        inp_len = tab_inp.size(0)\n",
    "                        outputs = model([tab_inp, text_inp, masks], model_name)\n",
    "                    \n",
    "                    loss = criterion(outputs, labels)\n",
    "\n",
    "                    _, preds = torch.max(outputs, 1)\n",
    "\n",
    "                    # backward + optimize only if in training phase\n",
    "                    if phase == 'train':\n",
    "                        loss.backward()\n",
    "                        optimizer.step()\n",
    "\n",
    "                # statistics\n",
    "                #print(\"text_inp.size(0)\")\n",
    "                #print(text_inp.size(0))\n",
    "\n",
    "                running_loss += loss.item() * labels.size(0)\n",
    "                running_corrects += torch.sum(preds == labels.data)\n",
    "                predicted_labels.extend(preds.cpu().detach().numpy())\n",
    "                ground_truth_labels.extend(labels.cpu().detach().numpy())\n",
    "                \n",
    "            epoch_loss = running_loss / length\n",
    "            epoch_acc = running_corrects.double() / length\n",
    "            #epoch_f1 = f1.double() / len(dataloaders[phase].dataset)\n",
    "            epoch_f1 = f1_score(ground_truth_labels, predicted_labels)\n",
    "\n",
    "            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))\n",
    "\n",
    "            if phase == 'val':\n",
    "                wandb.log({\"val_loss\": epoch_loss, \"val_acc\": epoch_acc, \"val_f1\": epoch_f1})\n",
    "                acc_dict[epoch] = float(epoch_acc.detach().cpu())\n",
    "                val_acc_history.append(epoch_acc.detach().cpu())\n",
    "                val_loss_history.append(epoch_loss)\n",
    "                torch.save(model.state_dict(), path+\"_current.pth\")\n",
    "                if epoch_acc > best_acc:\n",
    "                    best_acc = epoch_acc\n",
    "                    #best_model_wts = copy.deepcopy(model.state_dict())\n",
    "                    #torch.save(model.state_dict(), path+\"_best.pth\")\n",
    "                #\"\"\"\n",
    "                if (epoch > 10) and (acc_dict[epoch] <= acc_dict[epoch - 10]):\n",
    "                    trigger +=1\n",
    "                    if trigger >= patience:\n",
    "                        return model, {\"train_acc\":train_acc_history, \"val_acc\":val_acc_history,\"train_loss\":train_loss_history, \"val_loss\":val_loss_history}\n",
    "                else:\n",
    "                    trigger = 0\n",
    "                #\"\"\"    \n",
    "            if phase == 'train':\n",
    "                wandb.log({\"train_loss\": epoch_loss, \"train_acc\": epoch_acc,\"train_f1\": epoch_f1, \"epoch\": epoch})\n",
    "                train_acc_history.append(epoch_acc.detach().cpu())\n",
    "                train_loss_history.append(epoch_loss)\n",
    "\n",
    "\n",
    "    time_elapsed = time.time() - since\n",
    "    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))\n",
    "    print('Best val Acc: {:4f}'.format(best_acc))\n",
    "\n",
    "    # load best model weights\n",
    "    #model.load_state_dict(best_model_wts)\n",
    "    return model, {\"train_acc\":train_acc_history, \"val_acc\":val_acc_history,\"train_loss\":train_loss_history, \"val_loss\":val_loss_history}\n",
    "\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "be2dac40",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import logging\n",
    "import random\n",
    "# import pandas_path  # Path style access for pandas\n",
    "from tqdm import tqdm\n",
    "import torch                    \n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "from torchvision import datasets, models, transforms\n",
    "from torch import flatten\n",
    "\n",
    "from collections import OrderedDict\n",
    "from transformers import BertModel, DistilBertModel\n",
    "\n",
    "from PIL import ImageFile,Image\n",
    "ImageFile.LOAD_TRUNCATED_IMAGES = True\n",
    "\n",
    "from MHA_modified import MultiheadAttention\n",
    "\n",
    "class Attention(torch.nn.Module):\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__()\n",
    "        self.encoder_dim = encoder_dim\n",
    "        self.decoder_dim = decoder_dim\n",
    "\n",
    "    def forward(self, \n",
    "        query: torch.Tensor,  # [decoder_dim]\n",
    "        values: torch.Tensor, # [seq_length, encoder_dim]\n",
    "        ):\n",
    "        weights = self._get_weights(query, values) # [seq_length]\n",
    "        weights = torch.nn.functional.softmax(weights, dim=0)\n",
    "        return weights @ values  # [encoder_dim]\n",
    "\n",
    "class MultiplicativeAttention(Attention):\n",
    "\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__(encoder_dim, decoder_dim)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(\n",
    "            self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1)).to('cpu')\n",
    "\n",
    "    def _get_weights(self,\n",
    "        query: torch.Tensor,  # [decoder_dim]\n",
    "        values: torch.Tensor, # [seq_length, encoder_dim]\n",
    "    ):\n",
    "        print(query.shape)\n",
    "        print(self.W.shape)\n",
    "        print(values.shape)\n",
    "        print(values.transpose(2,3).shape)\n",
    "        weights = query.transpose(2, 3) @ self.W @ values.transpose(2, 3) #.T  # [seq_length]\n",
    "        return weights #/np.sqrt(self.decoder_dim)    \n",
    "class OneVSOthers(Attention):\n",
    "\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__(encoder_dim, decoder_dim)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1)).to('cpu')\n",
    "\n",
    "    def _get_weights(self,others, main):\n",
    "        mean = sum(others) / len(others)\n",
    "        weights = mean @ self.W @ main.T  # [seq_length]\n",
    "        return weights #/np.sqrt(self.decoder_dim) \n",
    "    \n",
    "class OneVSOthers_concat(Attention):\n",
    "\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__(encoder_dim, decoder_dim)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))\n",
    "\n",
    "    def _get_weights(self,others, main):\n",
    "        m2 = others[0]\n",
    "        m3 = others[1]\n",
    "        con = torch.cat((m2, m3), dim=1)\n",
    "        weights = con @ self.W @ main.T  # [seq_length]\n",
    "        return weights/np.sqrt(self.decoder_dim)\n",
    "\n",
    "class MultiHeadAttention(nn.Module):\n",
    "\n",
    "    def __init__(self,\n",
    "                 in_features,\n",
    "                 head_num, typ,\n",
    "                 bias=True,\n",
    "                 activation=F.relu):\n",
    "        \n",
    "        super(MultiHeadAttention, self).__init__()\n",
    "        if in_features % head_num != 0:\n",
    "            raise ValueError('`in_features`({}) should be divisible by `head_num`({})'.format(in_features, head_num))\n",
    "        self.in_features = in_features\n",
    "        self.type = typ\n",
    "        self.head_num = head_num\n",
    "        self.activation = activation\n",
    "        self.bias = bias\n",
    "        self.linear_q = nn.Linear(in_features, in_features, bias)\n",
    "        self.linear_k = nn.Linear(in_features, in_features, bias)\n",
    "        self.linear_v = nn.Linear(in_features, in_features, bias)\n",
    "        self.linear_o = nn.Linear(in_features, in_features, bias)\n",
    "\n",
    "    def forward(self, q, k, v, mask=None): \n",
    "        #q = self.linear_q(q)\n",
    "        #k = self.linear_k(k)\n",
    "        #v = self.linear_v(v)\n",
    "        \n",
    "        dim = int(self.in_features / self.head_num)\n",
    "        #y = ScaledDotProductAttention()(q, k, v, mask)\n",
    "        q = self._reshape_to_batches(q)\n",
    "        k = self._reshape_to_batches(k)\n",
    "        v = self._reshape_to_batches(v)\n",
    "        if self.type == \"OvO\":\n",
    "            att = OneVSOthers(dim, dim)\n",
    "            y = att([q, k], v) #.cuda()\n",
    "        else:\n",
    "            att = MultiplicativeAttention(dim, dim)\n",
    "            y = att(q, v) #.cuda()\n",
    "        y = self._reshape_from_batches(y)\n",
    "        y = self.linear_o(y)\n",
    "        #if self.activation is not None:\n",
    "        #    y = self.activation(y)\n",
    "        return y\n",
    "\n",
    "    \"\"\"\n",
    "    def _reshape_to_batches(self, x):\n",
    "        seq_len = 1\n",
    "        batch_size, in_feature = x.size()\n",
    "        sub_dim = in_feature // self.head_num\n",
    "        return x.reshape(batch_size, self.head_num, sub_dim)\\\n",
    "                .permute(0, 1, 2)\\\n",
    "                .reshape(batch_size * self.head_num, sub_dim)\n",
    "\n",
    "    def _reshape_from_batches(self, x):\n",
    "        seq_len = 1\n",
    "        batch_size, in_feature = x.size()\n",
    "        batch_size //= self.head_num\n",
    "        out_dim = in_feature * self.head_num\n",
    "        return x.reshape(batch_size, self.head_num,  in_feature)\\\n",
    "                .permute(0, 1, 2)\\\n",
    "                .reshape(batch_size,  out_dim)\n",
    "    \"\"\"\n",
    "\n",
    "    def _reshape_to_batches(self, x):\n",
    "        x = x.unsqueeze(1)\n",
    "        batch_size, seq_len, in_feature = x.size()\n",
    "        sub_dim = in_feature // self.head_num\n",
    "        return x.reshape(batch_size, seq_len, self.head_num, sub_dim)\\\n",
    "                .permute(0, 2, 1, 3)\\\n",
    "                .reshape(batch_size * self.head_num, seq_len, sub_dim)\n",
    "\n",
    "    def _reshape_from_batches(self, x):\n",
    "        x = x.unsqueeze(1)\n",
    "        batch_size, seq_len, in_feature = x.size()\n",
    "        batch_size //= self.head_num\n",
    "        out_dim = in_feature * self.head_num\n",
    "        return x.reshape(batch_size, self.head_num, seq_len, in_feature)\\\n",
    "                .permute(0, 2, 1, 3)\\\n",
    "                .reshape(batch_size, seq_len, out_dim)\n",
    "    \n",
    "    def extra_repr(self):\n",
    "        return 'in_features={}, head_num={}, bias={}, activation={}'.format(\n",
    "            self.in_features, self.head_num, self.bias, self.activation)\n",
    "#\"\"\"\n",
    "\n",
    "class MultimodalFramework(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super(MultimodalFramework, self).__init__()\n",
    "        ##MLP\n",
    "        self.fc1 = nn.Linear(53, 256)\n",
    "        self.fc2 = nn.Linear(256, 256)\n",
    "        self.fc3 = nn.Linear(256, 2)\n",
    "        self.relu = nn.ReLU()\n",
    "        \n",
    "        ##RESNET\n",
    "        self.resnet18 = models.resnet18(pretrained=True)\n",
    "        n_inputs = self.resnet18.fc.in_features\n",
    "\n",
    "        self.resnet18.fc = nn.Sequential(OrderedDict([\n",
    "            ('fc1', nn.Linear(n_inputs, 512))\n",
    "        ])) \n",
    "\n",
    "        self.resnet_classification = nn.Linear(512, 2) #4\n",
    "        \n",
    "        ##BERT\n",
    "        self.bert = BertModel.from_pretrained('bert-base-uncased') \n",
    "        self.bert_classification = nn.Linear(768, 2)\n",
    "        \n",
    "        #Two Modality models\n",
    "        self.bert_resnet_classification = nn.Linear(512 + 768, 2)\n",
    "        self.bert_mlp_classification = nn.Linear(256 + 768, 2)\n",
    "        self.resnet_mlp_classification = nn.Linear(256 + 512, 2)\n",
    "        self.bert_resnet_mlp_classification = nn.Linear(256 + 512 +768, 2)\n",
    "        self.bert_resnet_mlp_l_classification = nn.Linear(256*3, 2)\n",
    "        self.att_classification = nn.Linear(256*2, 2)\n",
    "        self.vaswani_3_classification = nn.Linear(3 *(256*2), 2)\n",
    "        self.OvO_classification = nn.Linear(3*256, 2)\n",
    "\n",
    "        self.res_wrap = nn.Linear(512, 256)\n",
    "        self.bert_wrap = nn.Linear(768, 256)\n",
    "        \n",
    "        self.luong_attention  = MultiplicativeAttention(encoder_dim=256, decoder_dim=256)\n",
    "        self.vaswani_attention  = nn.MultiheadAttention(256, 8, batch_first = True)\n",
    "        self.OvO_attention = OneVSOthers(encoder_dim=256, decoder_dim=256)\n",
    "        self.OvO_concat_attention = OneVSOthers_concat(encoder_dim=256, decoder_dim=512)\n",
    "        self.OvO_multihead_attention = MultiheadAttention(256,2, typ = \"OvO\")\n",
    "        self.luong_multihead_attention = MultiHeadAttention(256,2, typ = \"luong\")\n",
    "        \n",
    "    def bi_directional_att(self, pair):\n",
    "        x = pair[0]\n",
    "        y = pair[1]\n",
    "        attn_output_LV, attn_output_weights_LV = self.vaswani_attention(x, y, y)\n",
    "        attn_output_VL, attn_output_weights_VL = self.vaswani_attention(y, x, x)\n",
    "        combined = torch.cat((attn_output_LV,\n",
    "                              attn_output_VL), dim=1)\n",
    "        return combined\n",
    "\n",
    "    def forward(self, x, model):\n",
    "        if model == \"mlp\":\n",
    "            x = self.fc1(x)\n",
    "            x = self.relu(x)\n",
    "            x = self.fc2(x)\n",
    "            x = self.relu(x)\n",
    "            out = self.fc3(x)\n",
    "            \n",
    "        elif model == \"resnet\":\n",
    "            res = self.resnet18(x)       \n",
    "            out = self.resnet_classification(res)\n",
    "        \n",
    "        elif model == \"bert\":\n",
    "            text, masks = x\n",
    "            bert = self.bert(text, attention_mask=masks, token_type_ids=None).last_hidden_state[:,0,:]\n",
    "            out = self.bert_classification(bert)\n",
    "            \n",
    "        elif model == \"bert_resnet\":\n",
    "            img, text, masks = x\n",
    "            res_emb = self.resnet18(img)\n",
    "            \n",
    "            bert_emb = self.bert(text, attention_mask=masks).last_hidden_state[:,0,:]\n",
    "            combined = torch.cat((res_emb,\n",
    "                                  bert_emb), dim=1)\n",
    "            out = self.bert_resnet_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_luong\":\n",
    "            img, text, masks = x\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "\n",
    "            attn_output_LV = self.luong_multihead_attention(bert, res, res)\n",
    "            attn_output_VL = self.luong_multihead_attention(res, bert, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_vaswani\":\n",
    "            img, text, masks = x\n",
    "            res_emb = self.resnet18(img)\n",
    "        \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            res = self.res_wrap(res_emb)\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, res, res)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(res, bert, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV.,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_mlp\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            \n",
    "            combined = torch.cat((bert,feat), dim=1)\n",
    "            out = self.bert_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_mlp_luong\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "\n",
    "            attn_output_LV = self.luong_attention(bert, feat)\n",
    "            attn_output_VL = self.luong_attention(feat, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)   \n",
    "            \n",
    "        elif model == \"bert_mlp_vaswani\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, feat, feat)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, bert, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"resnet_mlp\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            \n",
    "            combined = torch.cat((feat,res), dim=1)\n",
    "            out = self.resnet_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"resnet_mlp_luong\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            attn_output_LV = self.luong_attention(res, feat)\n",
    "            attn_output_VL = self.luong_attention(feat, res)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)   \n",
    "            \n",
    "        elif model == \"resnet_mlp_vaswani\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(res, feat, feat)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, res, res)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            \n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_resnet_mlp\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            res = self.resnet18(img)\n",
    "\n",
    "            \n",
    "            combined = torch.cat((bert,feat, res), dim=1)\n",
    "            out = self.bert_resnet_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_resnet_mlp_l\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            combined = torch.cat((bert, feat, res), dim=1)\n",
    "            out = self.bert_resnet_mlp_l_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_mlp_vaswani\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "            \n",
    "            pairs = [[feat, bert],[feat,res],[bert,res]]\n",
    "        \n",
    "            results = []\n",
    "            for pair in pairs:\n",
    "                combined = self.bi_directional_att(pair)\n",
    "                results.append(combined)\n",
    "\n",
    "            comb = torch.cat(results, dim=1)\n",
    "            out = self.vaswani_3_classification(comb)\n",
    "            \n",
    "        else:\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "            \n",
    "            attn_txt, weights_txt = self.OvO_multihead_attention(feat, res, bert) #[feat, res], bert\n",
    "            attn_img, weights_img = self.OvO_multihead_attention(feat, bert,res)\n",
    "            attn_tab, weights_tab = self.OvO_multihead_attention(bert, res, feat)\n",
    "\n",
    "            comb = torch.cat([attn_txt, attn_img, attn_tab], dim=1)\n",
    "            out = self.OvO_classification(comb)\n",
    "\n",
    "        return out\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6bbd2a29",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Random seed set as 42\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0/0\n",
      "----------\n",
      "torch.Size([8, 1, 128])\n",
      "torch.Size([128, 128])\n"
     ]
    },
    {
     "ename": "IndexError",
     "evalue": "Dimension out of range (expected to be in range of [-3, 2], but got 3)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mIndexError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_250771/1031777710.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     25\u001b[0m \u001b[0mpath\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'/users/mgolovan/data/mgolovan/facebook_memes/bert_resnet_models/best_model_'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0mtrain_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataloaders_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen_val\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"config\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_250771/1702783436.py\u001b[0m in \u001b[0;36mtrain_model\u001b[0;34m(model_name, dataloaders, criterion, len_train, len_val, config, path)\u001b[0m\n\u001b[1;32m    115\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    116\u001b[0m                         \u001b[0minp_len\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtext_inp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m                         \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mimg_inp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtext_inp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmasks\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    118\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    119\u001b[0m                     \u001b[0;32melif\u001b[0m \u001b[0mmodel_name\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"_\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m\"resnet\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"mlp\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1108\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1109\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1111\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1112\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_250771/599584889.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, model)\u001b[0m\n\u001b[1;32m    248\u001b[0m             \u001b[0mbert\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbert_wrap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    249\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 250\u001b[0;31m             \u001b[0mattn_output_LV\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mluong_multihead_attention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbert\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mres\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mres\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    251\u001b[0m             \u001b[0mattn_output_VL\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mluong_multihead_attention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mres\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbert\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    252\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1108\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1109\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1111\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1112\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_250771/599584889.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, q, k, v, mask)\u001b[0m\n\u001b[1;32m    112\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    113\u001b[0m             \u001b[0matt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMultiplicativeAttention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 114\u001b[0;31m             \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0matt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m#.cuda()\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    115\u001b[0m         \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_reshape_from_batches\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    116\u001b[0m         \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear_o\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1108\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1109\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1111\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1112\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_250771/599584889.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, query, values)\u001b[0m\n\u001b[1;32m     32\u001b[0m         \u001b[0mvalues\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;31m# [seq_length, encoder_dim]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     33\u001b[0m         ):\n\u001b[0;32m---> 34\u001b[0;31m         \u001b[0mweights\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_weights\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquery\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# [seq_length]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     35\u001b[0m         \u001b[0mweights\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunctional\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msoftmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweights\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     36\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mweights\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mvalues\u001b[0m  \u001b[0;31m# [encoder_dim]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_250771/599584889.py\u001b[0m in \u001b[0;36m_get_weights\u001b[0;34m(self, query, values)\u001b[0m\n\u001b[1;32m     49\u001b[0m         \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquery\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     50\u001b[0m         \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mW\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 51\u001b[0;31m         \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     52\u001b[0m         \u001b[0mweights\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mquery\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mW\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m#.T  # [seq_length]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     53\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mweights\u001b[0m \u001b[0;31m#/np.sqrt(self.decoder_dim)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mIndexError\u001b[0m: Dimension out of range (expected to be in range of [-3, 2], but got 3)"
     ]
    }
   ],
   "source": [
    "model_name = \"bert_resnet_luong\"\n",
    "train_inputs_txt = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/train_txt.pt')\n",
    "val_inputs_txt = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/val_txt.pt')\n",
    "\n",
    "train_inputs_img = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/train_img.pt')\n",
    "val_inputs_img = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/val_img.pt')\n",
    "\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "\n",
    "\n",
    "train_dataloader_text = DataLoader(train_inputs_txt, batch_size=4,shuffle=False)\n",
    "val_dataloader_text = DataLoader(val_inputs_txt, batch_size=4, shuffle=False)\n",
    "\n",
    "train_dataloader_img = DataLoader(train_inputs_img, batch_size=4,shuffle=False)\n",
    "val_dataloader_img = DataLoader(val_inputs_img, batch_size=4, shuffle=False)\n",
    "\n",
    "len_val = len(val_inputs_txt)\n",
    "len_train = len(train_inputs_txt)\n",
    "\n",
    "dataloaders_dict = {'train':[train_dataloader_img, train_dataloader_text], 'val':[val_dataloader_img, val_dataloader_text]}\n",
    "\n",
    "\n",
    "path = '/users/mgolovan/data/mgolovan/facebook_memes/bert_resnet_models/best_model_' \n",
    "train_model(model_name, dataloaders_dict, criterion, len_train, len_val, \"config\", path)  \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "id": "c49df8e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import logging\n",
    "import random\n",
    "# import pandas_path  # Path style access for pandas\n",
    "from tqdm import tqdm\n",
    "import torch                    \n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "from torchvision import datasets, models, transforms\n",
    "from torch import flatten\n",
    "\n",
    "from collections import OrderedDict\n",
    "from transformers import BertModel, DistilBertModel\n",
    "\n",
    "from PIL import ImageFile,Image\n",
    "ImageFile.LOAD_TRUNCATED_IMAGES = True\n",
    "\n",
    "from MHA_modified import MultiheadAttention\n",
    "\n",
    "class Attention(torch.nn.Module):\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__()\n",
    "        self.encoder_dim = encoder_dim\n",
    "        self.decoder_dim = decoder_dim\n",
    "\n",
    "    def forward(self, \n",
    "        query: torch.Tensor,  # [decoder_dim]\n",
    "        values: torch.Tensor, # [seq_length, encoder_dim]\n",
    "        ):\n",
    "        weights = self._get_weights(query, values) # [seq_length]\n",
    "        weights = torch.nn.functional.softmax(weights, dim=0)\n",
    "        return weights @ values  # [encoder_dim]\n",
    "\n",
    "class MultiplicativeAttention(Attention):\n",
    "\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__(encoder_dim, decoder_dim)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(\n",
    "            self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1)).to('cpu')\n",
    "\n",
    "    def _get_weights(self,\n",
    "        query: torch.Tensor,  # [decoder_dim]\n",
    "        values: torch.Tensor, # [seq_length, encoder_dim]\n",
    "    ):\n",
    "        weights = query @ self.W @ values.T  # [seq_length]\n",
    "        return weights #/np.sqrt(self.decoder_dim)    \n",
    "\n",
    "\n",
    "\n",
    "class OneVSOthers(Attention):\n",
    "\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__(encoder_dim, decoder_dim)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1)).to('cpu')\n",
    "\n",
    "    def _get_weights(self,others, main):\n",
    "        mean = sum(others) / len(others)\n",
    "        weights = mean @ self.W @ main.T  # [seq_length]\n",
    "        return weights #/np.sqrt(self.decoder_dim) \n",
    "    \n",
    "class OneVSOthers_concat(Attention):\n",
    "\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__(encoder_dim, decoder_dim)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))\n",
    "\n",
    "    def _get_weights(self,others, main):\n",
    "        m2 = others[0]\n",
    "        m3 = others[1]\n",
    "        con = torch.cat((m2, m3), dim=1)\n",
    "        weights = con @ self.W @ main.T  # [seq_length]\n",
    "        return weights/np.sqrt(self.decoder_dim)\n",
    "#\"\"\"\n",
    "class MultiHeadAttention(nn.Module):\n",
    "\n",
    "    def __init__(self,\n",
    "                 in_features,\n",
    "                 head_num, typ,\n",
    "                 bias=True,\n",
    "                 activation=F.relu):\n",
    "        \n",
    "        super(MultiHeadAttention, self).__init__()\n",
    "        if in_features % head_num != 0:\n",
    "            raise ValueError('`in_features`({}) should be divisible by `head_num`({})'.format(in_features, head_num))\n",
    "        self.in_features = in_features\n",
    "        self.type = typ\n",
    "        self.head_num = head_num\n",
    "        self.activation = activation\n",
    "        self.bias = bias\n",
    "        self.linear_q = nn.Linear(in_features, in_features, bias)\n",
    "        self.linear_k = nn.Linear(in_features, in_features, bias)\n",
    "        self.linear_v = nn.Linear(in_features, in_features, bias)\n",
    "        self.linear_o = nn.Linear(in_features, in_features, bias)\n",
    "\n",
    "    def forward(self, q, k, v, mask=None): \n",
    "        #q = self.linear_q(q)\n",
    "        #k = self.linear_k(k)\n",
    "        #v = self.linear_v(v)\n",
    "        \n",
    "        dim = int(self.in_features / self.head_num)\n",
    "        #y = ScaledDotProductAttention()(q, k, v, mask)\n",
    "        q = self._reshape_to_batches(q)\n",
    "        k = self._reshape_to_batches(k)\n",
    "        v = self._reshape_to_batches(v)\n",
    "        if self.type == \"OvO\":\n",
    "            att = OneVSOthers(dim, dim)\n",
    "            y = att([q, k], v) #.cuda()\n",
    "        else:\n",
    "            att = MultiplicativeAttention(dim, dim)\n",
    "            y = att(q, v) #.cuda()\n",
    "        y = self._reshape_from_batches(y)\n",
    "        y = self.linear_o(y)\n",
    "        #if self.activation is not None:\n",
    "        #    y = self.activation(y)\n",
    "        return y\n",
    "\n",
    "    \n",
    "    def _reshape_to_batches(self, x):\n",
    "        seq_len = 1\n",
    "        batch_size, in_feature = x.size()\n",
    "        sub_dim = in_feature // self.head_num\n",
    "        return x.reshape(batch_size, self.head_num, sub_dim)\\\n",
    "                .permute(0, 1, 2)\\\n",
    "                .reshape(batch_size * self.head_num, sub_dim)\n",
    "\n",
    "    def _reshape_from_batches(self, x):\n",
    "        seq_len = 1\n",
    "        batch_size, in_feature = x.size()\n",
    "        batch_size //= self.head_num\n",
    "        out_dim = in_feature * self.head_num\n",
    "        return x.reshape(batch_size, self.head_num,  in_feature)\\\n",
    "                .permute(0, 1, 2)\\\n",
    "                .reshape(batch_size,  out_dim)\n",
    "    \"\"\"\n",
    "\n",
    "    def _reshape_to_batches(self, x):\n",
    "        x = x.unsqueeze(1)\n",
    "        batch_size, seq_len, in_feature = x.size()\n",
    "        sub_dim = in_feature // self.head_num\n",
    "        return x.reshape(batch_size, seq_len, self.head_num, sub_dim)\\\n",
    "                .permute(0, 2, 1, 3)\\\n",
    "                .reshape(batch_size * self.head_num, seq_len, sub_dim)\n",
    "\n",
    "    def _reshape_from_batches(self, x):\n",
    "        x = x.unsqueeze(1)\n",
    "        batch_size, seq_len, in_feature = x.size()\n",
    "        batch_size //= self.head_num\n",
    "        out_dim = in_feature * self.head_num\n",
    "        return x.reshape(batch_size, self.head_num, seq_len, in_feature)\\\n",
    "                .permute(0, 2, 1, 3)\\\n",
    "                .reshape(batch_size, seq_len, out_dim)\n",
    "    \n",
    "    def extra_repr(self):\n",
    "        return 'in_features={}, head_num={}, bias={}, activation={}'.format(\n",
    "            self.in_features, self.head_num, self.bias, self.activation)\n",
    "    \"\"\"\n",
    "\n",
    "class MultimodalFramework(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super(MultimodalFramework, self).__init__()\n",
    "        ##MLP\n",
    "        self.fc1 = nn.Linear(53, 256)\n",
    "        self.fc2 = nn.Linear(256, 256)\n",
    "        self.fc3 = nn.Linear(256, 2)\n",
    "        self.relu = nn.ReLU()\n",
    "        \n",
    "        ##RESNET\n",
    "        self.resnet18 = models.resnet18(pretrained=True)\n",
    "        n_inputs = self.resnet18.fc.in_features\n",
    "\n",
    "        self.resnet18.fc = nn.Sequential(OrderedDict([\n",
    "            ('fc1', nn.Linear(n_inputs, 512))\n",
    "        ])) \n",
    "\n",
    "        self.resnet_classification = nn.Linear(512, 2) #4\n",
    "        \n",
    "        ##BERT\n",
    "        self.bert = BertModel.from_pretrained('bert-base-uncased') \n",
    "        self.bert_classification = nn.Linear(768, 2)\n",
    "        \n",
    "        #Two Modality models\n",
    "        self.bert_resnet_classification = nn.Linear(512 + 768, 2)\n",
    "        self.bert_mlp_classification = nn.Linear(256 + 768, 2)\n",
    "        self.resnet_mlp_classification = nn.Linear(256 + 512, 2)\n",
    "        self.bert_resnet_mlp_classification = nn.Linear(256 + 512 +768, 2)\n",
    "        self.bert_resnet_mlp_l_classification = nn.Linear(256*3, 2)\n",
    "        self.att_classification = nn.Linear(256*2, 2)\n",
    "        self.vaswani_3_classification = nn.Linear(3 *(256*2), 2)\n",
    "        self.OvO_classification = nn.Linear(3*256, 2)\n",
    "\n",
    "        self.res_wrap = nn.Linear(512, 256)\n",
    "        self.bert_wrap = nn.Linear(768, 256)\n",
    "        \n",
    "        #self.luong_attention  = MultiplicativeAttention(encoder_dim=256, decoder_dim=256)\n",
    "        self.vaswani_attention  = nn.MultiheadAttention(256, 4, batch_first = True)\n",
    "        #self.OvO_attention = OneVSOthers(encoder_dim=256, decoder_dim=256)\n",
    "        self.OvO_concat_attention = OneVSOthers_concat(encoder_dim=256, decoder_dim=512)\n",
    "        self.OvO_multihead_attention = MultiheadAttention(256,2, typ = \"OvO\")\n",
    "        self.luong_multihead_attention = MultiHeadAttention(256,2, typ = \"luong\")\n",
    "        \n",
    "    def bi_directional_att(self, pair):\n",
    "        x = pair[0]\n",
    "        y = pair[1]\n",
    "        attn_output_LV, attn_output_weights_LV = self.vaswani_attention(x, y, y)\n",
    "        attn_output_VL, attn_output_weights_VL = self.vaswani_attention(y, x, x)\n",
    "        combined = torch.cat((attn_output_LV,\n",
    "                              attn_output_VL), dim=1)\n",
    "        return combined\n",
    "\n",
    "    def forward(self, x, model):\n",
    "        if model == \"mlp\":\n",
    "            x = self.fc1(x)\n",
    "            x = self.relu(x)\n",
    "            x = self.fc2(x)\n",
    "            x = self.relu(x)\n",
    "            out = self.fc3(x)\n",
    "            \n",
    "        elif model == \"resnet\":\n",
    "            res = self.resnet18(x)       \n",
    "            out = self.resnet_classification(res)\n",
    "        \n",
    "        elif model == \"bert\":\n",
    "            text, masks = x\n",
    "            bert = self.bert(text, attention_mask=masks, token_type_ids=None).last_hidden_state[:,0,:]\n",
    "            out = self.bert_classification(bert)\n",
    "            \n",
    "        elif model == \"bert_resnet\":\n",
    "            img, text, masks = x\n",
    "            res_emb = self.resnet18(img)\n",
    "            \n",
    "            bert_emb = self.bert(text, attention_mask=masks).last_hidden_state[:,0,:]\n",
    "            combined = torch.cat((res_emb,\n",
    "                                  bert_emb), dim=1)\n",
    "            out = self.bert_resnet_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_luong\":\n",
    "            img, text, masks = x\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "\n",
    "            attn_output_LV = self.luong_multihead_attention(bert, res, res)\n",
    "            attn_output_VL = self.luong_multihead_attention(res, bert, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_vaswani\":\n",
    "            img, text, masks = x\n",
    "            res_emb = self.resnet18(img)\n",
    "        \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            res = self.res_wrap(res_emb)\n",
    "            res = res[:, None, :]\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "            bert = bert[:, None, :]\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, res, res)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(res, bert, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV.squeeze(1),\n",
    "                                  attn_output_VL.squeeze(1)), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_mlp\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            \n",
    "            combined = torch.cat((bert,feat), dim=1)\n",
    "            out = self.bert_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_mlp_luong\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "\n",
    "            attn_output_LV = self.luong_attention(bert, feat)\n",
    "            attn_output_VL = self.luong_attention(feat, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)   \n",
    "            \n",
    "        elif model == \"bert_mlp_vaswani\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, feat, feat)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, bert, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"resnet_mlp\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            \n",
    "            combined = torch.cat((feat,res), dim=1)\n",
    "            out = self.resnet_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"resnet_mlp_luong\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            attn_output_LV = self.luong_attention(res, feat)\n",
    "            attn_output_VL = self.luong_attention(feat, res)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)   \n",
    "            \n",
    "        elif model == \"resnet_mlp_vaswani\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(res, feat, feat)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, res, res)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            \n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_resnet_mlp\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            res = self.resnet18(img)\n",
    "\n",
    "            \n",
    "            combined = torch.cat((bert,feat, res), dim=1)\n",
    "            out = self.bert_resnet_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_resnet_mlp_l\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            combined = torch.cat((bert, feat, res), dim=1)\n",
    "            out = self.bert_resnet_mlp_l_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_mlp_vaswani\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "            \n",
    "            pairs = [[feat, bert],[feat,res],[bert,res]]\n",
    "        \n",
    "            results = []\n",
    "            for pair in pairs:\n",
    "                combined = self.bi_directional_att(pair)\n",
    "                results.append(combined)\n",
    "\n",
    "            comb = torch.cat(results, dim=1)\n",
    "            out = self.vaswani_3_classification(comb)\n",
    "            \n",
    "        else:\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "            \n",
    "            attn_txt, weights_txt = self.OvO_multihead_attention(feat, res, bert) #[feat, res], bert\n",
    "            attn_img, weights_img = self.OvO_multihead_attention(feat, bert,res)\n",
    "            attn_tab, weights_tab = self.OvO_multihead_attention(bert, res, feat)\n",
    "\n",
    "            comb = torch.cat([attn_txt, attn_img, attn_tab], dim=1)\n",
    "            out = self.OvO_classification(comb)\n",
    "\n",
    "        return out\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "id": "dee9bfda",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import logging\n",
    "import random\n",
    "# import pandas_path  # Path style access for pandas\n",
    "from tqdm import tqdm\n",
    "import torch                    \n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "from torchvision import datasets, models, transforms\n",
    "from torch import flatten\n",
    "\n",
    "from collections import OrderedDict\n",
    "from transformers import BertModel, DistilBertModel\n",
    "\n",
    "from PIL import ImageFile,Image\n",
    "ImageFile.LOAD_TRUNCATED_IMAGES = True\n",
    "\n",
    "from MHA_modified import MultiheadAttention\n",
    "\n",
    "class Attention(torch.nn.Module):\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__()\n",
    "        self.encoder_dim = encoder_dim\n",
    "        self.decoder_dim = decoder_dim\n",
    "\n",
    "    def forward(self, \n",
    "        query: torch.Tensor,  # [decoder_dim]\n",
    "        values: torch.Tensor, # [seq_length, encoder_dim]\n",
    "        ):\n",
    "        weights = self._get_weights(query, values) # [seq_length]\n",
    "        weights = torch.nn.functional.softmax(weights, dim=0)\n",
    "        return weights @ values  # [encoder_dim]\n",
    "\n",
    "class MultiplicativeAttention(Attention):\n",
    "\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__(encoder_dim, decoder_dim)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(\n",
    "            self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1)).to('cpu')\n",
    "\n",
    "    def _get_weights(self,\n",
    "        query: torch.Tensor,  # [decoder_dim]\n",
    "        values: torch.Tensor, # [seq_length, encoder_dim]\n",
    "    ):\n",
    "        weights = query @ self.W @ values.T  # [seq_length]\n",
    "        return weights #/np.sqrt(self.decoder_dim)    \n",
    "\n",
    "\n",
    "\n",
    "class OneVSOthers(Attention):\n",
    "\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__(encoder_dim, decoder_dim)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1)).to('cpu')\n",
    "\n",
    "    def _get_weights(self,others, main):\n",
    "        mean = sum(others) / len(others)\n",
    "        weights = mean @ self.W @ main.T  # [seq_length]\n",
    "        return weights #/np.sqrt(self.decoder_dim) \n",
    "    \n",
    "class OneVSOthers_concat(Attention):\n",
    "\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__(encoder_dim, decoder_dim)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))\n",
    "\n",
    "    def _get_weights(self,others, main):\n",
    "        m2 = others[0]\n",
    "        m3 = others[1]\n",
    "        con = torch.cat((m2, m3), dim=1)\n",
    "        weights = con @ self.W @ main.T  # [seq_length]\n",
    "        return weights/np.sqrt(self.decoder_dim)\n",
    "#\"\"\"\n",
    "class MultiHeadAttention(nn.Module):\n",
    "\n",
    "    def __init__(self,\n",
    "                 in_features,\n",
    "                 head_num, typ,\n",
    "                 bias=True,\n",
    "                 activation=F.relu):\n",
    "        \n",
    "        super(MultiHeadAttention, self).__init__()\n",
    "        if in_features % head_num != 0:\n",
    "            raise ValueError('`in_features`({}) should be divisible by `head_num`({})'.format(in_features, head_num))\n",
    "        self.in_features = in_features\n",
    "        self.type = typ\n",
    "        self.head_num = head_num\n",
    "        self.activation = activation\n",
    "        self.bias = bias\n",
    "        self.linear_q = nn.Linear(in_features, in_features, bias)\n",
    "        self.linear_k = nn.Linear(in_features, in_features, bias)\n",
    "        self.linear_v = nn.Linear(in_features, in_features, bias)\n",
    "        self.linear_o = nn.Linear(in_features, in_features, bias)\n",
    "\n",
    "    def forward(self, q, k, v, mask=None): \n",
    "        #q = self.linear_q(q)\n",
    "        #k = self.linear_k(k)\n",
    "        #v = self.linear_v(v)\n",
    "        \n",
    "        dim = int(self.in_features / self.head_num)\n",
    "        #y = ScaledDotProductAttention()(q, k, v, mask)\n",
    "        q = self._reshape_to_batches(q)\n",
    "        k = self._reshape_to_batches(k)\n",
    "        v = self._reshape_to_batches(v)\n",
    "        if self.type == \"OvO\":\n",
    "            att = OneVSOthers(dim, dim)\n",
    "            y = att([q, k], v) #.cuda()\n",
    "        else:\n",
    "            att = MultiplicativeAttention(dim, dim)\n",
    "            y = att(q, v) #.cuda()\n",
    "        y = self._reshape_from_batches(y)\n",
    "        y = self.linear_o(y)\n",
    "        #if self.activation is not None:\n",
    "        #    y = self.activation(y)\n",
    "        return y\n",
    "\n",
    "    \n",
    "    def _reshape_to_batches(self, x):\n",
    "        seq_len = 1\n",
    "        batch_size, in_feature = x.size()\n",
    "        sub_dim = in_feature // self.head_num\n",
    "        return x.reshape(batch_size, self.head_num, sub_dim)\\\n",
    "                .permute(0, 1, 2)\\\n",
    "                .reshape(batch_size * self.head_num, sub_dim)\n",
    "\n",
    "    def _reshape_from_batches(self, x):\n",
    "        seq_len = 1\n",
    "        batch_size, in_feature = x.size()\n",
    "        batch_size //= self.head_num\n",
    "        out_dim = in_feature * self.head_num\n",
    "        return x.reshape(batch_size, self.head_num,  in_feature)\\\n",
    "                .permute(0, 1, 2)\\\n",
    "                .reshape(batch_size,  out_dim)\n",
    "    \"\"\"\n",
    "\n",
    "    def _reshape_to_batches(self, x):\n",
    "        x = x.unsqueeze(1)\n",
    "        batch_size, seq_len, in_feature = x.size()\n",
    "        sub_dim = in_feature // self.head_num\n",
    "        return x.reshape(batch_size, seq_len, self.head_num, sub_dim)\\\n",
    "                .permute(0, 2, 1, 3)\\\n",
    "                .reshape(batch_size * self.head_num, seq_len, sub_dim)\n",
    "\n",
    "    def _reshape_from_batches(self, x):\n",
    "        x = x.unsqueeze(1)\n",
    "        batch_size, seq_len, in_feature = x.size()\n",
    "        batch_size //= self.head_num\n",
    "        out_dim = in_feature * self.head_num\n",
    "        return x.reshape(batch_size, self.head_num, seq_len, in_feature)\\\n",
    "                .permute(0, 2, 1, 3)\\\n",
    "                .reshape(batch_size, seq_len, out_dim)\n",
    "    \n",
    "    def extra_repr(self):\n",
    "        return 'in_features={}, head_num={}, bias={}, activation={}'.format(\n",
    "            self.in_features, self.head_num, self.bias, self.activation)\n",
    "    \"\"\"\n",
    "\n",
    "class MultimodalFramework(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super(MultimodalFramework, self).__init__()\n",
    "        ##MLP\n",
    "        self.fc1 = nn.Linear(53, 256)\n",
    "        self.fc2 = nn.Linear(256, 256)\n",
    "        self.fc3 = nn.Linear(256, 2)\n",
    "        self.relu = nn.ReLU()\n",
    "        \n",
    "        ##RESNET\n",
    "        self.resnet18 = models.resnet18(pretrained=True)\n",
    "        n_inputs = self.resnet18.fc.in_features\n",
    "\n",
    "        self.resnet18.fc = nn.Sequential(OrderedDict([\n",
    "            ('fc1', nn.Linear(n_inputs, 512))\n",
    "        ])) \n",
    "\n",
    "        self.resnet_classification = nn.Linear(512, 2) #4\n",
    "        \n",
    "        ##BERT\n",
    "        self.bert = BertModel.from_pretrained('bert-base-uncased') \n",
    "        self.bert_classification = nn.Linear(768, 2)\n",
    "        \n",
    "        #Two Modality models\n",
    "        self.bert_resnet_classification = nn.Linear(512 + 768, 2)\n",
    "        self.bert_mlp_classification = nn.Linear(256 + 768, 2)\n",
    "        self.resnet_mlp_classification = nn.Linear(256 + 512, 2)\n",
    "        self.bert_resnet_mlp_classification = nn.Linear(256 + 512 +768, 2)\n",
    "        self.bert_resnet_mlp_l_classification = nn.Linear(256*3, 2)\n",
    "        self.att_classification = nn.Linear(256*2, 2)\n",
    "        self.vaswani_3_classification = nn.Linear(3 *(256*2), 2)\n",
    "        self.OvO_classification = nn.Linear(3*256, 2)\n",
    "\n",
    "        self.res_wrap = nn.Linear(512, 256)\n",
    "        self.bert_wrap = nn.Linear(768, 256)\n",
    "        \n",
    "        self.luong_attention  = MultiplicativeAttention(encoder_dim=256, decoder_dim=256)\n",
    "        self.vaswani_attention  = nn.MultiheadAttention(256, 4, batch_first = True)\n",
    "        self.OvO_attention = OneVSOthers(encoder_dim=256, decoder_dim=256)\n",
    "        #self.OvO_concat_attention = OneVSOthers_concat(encoder_dim=256, decoder_dim=512)\n",
    "        #self.OvO_multihead_attention = MultiheadAttention(256,2, typ = \"OvO\")\n",
    "        #self.luong_multihead_attention = MultiHeadAttention(256,2, typ = \"luong\")\n",
    "        \n",
    "    def bi_directional_att(self, pair):\n",
    "        x = pair[0]\n",
    "        y = pair[1]\n",
    "        attn_output_LV, attn_output_weights_LV = self.vaswani_attention(x, y, y)\n",
    "        attn_output_VL, attn_output_weights_VL = self.vaswani_attention(y, x, x)\n",
    "        combined = torch.cat((attn_output_LV,\n",
    "                              attn_output_VL), dim=1)\n",
    "        return combined\n",
    "\n",
    "    def forward(self, x, model):\n",
    "        if model == \"mlp\":\n",
    "            x = self.fc1(x)\n",
    "            x = self.relu(x)\n",
    "            x = self.fc2(x)\n",
    "            x = self.relu(x)\n",
    "            out = self.fc3(x)\n",
    "            \n",
    "        elif model == \"resnet\":\n",
    "            res = self.resnet18(x)       \n",
    "            out = self.resnet_classification(res)\n",
    "        \n",
    "        elif model == \"bert\":\n",
    "            text, masks = x\n",
    "            bert = self.bert(text, attention_mask=masks, token_type_ids=None).last_hidden_state[:,0,:]\n",
    "            out = self.bert_classification(bert)\n",
    "            \n",
    "        elif model == \"bert_resnet\":\n",
    "            img, text, masks = x\n",
    "            res_emb = self.resnet18(img)\n",
    "            \n",
    "            bert_emb = self.bert(text, attention_mask=masks).last_hidden_state[:,0,:]\n",
    "            combined = torch.cat((res_emb,\n",
    "                                  bert_emb), dim=1)\n",
    "            out = self.bert_resnet_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_luong\":\n",
    "            img, text, masks = x\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "\n",
    "            attn_output_LV = self.luong_attention(bert, res)\n",
    "            attn_output_VL = self.luong_attention(res, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_vaswani\":\n",
    "            img, text, masks = x\n",
    "            res_emb = self.resnet18(img)\n",
    "        \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            res = self.res_wrap(res_emb)\n",
    "            res = res[:, None, :]\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "            bert = bert[:, None, :]\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, res, res)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(res, bert, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV.squeeze(1),\n",
    "                                  attn_output_VL.squeeze(1)), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_mlp\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            \n",
    "            combined = torch.cat((bert,feat), dim=1)\n",
    "            out = self.bert_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_mlp_luong\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "\n",
    "            attn_output_LV = self.luong_attention(bert, feat)\n",
    "            attn_output_VL = self.luong_attention(feat, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)   \n",
    "            \n",
    "        elif model == \"bert_mlp_vaswani\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, feat, feat)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, bert, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"resnet_mlp\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            \n",
    "            combined = torch.cat((feat,res), dim=1)\n",
    "            out = self.resnet_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"resnet_mlp_luong\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            attn_output_LV = self.luong_attention(res, feat)\n",
    "            attn_output_VL = self.luong_attention(feat, res)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)   \n",
    "            \n",
    "        elif model == \"resnet_mlp_vaswani\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(res, feat, feat)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, res, res)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            \n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_resnet_mlp\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            res = self.resnet18(img)\n",
    "\n",
    "            \n",
    "            combined = torch.cat((bert,feat, res), dim=1)\n",
    "            out = self.bert_resnet_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_resnet_mlp_l\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            combined = torch.cat((bert, feat, res), dim=1)\n",
    "            out = self.bert_resnet_mlp_l_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_mlp_vaswani\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "            \n",
    "            pairs = [[feat, bert],[feat,res],[bert,res]]\n",
    "        \n",
    "            results = []\n",
    "            for pair in pairs:\n",
    "                combined = self.bi_directional_att(pair)\n",
    "                results.append(combined)\n",
    "\n",
    "            comb = torch.cat(results, dim=1)\n",
    "            out = self.vaswani_3_classification(comb)\n",
    "            \n",
    "        else:\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "            \n",
    "            attn_txt, weights_txt = self.OvO_attention(feat, res, bert) #[feat, res], bert\n",
    "            attn_img, weights_img = self.OvO_attention(feat, bert,res)\n",
    "            attn_tab, weights_tab = self.OvO_attention(bert, res, feat)\n",
    "\n",
    "            comb = torch.cat([attn_txt, attn_img, attn_tab], dim=1)\n",
    "            out = self.OvO_classification(comb)\n",
    "\n",
    "        return out\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "id": "84209806",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTorch Version:  1.11.0+cu113\n",
      "Torchvision Version:  0.12.0+cu113\n",
      "cpu\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the bert: 56 %\n",
      "AUROC         0.568000\n",
      "accuracy     56.800000\n",
      "precision    57.343286\n",
      "recall       56.800000\n",
      "f1-score     55.985915\n",
      "dtype: float64\n",
      "AUROC       NaN\n",
      "accuracy    NaN\n",
      "precision   NaN\n",
      "recall      NaN\n",
      "f1-score    NaN\n",
      "dtype: float64\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/ipykernel_launcher.py:136: FutureWarning: Dropping of nuisance columns in DataFrame reductions (with 'numeric_only=None') is deprecated; in a future version this will raise TypeError.  Select only valid columns before calling the reduction.\n",
      "/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/ipykernel_launcher.py:137: FutureWarning: Dropping of nuisance columns in DataFrame reductions (with 'numeric_only=None') is deprecated; in a future version this will raise TypeError.  Select only valid columns before calling the reduction.\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import sys\n",
    "import logging\n",
    "from pathlib import Path\n",
    "import random\n",
    "import tarfile\n",
    "import tempfile\n",
    "import warnings\n",
    "import matplotlib.pyplot as plt\n",
    "# import pandas_path  # Path style access for pandas\n",
    "from tqdm import tqdm\n",
    "import torch                    \n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "from torchvision import datasets, models, transforms\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "import os\n",
    "import copy\n",
    "print(\"PyTorch Version: \",torch.__version__)\n",
    "print(\"Torchvision Version: \",torchvision.__version__)\n",
    "\n",
    "from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score\n",
    "\n",
    "from torch.utils.data import TensorDataset\n",
    "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler,Dataset\n",
    " \n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "from transformers import BertTokenizer\n",
    "from transformers import BertForSequenceClassification, AdamW, BertConfig, BertModel\n",
    "\n",
    "\n",
    "#from models import MultimodalFramework\n",
    "from model_utils import set_seed, build_optimizer, ReviewsDataset\n",
    "\n",
    " #'/gpfs/home/mgolovan/data/mgolovan/Reviews/amazon/home/test_rvw_inputs.pt'\n",
    "model_name = \"bert_resnet_luong\"\n",
    "lr = 5e-05 \n",
    "epochs = 6\n",
    "batch_size = 64\n",
    "#best_model_1e-06_22_adamW_20_resnet.pth\n",
    "random_seeds = [15] #15, 0, 1,67,  128, 87, 261, 510, 340, 22\n",
    "df = pd.DataFrame(columns = ['AUROC','accuracy', \"precision\", \"recall\", \"f1-score\", \"CM\", \"CR\"]) #str(batch_size)+ '_'\n",
    "\n",
    "for seed in random_seeds:\n",
    "    model_path = '/users/mgolovan/data/mgolovan/facebook_memes/bert_resnet_models/best_model_' + str(lr)+'_' + str(seed)+'_adamW_'  +  str(epochs)+'_' + str(model_name)+ '.pth_current.pth'\n",
    "\n",
    "    device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\") #str(batch_size)+ '_' +\n",
    "    print(device)\n",
    "\n",
    "    torch.cuda.empty_cache()\n",
    "\n",
    "    model = MultimodalFramework()\n",
    "    model.load_state_dict(torch.load(model_path,map_location=torch.device('cpu'))) #eager-sweep-1\n",
    "    model.to(device)\n",
    "    model.eval()\n",
    "\n",
    "    test_inputs_txt = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/val_txt.pt')\n",
    "    test_inputs_img = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/val_img.pt')\n",
    "\n",
    "    if model_name.split(\"_\")[:2] == [\"bert\", \"resnet\"]:\n",
    "        modality_1 = DataLoader(test_inputs_img, batch_size=batch_size, shuffle=False)\n",
    "        modality_2 = DataLoader(test_inputs_txt, batch_size=batch_size, shuffle=False ) \n",
    "\n",
    "    elif model_name.split(\"_\")[:2] == [\"resnet\", \"mlp\"]:\n",
    "        modality_1 = DataLoader(test_inputs_img, batch_size=batch_size, shuffle=False ) \n",
    "        modality_2 = DataLoader(test_inputs_tab, batch_size=batch_size, shuffle=False )\n",
    "\n",
    "    else:\n",
    "        modality_1 = DataLoader(test_inputs_tab, batch_size=batch_size, shuffle=False ) \n",
    "        modality_2 = DataLoader(test_inputs_txt, batch_size=batch_size, shuffle=False )\n",
    "\n",
    "\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    pred = []\n",
    "    test_labels = []\n",
    "\n",
    "    # since we're not training, we don't need to calculate the gradients for our outputs\n",
    "    with torch.no_grad():\n",
    "        for modality1, modality2 in zip(modality_1, modality_2):\n",
    "\n",
    "            if model_name.split(\"_\")[:2] == [\"bert\", \"resnet\"]:\n",
    "                text_inp, masks, text_labels = modality2\n",
    "                img_inp, labels = modality1\n",
    "\n",
    "                text_inp, masks, text_labels = text_inp.to(device), masks.to(device), text_labels.to(device)\n",
    "                img_inp, labels = img_inp.to(device), labels.to(device)\n",
    "\n",
    "                outputs = model([img_inp, text_inp, masks], model_name)\n",
    "\n",
    "            elif model_name.split(\"_\")[:2] == [\"resnet\", \"mlp\"]:\n",
    "                img_inp, labels = modality1\n",
    "                tab_inp, tab_labels = modality2\n",
    "                tab_inp, tab_labels = tab_inp.float(), tab_labels.long()\n",
    "\n",
    "                tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)\n",
    "                img_inp, labels = img_inp.to(device), labels.to(device)\n",
    "\n",
    "                outputs = model([tab_inp, img_inp], model_name)\n",
    "            else:\n",
    "                tab_inp, tab_labels = modality1\n",
    "                text_inp, masks, labels = modality2\n",
    "                tab_inp, tab_labels = tab_inp.float(), tab_labels.long()\n",
    "\n",
    "                tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)\n",
    "                text_inp, masks, labels = text_inp.to(device), masks.to(device), labels.to(device)\n",
    "\n",
    "                outputs = model([tab_inp, text_inp, masks], model_name)\n",
    "\n",
    "            test_labels.extend(np.array(labels.cpu()))\n",
    "            _, predicted = torch.max(outputs, 1)\n",
    "            pred.extend(predicted.cpu().numpy())\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "\n",
    "    acc= 100 * correct / total\n",
    "    print(f'Accuracy of the bert: {100 * correct // total} %')\n",
    "\n",
    "    test_labels = np.array(test_labels)\n",
    "\n",
    "    #print(confusion_matrix(test_labels, pred))\n",
    "    cm = confusion_matrix(test_labels, pred)\n",
    "    #print(classification_report(test_labels, pred))\n",
    "    cr = classification_report(test_labels, pred, output_dict=True)\n",
    "    auc = roc_auc_score(test_labels, pred)\n",
    "    df = df.append({'AUROC': auc,'accuracy': acc, \"precision\":cr[\"macro avg\"][\"precision\"]*100 ,\n",
    "                    \"recall\":cr[\"macro avg\"][\"recall\"]*100, \"f1-score\":cr[\"macro avg\"][\"f1-score\"]*100,\n",
    "                    \"CM\":cm, \"CR\":cr}, ignore_index=True)\n",
    "\n",
    "#df.to_csv(model_name + \"_2_no_act_one_results.csv\")\n",
    "print(df.mean())\n",
    "print(df.std())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb8b7bb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "Accuracy of the bert: 56 %\n",
    "AUROC         0.568000\n",
    "accuracy     56.800000\n",
    "precision    57.343286\n",
    "recall       56.800000\n",
    "f1-score     55.985915\n",
    "dtype: float64\n",
    "AUROC       NaN\n",
    "accuracy    NaN\n",
    "precision   NaN\n",
    "recall      NaN\n",
    "f1-score    NaN\n",
    "dtype: float64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "998e7313",
   "metadata": {},
   "outputs": [],
   "source": [
    "Accuracy of the bert: 60 %\n",
    "AUROC         0.606000\n",
    "accuracy     60.600000\n",
    "precision    61.775894\n",
    "recall       60.600000\n",
    "f1-score     59.591236\n",
    "dtype: float64\n",
    "AUROC       NaN\n",
    "accuracy    NaN\n",
    "precision   NaN\n",
    "recall      NaN\n",
    "f1-score    NaN\n",
    "dtype: float64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8227308",
   "metadata": {},
   "outputs": [],
   "source": [
    "Accuracy of the bert: 66 %\n",
    "AUROC         0.653181\n",
    "accuracy     66.000000\n",
    "precision    64.550000\n",
    "recall       65.318093\n",
    "f1-score     64.732094\n",
    "dtype: float64\n",
    "AUROC       NaN\n",
    "accuracy    NaN\n",
    "precision   NaN\n",
    "recall      NaN\n",
    "f1-score    NaN\n",
    "dtype: float64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "571c282a",
   "metadata": {},
   "outputs": [],
   "source": [
    "Accuracy of the bert: 66 %\n",
    "AUROC         0.659618\n",
    "accuracy     66.000000\n",
    "precision    64.841198\n",
    "recall       65.961799\n",
    "f1-score     64.824099\n",
    "dtype: float64\n",
    "AUROC       NaN\n",
    "accuracy    NaN\n",
    "precision   NaN\n",
    "recall      NaN\n",
    "f1-score    NaN\n",
    "dtype: float64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "58f58460",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<generator object Module.parameters at 0x7fcc38b1a8d0>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "a487d114",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import logging\n",
    "import random\n",
    "# import pandas_path  # Path style access for pandas\n",
    "from tqdm import tqdm\n",
    "import torch                    \n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "from torchvision import datasets, models, transforms\n",
    "from torch import flatten\n",
    "\n",
    "from collections import OrderedDict\n",
    "from transformers import BertModel, DistilBertModel\n",
    "\n",
    "from PIL import ImageFile,Image\n",
    "ImageFile.LOAD_TRUNCATED_IMAGES = True\n",
    "\n",
    "from MHA_modified import MultiheadAttention\n",
    "\n",
    "class Attention(torch.nn.Module):\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__()\n",
    "        self.encoder_dim = encoder_dim\n",
    "        self.decoder_dim = decoder_dim\n",
    "\n",
    "    def forward(self, \n",
    "        query: torch.Tensor,  # [decoder_dim]\n",
    "        values: torch.Tensor, # [seq_length, encoder_dim]\n",
    "        ):\n",
    "        weights = self._get_weights(query, values) # [seq_length]\n",
    "        weights = torch.nn.functional.softmax(weights, dim=0)\n",
    "        return weights @ values  # [encoder_dim]\n",
    "\n",
    "class MultiplicativeAttention(Attention):\n",
    "\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__(encoder_dim, decoder_dim)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(\n",
    "            self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))\n",
    "\n",
    "    def _get_weights(self,\n",
    "        query: torch.Tensor,  # [decoder_dim]\n",
    "        values: torch.Tensor, # [seq_length, encoder_dim]\n",
    "    ):\n",
    "        weights = query @ self.W @ values.T  # [seq_length]\n",
    "        return weights #/np.sqrt(self.decoder_dim)    \n",
    "\n",
    "\n",
    "\n",
    "class OneVSOthers(Attention):\n",
    "\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__(encoder_dim, decoder_dim)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))\n",
    "\n",
    "    def _get_weights(self,others, main):\n",
    "        mean = sum(others) / len(others)\n",
    "        weights = mean @ self.W @ main.T  # [seq_length]\n",
    "        return weights #/np.sqrt(self.decoder_dim) \n",
    "    \n",
    "class OneVSOthers_concat(Attention):\n",
    "\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__(encoder_dim, decoder_dim)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))\n",
    "\n",
    "    def _get_weights(self,others, main):\n",
    "        m2 = others[0]\n",
    "        m3 = others[1]\n",
    "        con = torch.cat((m2, m3), dim=1)\n",
    "        weights = con @ self.W @ main.T  # [seq_length]\n",
    "        return weights/np.sqrt(self.decoder_dim)\n",
    "#\"\"\"\n",
    "class MultiHeadAttention(nn.Module):\n",
    "\n",
    "    def __init__(self,\n",
    "                 in_features,\n",
    "                 head_num, typ,\n",
    "                 bias=True,\n",
    "                 activation=F.relu):\n",
    "        \n",
    "        super(MultiHeadAttention, self).__init__()\n",
    "        if in_features % head_num != 0:\n",
    "            raise ValueError('`in_features`({}) should be divisible by `head_num`({})'.format(in_features, head_num))\n",
    "        self.in_features = in_features\n",
    "        self.type = typ\n",
    "        self.head_num = head_num\n",
    "        self.activation = activation\n",
    "        self.bias = bias\n",
    "        self.linear_q = nn.Linear(in_features, in_features, bias)\n",
    "        self.linear_k = nn.Linear(in_features, in_features, bias)\n",
    "        self.linear_v = nn.Linear(in_features, in_features, bias)\n",
    "        self.linear_o = nn.Linear(in_features, in_features, bias)\n",
    "\n",
    "    def forward(self, q, k, v, mask=None): \n",
    "        #q = self.linear_q(q)\n",
    "        #k = self.linear_k(k)\n",
    "        #v = self.linear_v(v)\n",
    "        \n",
    "        dim = int(self.in_features / self.head_num)\n",
    "        #y = ScaledDotProductAttention()(q, k, v, mask)\n",
    "        q = self._reshape_to_batches(q)\n",
    "        k = self._reshape_to_batches(k)\n",
    "        v = self._reshape_to_batches(v)\n",
    "        if self.type == \"OvO\":\n",
    "            att = OneVSOthers(dim, dim)\n",
    "            y = att([q, k], v) #.cuda()\n",
    "        else:\n",
    "            att = MultiplicativeAttention(dim, dim)\n",
    "            y = att(q, v) #.cuda()\n",
    "        y = self._reshape_from_batches(y)\n",
    "        y = self.linear_o(y)\n",
    "        #if self.activation is not None:\n",
    "        #    y = self.activation(y)\n",
    "        return y\n",
    "\n",
    "    \n",
    "    def _reshape_to_batches(self, x):\n",
    "        seq_len = 1\n",
    "        batch_size, in_feature = x.size()\n",
    "        sub_dim = in_feature // self.head_num\n",
    "        return x.reshape(batch_size, self.head_num, sub_dim)\\\n",
    "                .permute(0, 1, 2)\\\n",
    "                .reshape(batch_size * self.head_num, sub_dim)\n",
    "\n",
    "    def _reshape_from_batches(self, x):\n",
    "        seq_len = 1\n",
    "        batch_size, in_feature = x.size()\n",
    "        batch_size //= self.head_num\n",
    "        out_dim = in_feature * self.head_num\n",
    "        return x.reshape(batch_size, self.head_num,  in_feature)\\\n",
    "                .permute(0, 1, 2)\\\n",
    "                .reshape(batch_size,  out_dim)\n",
    "    \"\"\"\n",
    "\n",
    "    def _reshape_to_batches(self, x):\n",
    "        x = x.unsqueeze(1)\n",
    "        batch_size, seq_len, in_feature = x.size()\n",
    "        sub_dim = in_feature // self.head_num\n",
    "        return x.reshape(batch_size, seq_len, self.head_num, sub_dim)\\\n",
    "                .permute(0, 2, 1, 3)\\\n",
    "                .reshape(batch_size * self.head_num, seq_len, sub_dim)\n",
    "\n",
    "    def _reshape_from_batches(self, x):\n",
    "        x = x.unsqueeze(1)\n",
    "        batch_size, seq_len, in_feature = x.size()\n",
    "        batch_size //= self.head_num\n",
    "        out_dim = in_feature * self.head_num\n",
    "        return x.reshape(batch_size, self.head_num, seq_len, in_feature)\\\n",
    "                .permute(0, 2, 1, 3)\\\n",
    "                .reshape(batch_size, seq_len, out_dim)\n",
    "    \n",
    "    def extra_repr(self):\n",
    "        return 'in_features={}, head_num={}, bias={}, activation={}'.format(\n",
    "            self.in_features, self.head_num, self.bias, self.activation)\n",
    "    \"\"\"\n",
    "\n",
    "class MultimodalFramework(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super(MultimodalFramework, self).__init__()\n",
    "        ##MLP\n",
    "        self.fc1 = nn.Linear(53, 256)\n",
    "        self.fc2 = nn.Linear(256, 256)\n",
    "        self.fc3 = nn.Linear(256, 2)\n",
    "        self.relu = nn.ReLU()\n",
    "        \n",
    "        ##RESNET\n",
    "        self.resnet18 = models.resnet18(pretrained=True)\n",
    "        n_inputs = self.resnet18.fc.in_features\n",
    "\n",
    "        self.resnet18.fc = nn.Sequential(OrderedDict([\n",
    "            ('fc1', nn.Linear(n_inputs, 512))\n",
    "        ])) \n",
    "\n",
    "        self.resnet_classification = nn.Linear(512, 2) #4\n",
    "        \n",
    "        ##BERT\n",
    "        self.bert = BertModel.from_pretrained('bert-base-uncased') \n",
    "        self.bert_classification = nn.Linear(768, 2)\n",
    "        \n",
    "        #Two Modality models\n",
    "        self.bert_resnet_classification = nn.Linear(512 + 768, 2)\n",
    "        self.bert_mlp_classification = nn.Linear(256 + 768, 2)\n",
    "        self.resnet_mlp_classification = nn.Linear(256 + 512, 2)\n",
    "        self.bert_resnet_mlp_classification = nn.Linear(256 + 512 +768, 2)\n",
    "        self.bert_resnet_mlp_l_classification = nn.Linear(256*3, 2)\n",
    "        self.att_classification = nn.Linear(256*2, 2)\n",
    "        self.vaswani_3_classification = nn.Linear(3 *(256*2), 2)\n",
    "        self.OvO_classification = nn.Linear(3*256, 2)\n",
    "\n",
    "        self.res_wrap = nn.Linear(512, 256)\n",
    "        self.bert_wrap = nn.Linear(768, 256)\n",
    "        \n",
    "        #self.luong_attention  = MultiplicativeAttention(encoder_dim=256, decoder_dim=256)\n",
    "        self.vaswani_attention  = nn.MultiheadAttention(256, 4, batch_first = True)\n",
    "        #self.OvO_attention = OneVSOthers(encoder_dim=256, decoder_dim=256)\n",
    "        self.OvO_concat_attention = OneVSOthers_concat(encoder_dim=256, decoder_dim=512)\n",
    "        self.OvO_multihead_attention = MultiheadAttention(256,2, typ = \"OvO\")\n",
    "        self.luong_multihead_attention = MultiHeadAttention(256,2, typ = \"luong\")\n",
    "        \n",
    "    def bi_directional_att(self, pair):\n",
    "        x = pair[0]\n",
    "        y = pair[1]\n",
    "        attn_output_LV, attn_output_weights_LV = self.vaswani_attention(x, y, y)\n",
    "        attn_output_VL, attn_output_weights_VL = self.vaswani_attention(y, x, x)\n",
    "        combined = torch.cat((attn_output_LV,\n",
    "                              attn_output_VL), dim=1)\n",
    "        return combined\n",
    "\n",
    "    def forward(self, x, model):\n",
    "        if model == \"mlp\":\n",
    "            x = self.fc1(x)\n",
    "            x = self.relu(x)\n",
    "            x = self.fc2(x)\n",
    "            x = self.relu(x)\n",
    "            out = self.fc3(x)\n",
    "            \n",
    "        elif model == \"resnet\":\n",
    "            res = self.resnet18(x)       \n",
    "            out = self.resnet_classification(res)\n",
    "        \n",
    "        elif model == \"bert\":\n",
    "            text, masks = x\n",
    "            bert = self.bert(text, attention_mask=masks, token_type_ids=None).last_hidden_state[:,0,:]\n",
    "            out = self.bert_classification(bert)\n",
    "            \n",
    "        elif model == \"bert_resnet\":\n",
    "            img, text, masks = x\n",
    "            res_emb = self.resnet18(img)\n",
    "            \n",
    "            bert_emb = self.bert(text, attention_mask=masks).last_hidden_state[:,0,:]\n",
    "            combined = torch.cat((res_emb,\n",
    "                                  bert_emb), dim=1)\n",
    "            out = self.bert_resnet_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_luong\":\n",
    "            img, text, masks = x\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "\n",
    "            attn_output_LV = self.luong_multihead_attention(bert, res, res)\n",
    "            attn_output_VL = self.luong_multihead_attention(res, bert, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_vaswani\":\n",
    "            img, text, masks = x\n",
    "            res_emb = self.resnet18(img)\n",
    "        \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            res = self.res_wrap(res_emb)\n",
    "            res = res[:, None, :]\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "            bert = bert[:, None, :]\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, res, res)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(res, bert, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV.squeeze(1),\n",
    "                                  attn_output_VL.squeeze(1)), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_mlp\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            \n",
    "            combined = torch.cat((bert,feat), dim=1)\n",
    "            out = self.bert_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_mlp_luong\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "\n",
    "            attn_output_LV = self.luong_attention(bert, feat)\n",
    "            attn_output_VL = self.luong_attention(feat, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)   \n",
    "            \n",
    "        elif model == \"bert_mlp_vaswani\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, feat, feat)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, bert, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"resnet_mlp\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            \n",
    "            combined = torch.cat((feat,res), dim=1)\n",
    "            out = self.resnet_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"resnet_mlp_luong\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            attn_output_LV = self.luong_attention(res, feat)\n",
    "            attn_output_VL = self.luong_attention(feat, res)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)   \n",
    "            \n",
    "        elif model == \"resnet_mlp_vaswani\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(res, feat, feat)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, res, res)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            \n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_resnet_mlp\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            res = self.resnet18(img)\n",
    "\n",
    "            \n",
    "            combined = torch.cat((bert,feat, res), dim=1)\n",
    "            out = self.bert_resnet_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_resnet_mlp_l\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            combined = torch.cat((bert, feat, res), dim=1)\n",
    "            out = self.bert_resnet_mlp_l_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_mlp_vaswani\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "            \n",
    "            pairs = [[feat, bert],[feat,res],[bert,res]]\n",
    "        \n",
    "            results = []\n",
    "            for pair in pairs:\n",
    "                combined = self.bi_directional_att(pair)\n",
    "                results.append(combined)\n",
    "\n",
    "            comb = torch.cat(results, dim=1)\n",
    "            out = self.vaswani_3_classification(comb)\n",
    "            \n",
    "        else:\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "            \n",
    "            attn_txt, weights_txt = self.OvO_multihead_attention(feat, res, bert) #[feat, res], bert\n",
    "            attn_img, weights_img = self.OvO_multihead_attention(feat, bert,res)\n",
    "            attn_tab, weights_tab = self.OvO_multihead_attention(bert, res, feat)\n",
    "\n",
    "            comb = torch.cat([attn_txt, attn_img, attn_tab], dim=1)\n",
    "            out = self.OvO_classification(comb)\n",
    "\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8ddb5a10",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import logging\n",
    "import random\n",
    "# import pandas_path  # Path style access for pandas\n",
    "from tqdm import tqdm\n",
    "import torch                    \n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "from torchvision import datasets, models, transforms\n",
    "from torch import flatten\n",
    "\n",
    "from collections import OrderedDict\n",
    "from transformers import BertModel, DistilBertModel\n",
    "\n",
    "from PIL import ImageFile,Image\n",
    "ImageFile.LOAD_TRUNCATED_IMAGES = True\n",
    "\n",
    "from MHA_modified import MultiheadAttention\n",
    "\n",
    "class Attention(torch.nn.Module):\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__()\n",
    "        self.encoder_dim = encoder_dim\n",
    "        self.decoder_dim = decoder_dim\n",
    "\n",
    "    def forward(self, \n",
    "        query: torch.Tensor,  # [decoder_dim]\n",
    "        values: torch.Tensor, # [seq_length, encoder_dim]\n",
    "        ):\n",
    "        weights = self._get_weights(query, values) # [seq_length]\n",
    "        weights = torch.nn.functional.softmax(weights, dim=0)\n",
    "        return weights @ values  # [encoder_dim]\n",
    "\n",
    "class MultiplicativeAttention(Attention):\n",
    "\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__(encoder_dim, decoder_dim)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1)) #.to('cuda')\n",
    "\n",
    "    def _get_weights(self,\n",
    "        query: torch.Tensor,  # [decoder_dim]\n",
    "        values: torch.Tensor, # [seq_length, encoder_dim]\n",
    "    ):\n",
    "        weights = query @ self.W @ values.T  # [seq_length]\n",
    "        return weights #/np.sqrt(self.decoder_dim)    \n",
    "\n",
    "\n",
    "\n",
    "class OneVSOthers(Attention):\n",
    "\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__(encoder_dim, decoder_dim)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1)) #.to('cuda')\n",
    "\n",
    "    def _get_weights(self,others, main):\n",
    "        mean = sum(others) / len(others)\n",
    "        weights = mean @ self.W @ main.T  # [seq_length]\n",
    "        return weights #/np.sqrt(self.decoder_dim) \n",
    "    \n",
    "class OneVSOthers_concat(Attention):\n",
    "\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__(encoder_dim, decoder_dim)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))\n",
    "\n",
    "    def _get_weights(self,others, main):\n",
    "        m2 = others[0]\n",
    "        m3 = others[1]\n",
    "        con = torch.cat((m2, m3), dim=1)\n",
    "        weights = con @ self.W @ main.T  # [seq_length]\n",
    "        return weights/np.sqrt(self.decoder_dim)\n",
    "#\"\"\"\n",
    "\n",
    "class ScaledDotProductAttention(nn.Module):\n",
    "    \"\"\"\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__()\n",
    "        self.encoder_dim = encoder_dim\n",
    "        self.decoder_dim = decoder_dim\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1))\n",
    "    \"\"\"\n",
    "    def forward(self, W, query, key, value):\n",
    "\n",
    "        scores = query @ W @ value.T\n",
    "        attention = torch.nn.functional.softmax(scores, dim=0)\n",
    "        #attention = F.softmax(scores, dim=-1)\n",
    "        return attention.matmul(value)\n",
    "    \n",
    "class MultiHeadAttention(nn.Module):\n",
    "\n",
    "    def __init__(self,\n",
    "                 in_features,\n",
    "                 head_num, typ,\n",
    "                 bias=True,\n",
    "                 activation=F.relu):\n",
    "        \n",
    "        super(MultiHeadAttention, self).__init__()\n",
    "        if in_features % head_num != 0:\n",
    "            raise ValueError('`in_features`({}) should be divisible by `head_num`({})'.format(in_features, head_num))\n",
    "        self.in_features = in_features\n",
    "        self.type = typ\n",
    "        self.head_num = head_num\n",
    "        self.activation = activation\n",
    "        self.bias = bias\n",
    "        self.dim = int(self.in_features / self.head_num)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(self.dim, self.dim).uniform_(-0.1, 0.1))\n",
    "        self.linear_q = nn.Linear(in_features, in_features, bias)\n",
    "        self.linear_k = nn.Linear(in_features, in_features, bias)\n",
    "        self.linear_v = nn.Linear(in_features, in_features, bias)\n",
    "        self.linear_o = nn.Linear(in_features, in_features, bias)\n",
    "\n",
    "    def forward(self, q, k, v, mask=None): \n",
    "        #q = self.linear_q(q)\n",
    "        #k = self.linear_k(k)\n",
    "        #v = self.linear_v(v)\n",
    "        \n",
    "        dim = int(self.in_features / self.head_num)\n",
    "        #y = ScaledDotProductAttention()(q, k, v, mask)\n",
    "        #k, q, v = k.to(\"cpu\"), q.to(\"cpu\"), v.to(\"cpu\")\n",
    "        q = self._reshape_to_batches(q)\n",
    "        k = self._reshape_to_batches(k)\n",
    "        v = self._reshape_to_batches(v)\n",
    "        if self.type == \"OvO\":\n",
    "            att = OneVSOthers(dim, dim)\n",
    "            y = att([q, k], v) #.cuda()\n",
    "        else:\n",
    "            #att = MultiplicativeAttention(dim, dim)\n",
    "            att = ScaledDotProductAttention()\n",
    "            y = att(self.W, q,k, v) #.cuda()\n",
    "        #y = y.to(\"cuda\")\n",
    "        y = self._reshape_from_batches(y)\n",
    "        y = self.linear_o(y)\n",
    "        #if self.activation is not None:\n",
    "        #    y = self.activation(y)\n",
    "        return y\n",
    "\n",
    "    \n",
    "    def _reshape_to_batches(self, x):\n",
    "        seq_len = 1\n",
    "        batch_size, in_feature = x.size()\n",
    "        sub_dim = in_feature // self.head_num\n",
    "        return x.reshape(batch_size, self.head_num, sub_dim)\\\n",
    "                .permute(0, 1, 2)\\\n",
    "                .reshape(batch_size * self.head_num, sub_dim)\n",
    "\n",
    "    def _reshape_from_batches(self, x):\n",
    "        seq_len = 1\n",
    "        batch_size, in_feature = x.size()\n",
    "        batch_size //= self.head_num\n",
    "        out_dim = in_feature * self.head_num\n",
    "        return x.reshape(batch_size, self.head_num,  in_feature)\\\n",
    "                .permute(0, 1, 2)\\\n",
    "                .reshape(batch_size,  out_dim)\n",
    "    \"\"\"\n",
    "\n",
    "    def _reshape_to_batches(self, x):\n",
    "        x = x.unsqueeze(1)\n",
    "        batch_size, seq_len, in_feature = x.size()\n",
    "        sub_dim = in_feature // self.head_num\n",
    "        return x.reshape(batch_size, seq_len, self.head_num, sub_dim)\\\n",
    "                .permute(0, 2, 1, 3)\\\n",
    "                .reshape(batch_size * self.head_num, seq_len, sub_dim)\n",
    "\n",
    "    def _reshape_from_batches(self, x):\n",
    "        x = x.unsqueeze(1)\n",
    "        batch_size, seq_len, in_feature = x.size()\n",
    "        batch_size //= self.head_num\n",
    "        out_dim = in_feature * self.head_num\n",
    "        return x.reshape(batch_size, self.head_num, seq_len, in_feature)\\\n",
    "                .permute(0, 2, 1, 3)\\\n",
    "                .reshape(batch_size, seq_len, out_dim)\n",
    "    \n",
    "    def extra_repr(self):\n",
    "        return 'in_features={}, head_num={}, bias={}, activation={}'.format(\n",
    "            self.in_features, self.head_num, self.bias, self.activation)\n",
    "    \"\"\"\n",
    "\n",
    "class MultimodalFramework(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super(MultimodalFramework, self).__init__()\n",
    "        ##MLP\n",
    "        self.fc1 = nn.Linear(53, 256)\n",
    "        self.fc2 = nn.Linear(256, 256)\n",
    "        self.fc3 = nn.Linear(256, 2)\n",
    "        self.relu = nn.ReLU()\n",
    "        \n",
    "        ##RESNET\n",
    "        self.resnet18 = models.resnet18(pretrained=True)\n",
    "        n_inputs = self.resnet18.fc.in_features\n",
    "\n",
    "        self.resnet18.fc = nn.Sequential(OrderedDict([\n",
    "            ('fc1', nn.Linear(n_inputs, 512))\n",
    "        ])) \n",
    "\n",
    "        self.resnet_classification = nn.Linear(512, 2) #4\n",
    "        \n",
    "        ##BERT\n",
    "        self.bert = BertModel.from_pretrained('bert-base-uncased') \n",
    "        self.bert_classification = nn.Linear(768, 2)\n",
    "        \n",
    "        #Two Modality models\n",
    "        self.bert_resnet_classification = nn.Linear(512 + 768, 2)\n",
    "        self.bert_mlp_classification = nn.Linear(256 + 768, 2)\n",
    "        self.resnet_mlp_classification = nn.Linear(256 + 512, 2)\n",
    "        self.bert_resnet_mlp_classification = nn.Linear(256 + 512 +768, 2)\n",
    "        self.bert_resnet_mlp_l_classification = nn.Linear(256*3, 2)\n",
    "        self.att_classification = nn.Linear(256*2, 2)\n",
    "        self.vaswani_3_classification = nn.Linear(3 *(256*2), 2)\n",
    "        self.OvO_classification = nn.Linear(3*256, 2)\n",
    "\n",
    "        self.res_wrap = nn.Linear(512, 256)\n",
    "        self.bert_wrap = nn.Linear(768, 256)\n",
    "        \n",
    "        #self.luong_attention  = MultiplicativeAttention(encoder_dim=256, decoder_dim=256)\n",
    "        self.vaswani_attention  = nn.MultiheadAttention(256, 4, batch_first = True)\n",
    "        #self.OvO_attention = OneVSOthers(encoder_dim=256, decoder_dim=256)\n",
    "        self.OvO_concat_attention = OneVSOthers_concat(encoder_dim=256, decoder_dim=512)\n",
    "        self.OvO_multihead_attention = MultiheadAttention(256,2, typ = \"OvO\")\n",
    "        self.luong_multihead_attention = MultiHeadAttention(256,2, typ = \"luong\")\n",
    "        \n",
    "    def bi_directional_att(self, pair):\n",
    "        x = pair[0]\n",
    "        y = pair[1]\n",
    "        attn_output_LV, attn_output_weights_LV = self.vaswani_attention(x, y, y)\n",
    "        attn_output_VL, attn_output_weights_VL = self.vaswani_attention(y, x, x)\n",
    "        combined = torch.cat((attn_output_LV,\n",
    "                              attn_output_VL), dim=1)\n",
    "        return combined\n",
    "\n",
    "    def forward(self, x, model):\n",
    "        if model == \"mlp\":\n",
    "            x = self.fc1(x)\n",
    "            x = self.relu(x)\n",
    "            x = self.fc2(x)\n",
    "            x = self.relu(x)\n",
    "            out = self.fc3(x)\n",
    "            \n",
    "        elif model == \"resnet\":\n",
    "            res = self.resnet18(x)       \n",
    "            out = self.resnet_classification(res)\n",
    "        \n",
    "        elif model == \"bert\":\n",
    "            text, masks = x\n",
    "            bert = self.bert(text, attention_mask=masks, token_type_ids=None).last_hidden_state[:,0,:]\n",
    "            out = self.bert_classification(bert)\n",
    "            \n",
    "        elif model == \"bert_resnet\":\n",
    "            img, text, masks = x\n",
    "            res_emb = self.resnet18(img)\n",
    "            \n",
    "            bert_emb = self.bert(text, attention_mask=masks).last_hidden_state[:,0,:]\n",
    "            combined = torch.cat((res_emb,\n",
    "                                  bert_emb), dim=1)\n",
    "            out = self.bert_resnet_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_luong\":\n",
    "            img, text, masks = x\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "\n",
    "            attn_output_LV = self.luong_multihead_attention(bert, res, res)\n",
    "            attn_output_VL = self.luong_multihead_attention(res, bert, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_vaswani\":\n",
    "            img, text, masks = x\n",
    "            res_emb = self.resnet18(img)\n",
    "        \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            res = self.res_wrap(res_emb)\n",
    "            res = res[:, None, :]\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "            bert = bert[:, None, :]\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, res, res)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(res, bert, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV.squeeze(1),\n",
    "                                  attn_output_VL.squeeze(1)), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_mlp\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            \n",
    "            combined = torch.cat((bert,feat), dim=1)\n",
    "            out = self.bert_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_mlp_luong\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "\n",
    "            attn_output_LV = self.luong_attention(bert, feat)\n",
    "            attn_output_VL = self.luong_attention(feat, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)   \n",
    "            \n",
    "        elif model == \"bert_mlp_vaswani\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, feat, feat)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, bert, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"resnet_mlp\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            \n",
    "            combined = torch.cat((feat,res), dim=1)\n",
    "            out = self.resnet_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"resnet_mlp_luong\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            attn_output_LV = self.luong_attention(res, feat)\n",
    "            attn_output_VL = self.luong_attention(feat, res)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)   \n",
    "            \n",
    "        elif model == \"resnet_mlp_vaswani\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(res, feat, feat)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, res, res)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            \n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_resnet_mlp\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            res = self.resnet18(img)\n",
    "\n",
    "            \n",
    "            combined = torch.cat((bert,feat, res), dim=1)\n",
    "            out = self.bert_resnet_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_resnet_mlp_l\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            combined = torch.cat((bert, feat, res), dim=1)\n",
    "            out = self.bert_resnet_mlp_l_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_mlp_vaswani\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "            \n",
    "            pairs = [[feat, bert],[feat,res],[bert,res]]\n",
    "        \n",
    "            results = []\n",
    "            for pair in pairs:\n",
    "                combined = self.bi_directional_att(pair)\n",
    "                results.append(combined)\n",
    "\n",
    "            comb = torch.cat(results, dim=1)\n",
    "            out = self.vaswani_3_classification(comb)\n",
    "            \n",
    "        else:\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "            \n",
    "            attn_txt, weights_txt = self.OvO_multihead_attention(feat, res, bert) #[feat, res], bert\n",
    "            attn_img, weights_img = self.OvO_multihead_attention(feat, bert,res)\n",
    "            attn_tab, weights_tab = self.OvO_multihead_attention(bert, res, feat)\n",
    "\n",
    "            comb = torch.cat([attn_txt, attn_img, attn_tab], dim=1)\n",
    "            out = self.OvO_classification(comb)\n",
    "\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "b63d7efb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#from models import MultimodalFramework\n",
    "model = MultimodalFramework()\n",
    "model.load_state_dict(torch.load(\"/gpfs/home/mgolovan/data/mgolovan/facebook_memes/bert_resnet_models/best_model_lin1e-06_340_adamW_128_34_bert_resnet_luong.pth_best.pth\",map_location=torch.device('cpu')))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "36b57f75",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTorch Version:  1.11.0+cu113\n",
      "Torchvision Version:  0.12.0+cu113\n",
      "Random seed set as 42\n",
      "Accuracy of the bert: 57.2 %\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import sys\n",
    "import logging\n",
    "from pathlib import Path\n",
    "import random\n",
    "import tarfile\n",
    "import tempfile\n",
    "import warnings\n",
    "import matplotlib.pyplot as plt\n",
    "# import pandas_path  # Path style access for pandas\n",
    "from tqdm import tqdm\n",
    "import torch                    \n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "from torchvision import datasets, models, transforms\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "import os\n",
    "import copy\n",
    "print(\"PyTorch Version: \",torch.__version__)\n",
    "print(\"Torchvision Version: \",torchvision.__version__)\n",
    "\n",
    "from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score\n",
    "\n",
    "from torch.utils.data import TensorDataset\n",
    "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler,Dataset\n",
    " \n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "from transformers import BertTokenizer\n",
    "from transformers import BertForSequenceClassification, AdamW, BertConfig, BertModel\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\") #str(batch_size)+ '_' +\n",
    "from model_utils import set_seed\n",
    "\n",
    "model.eval()\n",
    "set_seed(42)\n",
    "batch_size =128\n",
    "model_name = \"bert_resnet_luong\"\n",
    "test_inputs_txt = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/val_txt.pt')\n",
    "test_inputs_img = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/val_img.pt')\n",
    "\n",
    "if model_name.split(\"_\")[:2] == [\"bert\", \"resnet\"]:\n",
    "    modality_1 = DataLoader(test_inputs_img, batch_size=batch_size, shuffle=False)\n",
    "    modality_2 = DataLoader(test_inputs_txt, batch_size=batch_size, shuffle=False ) \n",
    "\n",
    "elif model_name.split(\"_\")[:2] == [\"resnet\", \"mlp\"]:\n",
    "    modality_1 = DataLoader(test_inputs_img, batch_size=batch_size, shuffle=False ) \n",
    "    modality_2 = DataLoader(test_inputs_tab, batch_size=batch_size, shuffle=False )\n",
    "\n",
    "else:\n",
    "    modality_1 = DataLoader(test_inputs_tab, batch_size=batch_size, shuffle=False ) \n",
    "    modality_2 = DataLoader(test_inputs_txt, batch_size=batch_size, shuffle=False )\n",
    "\n",
    "\n",
    "correct = 0\n",
    "total = 0\n",
    "running_loss = 0\n",
    "pred = []\n",
    "test_labels = []\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "# since we're not training, we don't need to calculate the gradients for our outputs\n",
    "with torch.no_grad():\n",
    "    for modality1, modality2 in zip(modality_1, modality_2):\n",
    "\n",
    "        if model_name.split(\"_\")[:2] == [\"bert\", \"resnet\"]:\n",
    "            text_inp, masks, text_labels = modality2\n",
    "            img_inp, labels = modality1\n",
    "\n",
    "            text_inp, masks, text_labels = text_inp.to(device), masks.to(device), text_labels.to(device)\n",
    "            img_inp, labels = img_inp.to(device), labels.to(device)\n",
    "\n",
    "            outputs = model([img_inp, text_inp, masks], model_name)\n",
    "\n",
    "        elif model_name.split(\"_\")[:2] == [\"resnet\", \"mlp\"]:\n",
    "            img_inp, labels = modality1\n",
    "            tab_inp, tab_labels = modality2\n",
    "            tab_inp, tab_labels = tab_inp.float(), tab_labels.long()\n",
    "\n",
    "            tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)\n",
    "            img_inp, labels = img_inp.to(device), labels.to(device)\n",
    "\n",
    "            outputs = model([tab_inp, img_inp], model_name)\n",
    "        else:\n",
    "            tab_inp, tab_labels = modality1\n",
    "            text_inp, masks, labels = modality2\n",
    "            tab_inp, tab_labels = tab_inp.float(), tab_labels.long()\n",
    "\n",
    "            tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)\n",
    "            text_inp, masks, labels = text_inp.to(device), masks.to(device), labels.to(device)\n",
    "\n",
    "            outputs = model([tab_inp, text_inp, masks], model_name)\n",
    "        \n",
    "        loss = criterion(outputs, labels)\n",
    "        test_labels.extend(np.array(labels.cpu()))\n",
    "        _, predicted = torch.max(outputs, 1)\n",
    "        pred.extend(predicted.cpu().numpy())\n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels).sum().item()\n",
    "        running_loss += loss.item() * labels.size(0)\n",
    "\n",
    "acc= 100 * correct / total\n",
    "loss_f = running_loss/ total\n",
    "print(f'Accuracy of the bert: {100 * correct / total} %')\n",
    "\n",
    "test_labels = np.array(test_labels)\n",
    "\n",
    "#print(confusion_matrix(test_labels, pred))\n",
    "cm = confusion_matrix(test_labels, pred)\n",
    "#print(classification_report(test_labels, pred))\n",
    "cr = classification_report(test_labels, pred, output_dict=True)\n",
    "auc = roc_auc_score(test_labels, pred)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ca0085a0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTorch Version:  1.11.0+cu113\n",
      "Torchvision Version:  0.12.0+cu113\n",
      "Random seed set as 42\n",
      "Accuracy of the bert: 66.52941176470588 %\n",
      "0.6334595049107635\n"
     ]
    }
   ],
   "source": [
    "lr = 5e-07\n",
    "epochs = 68\n",
    "batch_size = 128 #32\n",
    "#best_model_1e-06_22_adamW_20_resnet.pth\n",
    "random_seeds = [15, 0, 1,67,  128, 87, 261, 510, 340, 22] #15, 0, 1,67,  128, 87, 261, 510, 340, 22\n",
    "#df = pd.DataFrame(columns = ['AUROC','accuracy', \"precision\", \"recall\", \"f1-score\", \"CM\", \"CR\"])\n",
    "model_name = \"bert_resnet_luong\"\n",
    "for seed in random_seeds:\n",
    "    model_path = '/users/mgolovan/data/mgolovan/facebook_memes/bert_resnet_models/best_model_' + str(lr)+'_' + str(seed)+'_adamW_' +str(batch_size)+ '_' +  str(epochs)+'_' + str(model_name)+ '.pth_current.pth'\n",
    "    model = MultimodalFramework()\n",
    "    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))\n",
    "    from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score\n",
    "    import pandas as pd\n",
    "    import numpy as np\n",
    "    import json\n",
    "    import sys\n",
    "    import logging\n",
    "    from pathlib import Path\n",
    "    import random\n",
    "    import tarfile\n",
    "    import tempfile\n",
    "    import warnings\n",
    "    import matplotlib.pyplot as plt\n",
    "    # import pandas_path  # Path style access for pandas\n",
    "    from tqdm import tqdm\n",
    "    import torch                    \n",
    "    import torch.nn as nn\n",
    "    import torch.optim as optim\n",
    "    import torchvision\n",
    "    from torchvision import datasets, models, transforms\n",
    "    import matplotlib.pyplot as plt\n",
    "    import time\n",
    "    import os\n",
    "    import copy\n",
    "    print(\"PyTorch Version: \",torch.__version__)\n",
    "    print(\"Torchvision Version: \",torchvision.__version__)\n",
    "\n",
    "    from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score\n",
    "\n",
    "    from torch.utils.data import TensorDataset\n",
    "    from torch.utils.data import DataLoader, RandomSampler, SequentialSampler,Dataset\n",
    "\n",
    "    import matplotlib.pyplot as plt\n",
    "    from PIL import Image\n",
    "    from transformers import BertTokenizer\n",
    "    from transformers import BertForSequenceClassification, AdamW, BertConfig, BertModel\n",
    "    device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\") #str(batch_size)+ '_' +\n",
    "    from model_utils import set_seed\n",
    "\n",
    "    model.eval()\n",
    "    set_seed(42)\n",
    "    batch_size =128\n",
    "    model_name = \"bert_resnet_luong\"\n",
    "    test_inputs_txt = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/test_txt.pt')\n",
    "    test_inputs_img = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/test_img.pt')\n",
    "\n",
    "    if model_name.split(\"_\")[:2] == [\"bert\", \"resnet\"]:\n",
    "        modality_1 = DataLoader(test_inputs_img, batch_size=batch_size, shuffle=False)\n",
    "        modality_2 = DataLoader(test_inputs_txt, batch_size=batch_size, shuffle=False ) \n",
    "\n",
    "    elif model_name.split(\"_\")[:2] == [\"resnet\", \"mlp\"]:\n",
    "        modality_1 = DataLoader(test_inputs_img, batch_size=batch_size, shuffle=False ) \n",
    "        modality_2 = DataLoader(test_inputs_tab, batch_size=batch_size, shuffle=False )\n",
    "\n",
    "    else:\n",
    "        modality_1 = DataLoader(test_inputs_tab, batch_size=batch_size, shuffle=False ) \n",
    "        modality_2 = DataLoader(test_inputs_txt, batch_size=batch_size, shuffle=False )\n",
    "\n",
    "\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    running_loss = 0\n",
    "    pred = []\n",
    "    test_labels = []\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "    # since we're not training, we don't need to calculate the gradients for our outputs\n",
    "    with torch.no_grad():\n",
    "        for modality1, modality2 in zip(modality_1, modality_2):\n",
    "\n",
    "            if model_name.split(\"_\")[:2] == [\"bert\", \"resnet\"]:\n",
    "                text_inp, masks, text_labels = modality2\n",
    "                img_inp, labels = modality1\n",
    "\n",
    "                text_inp, masks, text_labels = text_inp.to(device), masks.to(device), text_labels.to(device)\n",
    "                img_inp, labels = img_inp.to(device), labels.to(device)\n",
    "\n",
    "                outputs = model([img_inp, text_inp, masks], model_name)\n",
    "\n",
    "            elif model_name.split(\"_\")[:2] == [\"resnet\", \"mlp\"]:\n",
    "                img_inp, labels = modality1\n",
    "                tab_inp, tab_labels = modality2\n",
    "                tab_inp, tab_labels = tab_inp.float(), tab_labels.long()\n",
    "\n",
    "                tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)\n",
    "                img_inp, labels = img_inp.to(device), labels.to(device)\n",
    "\n",
    "                outputs = model([tab_inp, img_inp], model_name)\n",
    "            else:\n",
    "                tab_inp, tab_labels = modality1\n",
    "                text_inp, masks, labels = modality2\n",
    "                tab_inp, tab_labels = tab_inp.float(), tab_labels.long()\n",
    "\n",
    "                tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)\n",
    "                text_inp, masks, labels = text_inp.to(device), masks.to(device), labels.to(device)\n",
    "\n",
    "                outputs = model([tab_inp, text_inp, masks], model_name)\n",
    "\n",
    "            loss = criterion(outputs, labels)\n",
    "            test_labels.extend(np.array(labels.cpu()))\n",
    "            _, predicted = torch.max(outputs, 1)\n",
    "            pred.extend(predicted.cpu().numpy())\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "            running_loss += loss.item() * labels.size(0)\n",
    "\n",
    "    acc= 100 * correct / total\n",
    "    loss_f = running_loss/ total\n",
    "    print(f'Accuracy of the bert: {100 * correct / total} %')\n",
    "\n",
    "    test_labels = np.array(test_labels)\n",
    "\n",
    "    #print(confusion_matrix(test_labels, pred))\n",
    "    cm = confusion_matrix(test_labels, pred)\n",
    "    #print(classification_report(test_labels, pred))\n",
    "    cr = classification_report(test_labels, pred, output_dict=True)\n",
    "    auc = roc_auc_score(test_labels, pred)\n",
    "    print(cr['macro avg']['f1-score'])\n",
    "    df = df.append({'AUROC': auc,'accuracy': acc, \"precision\":cr[\"macro avg\"][\"precision\"]*100 ,\n",
    "                    \"recall\":cr[\"macro avg\"][\"recall\"]*100, \"f1-score\":cr[\"macro avg\"][\"f1-score\"]*100,\n",
    "                    \"CM\":cm, \"CR\":cr}, ignore_index=True)\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "55be7f0d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>AUROC</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>precision</th>\n",
       "      <th>recall</th>\n",
       "      <th>f1-score</th>\n",
       "      <th>CM</th>\n",
       "      <th>CR</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.649902</td>\n",
       "      <td>69.176471</td>\n",
       "      <td>66.140263</td>\n",
       "      <td>64.990224</td>\n",
       "      <td>65.364182</td>\n",
       "      <td>[[870, 220], [304, 306]]</td>\n",
       "      <td>{'0': {'precision': 0.7410562180579217, 'recal...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.638555</td>\n",
       "      <td>68.647059</td>\n",
       "      <td>65.444905</td>\n",
       "      <td>63.855467</td>\n",
       "      <td>64.270153</td>\n",
       "      <td>[[881, 209], [324, 286]]</td>\n",
       "      <td>{'0': {'precision': 0.7311203319502074, 'recal...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.614183</td>\n",
       "      <td>65.058824</td>\n",
       "      <td>61.726473</td>\n",
       "      <td>61.418258</td>\n",
       "      <td>61.543207</td>\n",
       "      <td>[[810, 280], [314, 296]]</td>\n",
       "      <td>{'0': {'precision': 0.7206405693950177, 'recal...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.652286</td>\n",
       "      <td>70.176471</td>\n",
       "      <td>67.312720</td>\n",
       "      <td>65.228606</td>\n",
       "      <td>65.752486</td>\n",
       "      <td>[[902, 188], [319, 291]]</td>\n",
       "      <td>{'0': {'precision': 0.7387387387387387, 'recal...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.645044</td>\n",
       "      <td>69.294118</td>\n",
       "      <td>66.226512</td>\n",
       "      <td>64.504437</td>\n",
       "      <td>64.958087</td>\n",
       "      <td>[[888, 202], [320, 290]]</td>\n",
       "      <td>{'0': {'precision': 0.7350993377483444, 'recal...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>0.647142</td>\n",
       "      <td>69.470588</td>\n",
       "      <td>66.441448</td>\n",
       "      <td>64.714243</td>\n",
       "      <td>65.175910</td>\n",
       "      <td>[[889, 201], [318, 292]]</td>\n",
       "      <td>{'0': {'precision': 0.7365368682684341, 'recal...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>0.628215</td>\n",
       "      <td>66.117647</td>\n",
       "      <td>63.017255</td>\n",
       "      <td>62.821477</td>\n",
       "      <td>62.909091</td>\n",
       "      <td>[[812, 278], [298, 312]]</td>\n",
       "      <td>{'0': {'precision': 0.7315315315315315, 'recal...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>0.618439</td>\n",
       "      <td>65.882353</td>\n",
       "      <td>62.450593</td>\n",
       "      <td>61.843886</td>\n",
       "      <td>62.053571</td>\n",
       "      <td>[[830, 260], [320, 290]]</td>\n",
       "      <td>{'0': {'precision': 0.7217391304347827, 'recal...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>0.650526</td>\n",
       "      <td>69.117647</td>\n",
       "      <td>66.090551</td>\n",
       "      <td>65.052639</td>\n",
       "      <td>65.403638</td>\n",
       "      <td>[[866, 224], [301, 309]]</td>\n",
       "      <td>{'0': {'precision': 0.7420736932305055, 'recal...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>0.660911</td>\n",
       "      <td>70.588235</td>\n",
       "      <td>67.790275</td>\n",
       "      <td>66.091142</td>\n",
       "      <td>66.591412</td>\n",
       "      <td>[[894, 196], [304, 306]]</td>\n",
       "      <td>{'0': {'precision': 0.7462437395659433, 'recal...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      AUROC   accuracy  precision     recall   f1-score  \\\n",
       "0  0.649902  69.176471  66.140263  64.990224  65.364182   \n",
       "1  0.638555  68.647059  65.444905  63.855467  64.270153   \n",
       "2  0.614183  65.058824  61.726473  61.418258  61.543207   \n",
       "3  0.652286  70.176471  67.312720  65.228606  65.752486   \n",
       "4  0.645044  69.294118  66.226512  64.504437  64.958087   \n",
       "5  0.647142  69.470588  66.441448  64.714243  65.175910   \n",
       "6  0.628215  66.117647  63.017255  62.821477  62.909091   \n",
       "7  0.618439  65.882353  62.450593  61.843886  62.053571   \n",
       "8  0.650526  69.117647  66.090551  65.052639  65.403638   \n",
       "9  0.660911  70.588235  67.790275  66.091142  66.591412   \n",
       "\n",
       "                         CM                                                 CR  \n",
       "0  [[870, 220], [304, 306]]  {'0': {'precision': 0.7410562180579217, 'recal...  \n",
       "1  [[881, 209], [324, 286]]  {'0': {'precision': 0.7311203319502074, 'recal...  \n",
       "2  [[810, 280], [314, 296]]  {'0': {'precision': 0.7206405693950177, 'recal...  \n",
       "3  [[902, 188], [319, 291]]  {'0': {'precision': 0.7387387387387387, 'recal...  \n",
       "4  [[888, 202], [320, 290]]  {'0': {'precision': 0.7350993377483444, 'recal...  \n",
       "5  [[889, 201], [318, 292]]  {'0': {'precision': 0.7365368682684341, 'recal...  \n",
       "6  [[812, 278], [298, 312]]  {'0': {'precision': 0.7315315315315315, 'recal...  \n",
       "7  [[830, 260], [320, 290]]  {'0': {'precision': 0.7217391304347827, 'recal...  \n",
       "8  [[866, 224], [301, 309]]  {'0': {'precision': 0.7420736932305055, 'recal...  \n",
       "9  [[894, 196], [304, 306]]  {'0': {'precision': 0.7462437395659433, 'recal...  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "130ecc38",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "64.40217369520445"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df[\"f1-score\"].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4e32f735",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/ipykernel_launcher.py:1: FutureWarning: Dropping of nuisance columns in DataFrame reductions (with 'numeric_only=None') is deprecated; in a future version this will raise TypeError.  Select only valid columns before calling the reduction.\n",
      "  \"\"\"Entry point for launching an IPython kernel.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "AUROC         0.640520\n",
       "accuracy     68.352941\n",
       "precision    65.264099\n",
       "recall       64.052038\n",
       "f1-score     64.402174\n",
       "dtype: float64"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "73034652",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.586"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "correct / total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "e17342b8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.3884286813735962"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss_f"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "ad53be8b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.648864681332673"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cr['macro avg']['f1-score']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "a2e3e9c9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "8200"
      ]
     },
     "execution_count": 66,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.backends.cudnn.version()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aaf508d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "val Loss: 0.8501 Acc: 0.5000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "080b3aa4",
   "metadata": {},
   "outputs": [],
   "source": [
    "train Loss: 0.6654 Acc: 0.6334\n",
    "val Loss: 0.7465 Acc: 0.5020\n",
    "Training complete in 1m 57s\n",
    "Best val Acc: 0.502000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcee9e91",
   "metadata": {},
   "outputs": [],
   "source": [
    "train Loss: 0.6606 Acc: 0.6379\n",
    "val Loss: 0.7492 Acc: 0.5000\n",
    "Training complete in 4m 37s\n",
    "Best val Acc: 0.500000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46e16b72",
   "metadata": {},
   "outputs": [],
   "source": [
    "train Loss: 0.6575 Acc: 0.6391\n",
    "val Loss: 0.7548 Acc: 0.5000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6328af67",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46104cf6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4c4bd13",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "274034cb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91d16179",
   "metadata": {},
   "outputs": [],
   "source": [
    "train Loss: 0.6589 Acc: 0.6363\n",
    "val Loss: 0.7782 Acc: 0.5000\n",
    "Training complete in 1m 55s\n",
    "Best val Acc: 0.500000\n",
    "    \n",
    "train Loss: 0.6603 Acc: 0.6394\n",
    "val Loss: 0.7611 Acc: 0.5000\n",
    "Training complete in 2m 27s\n",
    "Best val Acc: 0.500000\n",
    "\n",
    "train Loss: 0.6603 Acc: 0.6394\n",
    "val Loss: 0.7611 Acc: 0.5000\n",
    "Training complete in 2m 16s\n",
    "Best val Acc: 0.500000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 128,
   "id": "9ca6f5c9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 128,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train Loss: 0.6418 Acc: 0.6437\n",
    "val Loss: 0.8099 Acc: 0.4940\n",
    "Training complete in 2m 6s\n",
    "Best val Acc: 0.494000\n",
    "\n",
    "train Loss: 0.6418 Acc: 0.6437\n",
    "val Loss: 0.8099 Acc: 0.4940\n",
    "Training complete in 2m 6s\n",
    "Best val Acc: 0.494000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "id": "16b1de5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred1 = pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "1889de8a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "50.2"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "a1a3349b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "48.8"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eda0f749",
   "metadata": {},
   "outputs": [],
   "source": [
    "train Loss: 55.9858 Acc: 0.5862\n",
    "val Loss: 543.6128 Acc: 0.5040\n",
    "Training complete in 4m 12s\n",
    "Best val Acc: 0.504000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "58a6190c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import logging\n",
    "import random\n",
    "# import pandas_path  # Path style access for pandas\n",
    "from tqdm import tqdm\n",
    "import torch                    \n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "from torchvision import datasets, models, transforms\n",
    "from torch import flatten\n",
    "\n",
    "from collections import OrderedDict\n",
    "from transformers import BertModel, DistilBertModel\n",
    "\n",
    "from PIL import ImageFile,Image\n",
    "ImageFile.LOAD_TRUNCATED_IMAGES = True\n",
    "\n",
    "from MHA_modified import MultiheadAttention\n",
    "\n",
    "import math\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch import Tensor\n",
    "import numpy as np\n",
    "from typing import Optional, Tuple\n",
    "\n",
    "class Attention(torch.nn.Module):\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__()\n",
    "        self.encoder_dim = encoder_dim\n",
    "        self.decoder_dim = decoder_dim\n",
    "\n",
    "    def forward(self, \n",
    "        query: torch.Tensor,  # [decoder_dim]\n",
    "        values: torch.Tensor, # [seq_length, encoder_dim]\n",
    "        ):\n",
    "        weights = self._get_weights(query, values) # [seq_length]\n",
    "        weights = torch.nn.functional.softmax(weights, dim=0)\n",
    "        return weights @ values  # [encoder_dim]\n",
    "\n",
    "class MultiplicativeAttention(Attention):\n",
    "\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__(encoder_dim, decoder_dim)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1)) #.to('cuda')\n",
    "\n",
    "    def _get_weights(self,\n",
    "        query: torch.Tensor,  # [decoder_dim]\n",
    "        values: torch.Tensor, # [seq_length, encoder_dim]\n",
    "    ):\n",
    "        weights = query @ self.W @ values.T  # [seq_length]\n",
    "        return weights #/np.sqrt(self.decoder_dim)    \n",
    "\n",
    "\n",
    "\n",
    "class OneVSOthers(Attention):\n",
    "\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__(encoder_dim, decoder_dim)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1)) #.to('cuda')\n",
    "\n",
    "    def _get_weights(self,others, main):\n",
    "        mean = sum(others) / len(others)\n",
    "        weights = mean @ self.W @ main.T  # [seq_length]\n",
    "        return weights #/np.sqrt(self.decoder_dim) \n",
    "\"\"\"\n",
    "class ScaledDotProductAttention(nn.Module):\n",
    "    \n",
    "    def __init__(self, dim: int):\n",
    "        super(ScaledDotProductAttention, self).__init__()\n",
    "        self.sqrt_dim = np.sqrt(dim)\n",
    "\n",
    "    def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:\n",
    "        score = torch.bmm(query, key.transpose(1, 2)) / self.sqrt_dim\n",
    "\n",
    "        if mask is not None:\n",
    "            score.masked_fill_(mask.view(score.size()), -float('Inf'))\n",
    "\n",
    "        attn = F.softmax(score, -1)\n",
    "        context = torch.bmm(attn, value)\n",
    "\n",
    "        return context, attn\n",
    "\"\"\"\n",
    "\n",
    "class ScaledDotProductAttention(nn.Module):\n",
    "    def __init__(self, dim: int):\n",
    "        super(ScaledDotProductAttention, self).__init__()\n",
    "        self.sqrt_dim = np.sqrt(dim)\n",
    "    \n",
    "    def forward(self, query, key, value, W):\n",
    "\n",
    "        score = query @ W @ value.transpose(1, 2) #.T\n",
    "        attn = F.softmax(score, -1)\n",
    "        context = torch.bmm(attn, value)\n",
    "        return context, attn\n",
    "    \n",
    "class MultiHeadAttention(nn.Module):\n",
    "    \"\"\"\n",
    "    Multi-Head Attention proposed in \"Attention Is All You Need\"\n",
    "    Instead of performing a single attention function with d_model-dimensional keys, values, and queries,\n",
    "    project the queries, keys and values h times with different, learned linear projections to d_head dimensions.\n",
    "    These are concatenated and once again projected, resulting in the final values.\n",
    "    Multi-head attention allows the model to jointly attend to information from different representation\n",
    "    subspaces at different positions.\n",
    "    MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_o\n",
    "        where head_i = Attention(Q · W_q, K · W_k, V · W_v)\n",
    "    Args:\n",
    "        d_model (int): The dimension of keys / values / quries (default: 512)\n",
    "        num_heads (int): The number of attention heads. (default: 8)\n",
    "    Inputs: query, key, value, mask\n",
    "        - **query** (batch, q_len, d_model): In transformer, three different ways:\n",
    "            Case 1: come from previoys decoder layer\n",
    "            Case 2: come from the input embedding\n",
    "            Case 3: come from the output embedding (masked)\n",
    "        - **key** (batch, k_len, d_model): In transformer, three different ways:\n",
    "            Case 1: come from the output of the encoder\n",
    "            Case 2: come from the input embeddings\n",
    "            Case 3: come from the output embedding (masked)\n",
    "        - **value** (batch, v_len, d_model): In transformer, three different ways:\n",
    "            Case 1: come from the output of the encoder\n",
    "            Case 2: come from the input embeddings\n",
    "            Case 3: come from the output embedding (masked)\n",
    "        - **mask** (-): tensor containing indices to be masked\n",
    "    Returns: output, attn\n",
    "        - **output** (batch, output_len, dimensions): tensor containing the attended output features.\n",
    "        - **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.\n",
    "    \"\"\"\n",
    "    def __init__(self, d_model: int = 512, num_heads: int = 8):\n",
    "        super(MultiHeadAttention, self).__init__()\n",
    "\n",
    "        assert d_model % num_heads == 0, \"d_model % num_heads should be zero.\"\n",
    "\n",
    "        self.d_head = int(d_model / num_heads)\n",
    "        self.num_heads = num_heads\n",
    "        self.scaled_dot_attn = ScaledDotProductAttention(self.d_head)\n",
    "        self.query_proj = nn.Linear(d_model, self.d_head * num_heads)\n",
    "        self.key_proj = nn.Linear(d_model, self.d_head * num_heads)\n",
    "        self.value_proj = nn.Linear(d_model, self.d_head * num_heads)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(self.d_head, self.d_head).uniform_(-0.1, 0.1))\n",
    "\n",
    "    def forward(\n",
    "            self,\n",
    "            query: Tensor,\n",
    "            key: Tensor,\n",
    "            value: Tensor,\n",
    "            mask: Optional[Tensor] = None\n",
    "    ) -> Tuple[Tensor, Tensor]:\n",
    "        batch_size = value.size(0)\n",
    "\n",
    "        query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)  # BxQ_LENxNxD\n",
    "        key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head)      # BxK_LENxNxD\n",
    "        value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head)  # BxV_LENxNxD\n",
    "\n",
    "        query = query.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head)  # BNxQ_LENxD\n",
    "        key = key.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head)      # BNxK_LENxD\n",
    "        value = value.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head)  # BNxV_LENxD\n",
    "\n",
    "        if mask is not None:\n",
    "            mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)  # BxNxQ_LENxK_LEN\n",
    "\n",
    "        #context, attn = self.scaled_dot_attn(query, key, value, mask)\n",
    "        context, attn = self.scaled_dot_attn(query, key, value, self.W)\n",
    "\n",
    "        context = context.view(self.num_heads, batch_size, -1, self.d_head)\n",
    "        context = context.permute(1, 2, 0, 3).contiguous().view(batch_size, -1, self.num_heads * self.d_head)  # BxTxND\n",
    "\n",
    "        return context\n",
    "\n",
    "class MultimodalFramework(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super(MultimodalFramework, self).__init__()\n",
    "        ##MLP\n",
    "        self.fc1 = nn.Linear(53, 256)\n",
    "        self.fc2 = nn.Linear(256, 256)\n",
    "        self.fc3 = nn.Linear(256, 2)\n",
    "        self.relu = nn.ReLU()\n",
    "        \n",
    "        ##RESNET\n",
    "        self.resnet18 = models.resnet18(pretrained=True)\n",
    "        n_inputs = self.resnet18.fc.in_features\n",
    "\n",
    "        self.resnet18.fc = nn.Sequential(OrderedDict([\n",
    "            ('fc1', nn.Linear(n_inputs, 512))\n",
    "        ])) \n",
    "\n",
    "        self.resnet_classification = nn.Linear(512, 2) #4\n",
    "        \n",
    "        ##BERT\n",
    "        self.bert = BertModel.from_pretrained('bert-base-uncased') \n",
    "        self.bert_classification = nn.Linear(768, 2)\n",
    "        \n",
    "        #Two Modality models\n",
    "        self.bert_resnet_classification = nn.Linear(512 + 768, 2)\n",
    "        self.bert_mlp_classification = nn.Linear(256 + 768, 2)\n",
    "        self.resnet_mlp_classification = nn.Linear(256 + 512, 2)\n",
    "        self.bert_resnet_mlp_classification = nn.Linear(256 + 512 +768, 2)\n",
    "        self.bert_resnet_mlp_l_classification = nn.Linear(256*3, 2)\n",
    "        self.att_classification = nn.Linear(256*2, 2)\n",
    "        self.vaswani_3_classification = nn.Linear(3 *(256*2), 2)\n",
    "        self.OvO_classification = nn.Linear(3*256, 2)\n",
    "\n",
    "        self.res_wrap = nn.Linear(512, 256)\n",
    "        self.bert_wrap = nn.Linear(768, 256)\n",
    "        \n",
    "        #self.luong_attention  = MultiplicativeAttention(encoder_dim=256, decoder_dim=256)\n",
    "        self.vaswani_attention  = nn.MultiheadAttention(256, 4, batch_first = True)\n",
    "        #self.OvO_attention = OneVSOthers(encoder_dim=256, decoder_dim=256)\n",
    "        #self.OvO_concat_attention = OneVSOthers_concat(encoder_dim=256, decoder_dim=512)\n",
    "        self.OvO_multihead_attention = MultiheadAttention(256,2, typ = \"OvO\")\n",
    "        #self.luong_multihead_attention = MultiHeadAttention(256,2, typ = \"luong\")\n",
    "        self.luong_multihead_attention = MultiHeadAttention(256,2)\n",
    "        \n",
    "    def bi_directional_att(self, pair):\n",
    "        x = pair[0]\n",
    "        y = pair[1]\n",
    "        attn_output_LV, attn_output_weights_LV = self.vaswani_attention(x, y, y)\n",
    "        attn_output_VL, attn_output_weights_VL = self.vaswani_attention(y, x, x)\n",
    "        combined = torch.cat((attn_output_LV,\n",
    "                              attn_output_VL), dim=1)\n",
    "        return combined\n",
    "\n",
    "    def forward(self, x, model):\n",
    "        if model == \"mlp\":\n",
    "            x = self.fc1(x)\n",
    "            x = self.relu(x)\n",
    "            x = self.fc2(x)\n",
    "            x = self.relu(x)\n",
    "            out = self.fc3(x)\n",
    "            \n",
    "        elif model == \"resnet\":\n",
    "            res = self.resnet18(x)       \n",
    "            out = self.resnet_classification(res)\n",
    "        \n",
    "        elif model == \"bert\":\n",
    "            text, masks = x\n",
    "            bert = self.bert(text, attention_mask=masks, token_type_ids=None).last_hidden_state[:,0,:]\n",
    "            out = self.bert_classification(bert)\n",
    "            \n",
    "        elif model == \"bert_resnet\":\n",
    "            img, text, masks = x\n",
    "            res_emb = self.resnet18(img)\n",
    "            \n",
    "            bert_emb = self.bert(text, attention_mask=masks).last_hidden_state[:,0,:]\n",
    "            combined = torch.cat((res_emb,\n",
    "                                  bert_emb), dim=1)\n",
    "            out = self.bert_resnet_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_luong\":\n",
    "            img, text, masks = x\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res).unsqueeze(1)\n",
    "\n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert).unsqueeze(1)\n",
    "\n",
    "            attn_output_LV = self.luong_multihead_attention(bert, res, res)\n",
    "            attn_output_VL = self.luong_multihead_attention(res, bert, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV.squeeze(1),\n",
    "                                  attn_output_VL.squeeze(1)), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_vaswani\":\n",
    "            img, text, masks = x\n",
    "            res_emb = self.resnet18(img)\n",
    "        \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            res = self.res_wrap(res_emb)\n",
    "            res = res[:, None, :]\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "            bert = bert[:, None, :]\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, res, res)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(res, bert, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV.squeeze(1),\n",
    "                                  attn_output_VL.squeeze(1)), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_mlp\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            \n",
    "            combined = torch.cat((bert,feat), dim=1)\n",
    "            out = self.bert_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_mlp_luong\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "\n",
    "            attn_output_LV = self.luong_attention(bert, feat)\n",
    "            attn_output_VL = self.luong_attention(feat, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)   \n",
    "            \n",
    "        elif model == \"bert_mlp_vaswani\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, feat, feat)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, bert, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"resnet_mlp\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            \n",
    "            combined = torch.cat((feat,res), dim=1)\n",
    "            out = self.resnet_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"resnet_mlp_luong\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            attn_output_LV = self.luong_attention(res, feat)\n",
    "            attn_output_VL = self.luong_attention(feat, res)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)   \n",
    "            \n",
    "        elif model == \"resnet_mlp_vaswani\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(res, feat, feat)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, res, res)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            \n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_resnet_mlp\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            res = self.resnet18(img)\n",
    "\n",
    "            \n",
    "            combined = torch.cat((bert,feat, res), dim=1)\n",
    "            out = self.bert_resnet_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_resnet_mlp_l\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            combined = torch.cat((bert, feat, res), dim=1)\n",
    "            out = self.bert_resnet_mlp_l_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_mlp_vaswani\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "            \n",
    "            pairs = [[feat, bert],[feat,res],[bert,res]]\n",
    "        \n",
    "            results = []\n",
    "            for pair in pairs:\n",
    "                combined = self.bi_directional_att(pair)\n",
    "                results.append(combined)\n",
    "\n",
    "            comb = torch.cat(results, dim=1)\n",
    "            out = self.vaswani_3_classification(comb)\n",
    "            \n",
    "        else:\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "            \n",
    "            attn_txt, weights_txt = self.OvO_multihead_attention(feat, res, bert) #[feat, res], bert\n",
    "            attn_img, weights_img = self.OvO_multihead_attention(feat, bert,res)\n",
    "            attn_tab, weights_tab = self.OvO_multihead_attention(bert, res, feat)\n",
    "\n",
    "            comb = torch.cat([attn_txt, attn_img, attn_tab], dim=1)\n",
    "            out = self.OvO_classification(comb)\n",
    "\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b08dcad3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: '/gpfs/home/mgolovan/data/mgolovan/facebook_memes/bert_resnet_models/best_model_new_0.00005_42_adamW_16_150_bert_resnet_luong.pth_current.pth'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mFileNotFoundError\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_49186/2023426640.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMultimodalFramework\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"/gpfs/home/mgolovan/data/mgolovan/facebook_memes/bert_resnet_models/best_model_new_0.00005_42_adamW_16_150_bert_resnet_luong.pth_current.pth\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mmap_location\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'cpu'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, **pickle_load_args)\u001b[0m\n\u001b[1;32m    697\u001b[0m         \u001b[0mpickle_load_args\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'encoding'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'utf-8'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    698\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 699\u001b[0;31m     \u001b[0;32mwith\u001b[0m \u001b[0m_open_file_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mopened_file\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    700\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0m_is_zipfile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopened_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    701\u001b[0m             \u001b[0;31m# The zipfile reader is going to advance the current file position.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36m_open_file_like\u001b[0;34m(name_or_buffer, mode)\u001b[0m\n\u001b[1;32m    229\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_open_file_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    230\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0m_is_path\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname_or_buffer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 231\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0m_open_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    232\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    233\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0;34m'w'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, name, mode)\u001b[0m\n\u001b[1;32m    210\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0m_open_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_opener\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    211\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 212\u001b[0;31m         \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_open_file\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    213\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    214\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__exit__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/gpfs/home/mgolovan/data/mgolovan/facebook_memes/bert_resnet_models/best_model_new_0.00005_42_adamW_16_150_bert_resnet_luong.pth_current.pth'"
     ]
    }
   ],
   "source": [
    "model = MultimodalFramework()\n",
    "model.load_state_dict(torch.load(\"/gpfs/home/mgolovan/data/mgolovan/facebook_memes/bert_resnet_models/best_model_new_0.00005_42_adamW_16_150_bert_resnet_luong.pth_current.pth\",map_location=torch.device('cpu')))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c1c8386e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTorch Version:  1.11.0+cu113\n",
      "Torchvision Version:  0.12.0+cu113\n",
      "Random seed set as 42\n",
      "Accuracy of the bert: 64.11764705882354 %\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import sys\n",
    "import logging\n",
    "from pathlib import Path\n",
    "import random\n",
    "import tarfile\n",
    "import tempfile\n",
    "import warnings\n",
    "import matplotlib.pyplot as plt\n",
    "# import pandas_path  # Path style access for pandas\n",
    "from tqdm import tqdm\n",
    "import torch                    \n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "from torchvision import datasets, models, transforms\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "import os\n",
    "import copy\n",
    "print(\"PyTorch Version: \",torch.__version__)\n",
    "print(\"Torchvision Version: \",torchvision.__version__)\n",
    "\n",
    "from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score\n",
    "\n",
    "from torch.utils.data import TensorDataset\n",
    "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler,Dataset\n",
    " \n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "from transformers import BertTokenizer\n",
    "from transformers import BertForSequenceClassification, AdamW, BertConfig, BertModel\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\") #str(batch_size)+ '_' +\n",
    "from model_utils import set_seed\n",
    "\n",
    "model.eval()\n",
    "set_seed(42)\n",
    "batch_size =16\n",
    "model_name = \"bert_resnet_luong\"\n",
    "test_inputs_txt = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/test_txt.pt')\n",
    "test_inputs_img = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/test_img.pt')\n",
    "\n",
    "if model_name.split(\"_\")[:2] == [\"bert\", \"resnet\"]:\n",
    "    modality_1 = DataLoader(test_inputs_img, batch_size=batch_size, shuffle=False)\n",
    "    modality_2 = DataLoader(test_inputs_txt, batch_size=batch_size, shuffle=False ) \n",
    "\n",
    "elif model_name.split(\"_\")[:2] == [\"resnet\", \"mlp\"]:\n",
    "    modality_1 = DataLoader(test_inputs_img, batch_size=batch_size, shuffle=False ) \n",
    "    modality_2 = DataLoader(test_inputs_tab, batch_size=batch_size, shuffle=False )\n",
    "\n",
    "else:\n",
    "    modality_1 = DataLoader(test_inputs_tab, batch_size=batch_size, shuffle=False ) \n",
    "    modality_2 = DataLoader(test_inputs_txt, batch_size=batch_size, shuffle=False )\n",
    "\n",
    "\n",
    "correct = 0\n",
    "total = 0\n",
    "running_loss = 0\n",
    "pred = []\n",
    "test_labels = []\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "# since we're not training, we don't need to calculate the gradients for our outputs\n",
    "with torch.no_grad():\n",
    "    for modality1, modality2 in zip(modality_1, modality_2):\n",
    "\n",
    "        if model_name.split(\"_\")[:2] == [\"bert\", \"resnet\"]:\n",
    "            text_inp, masks, text_labels = modality2\n",
    "            img_inp, labels = modality1\n",
    "\n",
    "            text_inp, masks, text_labels = text_inp.to(device), masks.to(device), text_labels.to(device)\n",
    "            img_inp, labels = img_inp.to(device), labels.to(device)\n",
    "\n",
    "            outputs = model([img_inp, text_inp, masks], model_name)\n",
    "\n",
    "        elif model_name.split(\"_\")[:2] == [\"resnet\", \"mlp\"]:\n",
    "            img_inp, labels = modality1\n",
    "            tab_inp, tab_labels = modality2\n",
    "            tab_inp, tab_labels = tab_inp.float(), tab_labels.long()\n",
    "\n",
    "            tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)\n",
    "            img_inp, labels = img_inp.to(device), labels.to(device)\n",
    "\n",
    "            outputs = model([tab_inp, img_inp], model_name)\n",
    "        else:\n",
    "            tab_inp, tab_labels = modality1\n",
    "            text_inp, masks, labels = modality2\n",
    "            tab_inp, tab_labels = tab_inp.float(), tab_labels.long()\n",
    "\n",
    "            tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)\n",
    "            text_inp, masks, labels = text_inp.to(device), masks.to(device), labels.to(device)\n",
    "\n",
    "            outputs = model([tab_inp, text_inp, masks], model_name)\n",
    "        \n",
    "        loss = criterion(outputs, labels)\n",
    "        test_labels.extend(np.array(labels.cpu()))\n",
    "        _, predicted = torch.max(outputs, 1)\n",
    "        pred.extend(predicted.cpu().numpy())\n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels).sum().item()\n",
    "        running_loss += loss.item() * labels.size(0)\n",
    "\n",
    "acc= 100 * correct / total\n",
    "loss_f = running_loss/ total\n",
    "print(f'Accuracy of the bert: {100 * correct / total} %')\n",
    "\n",
    "test_labels = np.array(test_labels)\n",
    "\n",
    "#print(confusion_matrix(test_labels, pred))\n",
    "cm = confusion_matrix(test_labels, pred)\n",
    "#print(classification_report(test_labels, pred))\n",
    "cr = classification_report(test_labels, pred, output_dict=True)\n",
    "auc = roc_auc_score(test_labels, pred)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "1321b95d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.472200929697822"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss_f"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "5de93463",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.3906810035842294"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cr['macro avg']['f1-score']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "c04b9661",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTorch Version:  1.11.0+cu113\n",
      "Torchvision Version:  0.12.0+cu113\n",
      "cpu\n",
      "cpu\n",
      "Random seed set as 42\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0/0\n",
      "----------\n",
      "torch.Size([16, 1, 256])\n",
      "torch.Size([16, 1, 2, 128])\n",
      "torch.Size([32, 1, 128])\n",
      "torch.Size([16, 1, 256])\n",
      "torch.Size([16, 1, 2, 128])\n",
      "torch.Size([32, 1, 128])\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_49186/3894401215.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m    699\u001b[0m \u001b[0mpath\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m''\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mmodel_name\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\"_original_\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    700\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 701\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdic\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataloaders_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen_val\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"config\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    702\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    703\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_49186/3894401215.py\u001b[0m in \u001b[0;36mtrain_model\u001b[0;34m(model_name, dataloaders, criterion, len_train, len_val, config, path)\u001b[0m\n\u001b[1;32m    618\u001b[0m                     \u001b[0;32mif\u001b[0m \u001b[0mphase\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'train'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    619\u001b[0m                         \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 620\u001b[0;31m                         \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    621\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    622\u001b[0m                 \u001b[0;31m# statistics\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/torch/optim/optimizer.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     86\u001b[0m                 \u001b[0mprofile_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"Optimizer.step#{}.step\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     87\u001b[0m                 \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrecord_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprofile_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 88\u001b[0;31m                     \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     89\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     90\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/torch/autograd/grad_mode.py\u001b[0m in \u001b[0;36mdecorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     25\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     26\u001b[0m             \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclone\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 27\u001b[0;31m                 \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     28\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mcast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     29\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/torch/optim/adamw.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m    155\u001b[0m                     \u001b[0mweight_decay\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mgroup\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'weight_decay'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    156\u001b[0m                     \u001b[0meps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mgroup\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'eps'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 157\u001b[0;31m                     maximize=group['maximize'])\n\u001b[0m\u001b[1;32m    158\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    159\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/gpfs/data/ceickhof/mgolovan/test-yang-py3.7/lib/python3.7/site-packages/torch/optim/_functional.py\u001b[0m in \u001b[0;36madamw\u001b[0;34m(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, beta1, beta2, lr, weight_decay, eps, maximize)\u001b[0m\n\u001b[1;32m    149\u001b[0m             \u001b[0mdenom\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mmax_exp_avg_sqs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbias_correction2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    150\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 151\u001b[0;31m             \u001b[0mdenom\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mexp_avg_sq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbias_correction2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    152\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    153\u001b[0m         \u001b[0mstep_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlr\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mbias_correction1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import logging\n",
    "import random\n",
    "# import pandas_path  # Path style access for pandas\n",
    "from tqdm import tqdm\n",
    "import torch                    \n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "from torchvision import datasets, models, transforms\n",
    "from torch import flatten\n",
    "\n",
    "from collections import OrderedDict\n",
    "from transformers import BertModel, DistilBertModel\n",
    "\n",
    "from PIL import ImageFile,Image\n",
    "ImageFile.LOAD_TRUNCATED_IMAGES = True\n",
    "\n",
    "from MHA_modified import MultiheadAttention\n",
    "\n",
    "import math\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch import Tensor\n",
    "import numpy as np\n",
    "from typing import Optional, Tuple\n",
    "\n",
    "class Attention(torch.nn.Module):\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__()\n",
    "        self.encoder_dim = encoder_dim\n",
    "        self.decoder_dim = decoder_dim\n",
    "\n",
    "    def forward(self, \n",
    "        query: torch.Tensor,  # [decoder_dim]\n",
    "        values: torch.Tensor, # [seq_length, encoder_dim]\n",
    "        ):\n",
    "        weights = self._get_weights(query, values) # [seq_length]\n",
    "        weights = torch.nn.functional.softmax(weights, dim=0)\n",
    "        return weights @ values  # [encoder_dim]\n",
    "\n",
    "class MultiplicativeAttention(Attention):\n",
    "\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__(encoder_dim, decoder_dim)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1)) #.to('cuda')\n",
    "\n",
    "    def _get_weights(self,\n",
    "        query: torch.Tensor,  # [decoder_dim]\n",
    "        values: torch.Tensor, # [seq_length, encoder_dim]\n",
    "    ):\n",
    "        weights = query @ self.W @ values.T  # [seq_length]\n",
    "        return weights #/np.sqrt(self.decoder_dim)    \n",
    "\n",
    "\n",
    "\n",
    "class OneVSOthers(Attention):\n",
    "\n",
    "    def __init__(self, encoder_dim: int, decoder_dim: int):\n",
    "        super().__init__(encoder_dim, decoder_dim)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(self.decoder_dim, self.encoder_dim).uniform_(-0.1, 0.1)) #.to('cuda')\n",
    "\n",
    "    def _get_weights(self,others, main):\n",
    "        mean = sum(others) / len(others)\n",
    "        weights = mean @ self.W @ main.T  # [seq_length]\n",
    "        return weights #/np.sqrt(self.decoder_dim) \n",
    "\"\"\"\n",
    "class ScaledDotProductAttention(nn.Module):\n",
    "    \n",
    "    def __init__(self, dim: int):\n",
    "        super(ScaledDotProductAttention, self).__init__()\n",
    "        self.sqrt_dim = np.sqrt(dim)\n",
    "\n",
    "    def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:\n",
    "        score = torch.bmm(query, key.transpose(1, 2)) / self.sqrt_dim\n",
    "\n",
    "        if mask is not None:\n",
    "            score.masked_fill_(mask.view(score.size()), -float('Inf'))\n",
    "\n",
    "        attn = F.softmax(score, -1)\n",
    "        context = torch.bmm(attn, value)\n",
    "\n",
    "        return context, attn\n",
    "\"\"\"\n",
    "\n",
    "class ScaledDotProductAttention(nn.Module):\n",
    "    def __init__(self, dim: int):\n",
    "        super(ScaledDotProductAttention, self).__init__()\n",
    "        self.sqrt_dim = np.sqrt(dim)\n",
    "    \n",
    "    def forward(self, query, key, value, W):\n",
    "\n",
    "        score = query @ W @ value.transpose(1, 2) #.T\n",
    "        attn = F.softmax(score, -1)\n",
    "        context = torch.bmm(attn, value)\n",
    "        return context, attn\n",
    "    \n",
    "class MultiHeadAttention(nn.Module):\n",
    "    \"\"\"\n",
    "    Multi-Head Attention proposed in \"Attention Is All You Need\"\n",
    "    Instead of performing a single attention function with d_model-dimensional keys, values, and queries,\n",
    "    project the queries, keys and values h times with different, learned linear projections to d_head dimensions.\n",
    "    These are concatenated and once again projected, resulting in the final values.\n",
    "    Multi-head attention allows the model to jointly attend to information from different representation\n",
    "    subspaces at different positions.\n",
    "    MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_o\n",
    "        where head_i = Attention(Q · W_q, K · W_k, V · W_v)\n",
    "    Args:\n",
    "        d_model (int): The dimension of keys / values / quries (default: 512)\n",
    "        num_heads (int): The number of attention heads. (default: 8)\n",
    "    Inputs: query, key, value, mask\n",
    "        - **query** (batch, q_len, d_model): In transformer, three different ways:\n",
    "            Case 1: come from previoys decoder layer\n",
    "            Case 2: come from the input embedding\n",
    "            Case 3: come from the output embedding (masked)\n",
    "        - **key** (batch, k_len, d_model): In transformer, three different ways:\n",
    "            Case 1: come from the output of the encoder\n",
    "            Case 2: come from the input embeddings\n",
    "            Case 3: come from the output embedding (masked)\n",
    "        - **value** (batch, v_len, d_model): In transformer, three different ways:\n",
    "            Case 1: come from the output of the encoder\n",
    "            Case 2: come from the input embeddings\n",
    "            Case 3: come from the output embedding (masked)\n",
    "        - **mask** (-): tensor containing indices to be masked\n",
    "    Returns: output, attn\n",
    "        - **output** (batch, output_len, dimensions): tensor containing the attended output features.\n",
    "        - **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.\n",
    "    \"\"\"\n",
    "    def __init__(self, d_model: int = 512, num_heads: int = 8):\n",
    "        super(MultiHeadAttention, self).__init__()\n",
    "\n",
    "        assert d_model % num_heads == 0, \"d_model % num_heads should be zero.\"\n",
    "\n",
    "        self.d_head = int(d_model / num_heads)\n",
    "        self.num_heads = num_heads\n",
    "        self.scaled_dot_attn = ScaledDotProductAttention(self.d_head)\n",
    "        self.query_proj = nn.Linear(d_model, self.d_head * num_heads)\n",
    "        self.key_proj = nn.Linear(d_model, self.d_head * num_heads)\n",
    "        self.value_proj = nn.Linear(d_model, self.d_head * num_heads)\n",
    "        self.W = torch.nn.Parameter(torch.FloatTensor(self.d_head, self.d_head).uniform_(-0.1, 0.1))\n",
    "\n",
    "    def forward(\n",
    "            self,\n",
    "            query: Tensor,\n",
    "            key: Tensor,\n",
    "            value: Tensor,\n",
    "            mask: Optional[Tensor] = None\n",
    "    ) -> Tuple[Tensor, Tensor]:\n",
    "        batch_size = value.size(0)\n",
    "        \n",
    "        print(query.shape)\n",
    "        query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)  # BxQ_LENxNxD\n",
    "        key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head)      # BxK_LENxNxD\n",
    "        value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head)  # BxV_LENxNxD\n",
    "        print(query.shape)\n",
    "        query = query.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head)  # BNxQ_LENxD\n",
    "        print(query.shape)\n",
    "        key = key.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head)      # BNxK_LENxD\n",
    "        value = value.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head)  # BNxV_LENxD\n",
    "\n",
    "        if mask is not None:\n",
    "            mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)  # BxNxQ_LENxK_LEN\n",
    "\n",
    "        #context, attn = self.scaled_dot_attn(query, key, value, mask)\n",
    "        context, attn = self.scaled_dot_attn(query, key, value, self.W)\n",
    "\n",
    "        context = context.view(self.num_heads, batch_size, -1, self.d_head)\n",
    "        context = context.permute(1, 2, 0, 3).contiguous().view(batch_size, -1, self.num_heads * self.d_head)  # BxTxND\n",
    "\n",
    "        return context\n",
    "\n",
    "class MultimodalFramework(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super(MultimodalFramework, self).__init__()\n",
    "        ##MLP\n",
    "        self.fc1 = nn.Linear(53, 256)\n",
    "        self.fc2 = nn.Linear(256, 256)\n",
    "        self.fc3 = nn.Linear(256, 2)\n",
    "        self.relu = nn.ReLU()\n",
    "        \n",
    "        ##RESNET\n",
    "        self.resnet18 = models.resnet18(pretrained=True)\n",
    "        n_inputs = self.resnet18.fc.in_features\n",
    "\n",
    "        self.resnet18.fc = nn.Sequential(OrderedDict([\n",
    "            ('fc1', nn.Linear(n_inputs, 512))\n",
    "        ])) \n",
    "\n",
    "        self.resnet_classification = nn.Linear(512, 2) #4\n",
    "        \n",
    "        ##BERT\n",
    "        self.bert = BertModel.from_pretrained('bert-base-uncased') \n",
    "        self.bert_classification = nn.Linear(768, 2)\n",
    "        \n",
    "        #Two Modality models\n",
    "        self.bert_resnet_classification = nn.Linear(512 + 768, 2)\n",
    "        self.bert_mlp_classification = nn.Linear(256 + 768, 2)\n",
    "        self.resnet_mlp_classification = nn.Linear(256 + 512, 2)\n",
    "        self.bert_resnet_mlp_classification = nn.Linear(256 + 512 +768, 2)\n",
    "        self.bert_resnet_mlp_l_classification = nn.Linear(256*3, 2)\n",
    "        self.att_classification = nn.Linear(256*2, 2)\n",
    "        self.vaswani_3_classification = nn.Linear(3 *(256*2), 2)\n",
    "        self.OvO_classification = nn.Linear(3*256, 2)\n",
    "\n",
    "        self.res_wrap = nn.Linear(512, 256)\n",
    "        self.bert_wrap = nn.Linear(768, 256)\n",
    "        \n",
    "        #self.luong_attention  = MultiplicativeAttention(encoder_dim=256, decoder_dim=256)\n",
    "        self.vaswani_attention  = nn.MultiheadAttention(256, 4, batch_first = True)\n",
    "        #self.OvO_attention = OneVSOthers(encoder_dim=256, decoder_dim=256)\n",
    "        #self.OvO_concat_attention = OneVSOthers_concat(encoder_dim=256, decoder_dim=512)\n",
    "        self.OvO_multihead_attention = MultiheadAttention(256,2, typ = \"OvO\")\n",
    "        #self.luong_multihead_attention = MultiHeadAttention(256,2, typ = \"luong\")\n",
    "        self.luong_multihead_attention = MultiHeadAttention(256,2)\n",
    "        \n",
    "    def bi_directional_att(self, pair):\n",
    "        x = pair[0]\n",
    "        y = pair[1]\n",
    "        attn_output_LV, attn_output_weights_LV = self.vaswani_attention(x, y, y)\n",
    "        attn_output_VL, attn_output_weights_VL = self.vaswani_attention(y, x, x)\n",
    "        combined = torch.cat((attn_output_LV,\n",
    "                              attn_output_VL), dim=1)\n",
    "        return combined\n",
    "\n",
    "    def forward(self, x, model):\n",
    "        if model == \"mlp\":\n",
    "            x = self.fc1(x)\n",
    "            x = self.relu(x)\n",
    "            x = self.fc2(x)\n",
    "            x = self.relu(x)\n",
    "            out = self.fc3(x)\n",
    "            \n",
    "        elif model == \"resnet\":\n",
    "            res = self.resnet18(x)       \n",
    "            out = self.resnet_classification(res)\n",
    "        \n",
    "        elif model == \"bert\":\n",
    "            text, masks = x\n",
    "            bert = self.bert(text, attention_mask=masks, token_type_ids=None).last_hidden_state[:,0,:]\n",
    "            out = self.bert_classification(bert)\n",
    "            \n",
    "        elif model == \"bert_resnet\":\n",
    "            img, text, masks = x\n",
    "            res_emb = self.resnet18(img)\n",
    "            \n",
    "            bert_emb = self.bert(text, attention_mask=masks).last_hidden_state[:,0,:]\n",
    "            combined = torch.cat((res_emb,\n",
    "                                  bert_emb), dim=1)\n",
    "            out = self.bert_resnet_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_luong\":\n",
    "            img, text, masks = x\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res) #.unsqueeze(1)\n",
    "\n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert) #.unsqueeze(1)\n",
    "\n",
    "            attn_output_LV = self.luong_multihead_attention(bert, res, res)\n",
    "            attn_output_VL = self.luong_multihead_attention(res, bert, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV, #.squeeze(1)\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_vaswani\":\n",
    "            img, text, masks = x\n",
    "            res_emb = self.resnet18(img)\n",
    "        \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            res = self.res_wrap(res_emb)\n",
    "            res = res[:, None, :]\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "            bert = bert[:, None, :]\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, res, res)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(res, bert, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV.squeeze(1),\n",
    "                                  attn_output_VL.squeeze(1)), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_mlp\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            \n",
    "            combined = torch.cat((bert,feat), dim=1)\n",
    "            out = self.bert_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_mlp_luong\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "\n",
    "            attn_output_LV = self.luong_attention(bert, feat)\n",
    "            attn_output_VL = self.luong_attention(feat, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)   \n",
    "            \n",
    "        elif model == \"bert_mlp_vaswani\":\n",
    "            features, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert_emb = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert_emb)\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(bert, feat, feat)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, bert, bert)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"resnet_mlp\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            \n",
    "            combined = torch.cat((feat,res), dim=1)\n",
    "            out = self.resnet_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"resnet_mlp_luong\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            attn_output_LV = self.luong_attention(res, feat)\n",
    "            attn_output_VL = self.luong_attention(feat, res)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            out = self.att_classification(combined)   \n",
    "            \n",
    "        elif model == \"resnet_mlp_vaswani\":\n",
    "            features, img = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            attn_output_LV, attn_output_weights_LV = self.vaswani_attention(res, feat, feat)\n",
    "            attn_output_VL, attn_output_weights_VL = self.vaswani_attention(feat, res, res)\n",
    "\n",
    "            combined = torch.cat((attn_output_LV,\n",
    "                                  attn_output_VL), dim=1)\n",
    "            \n",
    "            out = self.att_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_resnet_mlp\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            res = self.resnet18(img)\n",
    "\n",
    "            \n",
    "            combined = torch.cat((bert,feat, res), dim=1)\n",
    "            out = self.bert_resnet_mlp_classification(combined)\n",
    "            \n",
    "        elif model == \"bert_resnet_mlp_l\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "\n",
    "            combined = torch.cat((bert, feat, res), dim=1)\n",
    "            out = self.bert_resnet_mlp_l_classification(combined)\n",
    "        \n",
    "        elif model == \"bert_resnet_mlp_vaswani\":\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "            \n",
    "            pairs = [[feat, bert],[feat,res],[bert,res]]\n",
    "        \n",
    "            results = []\n",
    "            for pair in pairs:\n",
    "                combined = self.bi_directional_att(pair)\n",
    "                results.append(combined)\n",
    "\n",
    "            comb = torch.cat(results, dim=1)\n",
    "            out = self.vaswani_3_classification(comb)\n",
    "            \n",
    "        else:\n",
    "            features, img, text, masks = x\n",
    "            \n",
    "            feat = self.fc1(features)\n",
    "            feat = self.relu(feat)\n",
    "            feat = self.fc2(feat)\n",
    "            \n",
    "            bert = self.bert(text,attention_mask=masks).last_hidden_state[:,0,:] #.last_hidden_state #.logits\n",
    "            bert = self.bert_wrap(bert)\n",
    "            \n",
    "            res = self.resnet18(img)\n",
    "            res = self.res_wrap(res)\n",
    "            \n",
    "            attn_txt, weights_txt = self.OvO_multihead_attention(feat, res, bert) #[feat, res], bert\n",
    "            attn_img, weights_img = self.OvO_multihead_attention(feat, bert,res)\n",
    "            attn_tab, weights_tab = self.OvO_multihead_attention(bert, res, feat)\n",
    "\n",
    "            comb = torch.cat([attn_txt, attn_img, attn_tab], dim=1)\n",
    "            out = self.OvO_classification(comb)\n",
    "\n",
    "        return out\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import sys\n",
    "import logging\n",
    "from pathlib import Path\n",
    "import random\n",
    "import tarfile\n",
    "import tempfile\n",
    "import warnings\n",
    "import matplotlib.pyplot as plt\n",
    "# import pandas_path  # Path style access for pandas\n",
    "from tqdm import tqdm\n",
    "import torch                    \n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "from torchvision import datasets, models, transforms\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "import os\n",
    "import copy\n",
    "import pprint\n",
    "print(\"PyTorch Version: \",torch.__version__)\n",
    "print(\"Torchvision Version: \",torchvision.__version__)\n",
    "from sklearn.metrics import confusion_matrix,f1_score\n",
    "\n",
    "from torch.utils.data import TensorDataset\n",
    "from torch.utils.data import DataLoader, RandomSampler, SequentialSampler\n",
    "\n",
    "from collections import OrderedDict\n",
    "\n",
    "from PIL import ImageFile,Image\n",
    "ImageFile.LOAD_TRUNCATED_IMAGES = True\n",
    "\n",
    "from transformers import BertModel\n",
    "from transformers import BertForSequenceClassification, AdamW, BertConfig\n",
    "\n",
    "\n",
    "import logging\n",
    "logging.propagate = False \n",
    "logging.getLogger().setLevel(logging.ERROR)\n",
    "\n",
    "# WandB – Import the wandb library\n",
    "import wandb\n",
    "\n",
    "#from models import MultimodalFramework\n",
    "from model_utils import set_seed, build_optimizer\n",
    "\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(device)\n",
    "\n",
    "torch.cuda.empty_cache()\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(device)\n",
    "\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "def save_model(net, optim, ckpt_fname, epoch):                                                                                                                                                             \n",
    "    state_dict = net.state_dict()                                                                          \n",
    "    for key in state_dict.keys():                                                                                 \n",
    "        state_dict[key] = state_dict[key].cpu()  \n",
    "\n",
    "    torch.save({                                                                                                                                                                                                 \n",
    "        'epoch': epoch,                                                                                                                                                                                     \n",
    "        'state_dict': state_dict,                                                                                                                                                                                \n",
    "        'optimizer': optim},                                                                                                                                                                                     \n",
    "        ckpt_fname)\n",
    "\n",
    "#train_model(model_name, dataloaders_dict, criterion, len_train, len_val, config, path)\n",
    "def train_model(model_name, dataloaders, criterion, len_train, len_val, config, path):\n",
    "    \n",
    "    set_seed(42)\n",
    "    model = MultimodalFramework()\n",
    "    \n",
    "    #torch.cuda.empty_cache()\n",
    "    device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "   \n",
    "    #model = model.to('cuda')\n",
    "    \n",
    "    num_epochs = 1\n",
    "    optimizer = build_optimizer(model, \"adamW\", 0.0001, 0.9)\n",
    "\n",
    "    since = time.time()\n",
    "\n",
    "    val_acc_history = []\n",
    "    val_loss_history = []\n",
    "    train_acc_history = []\n",
    "    train_loss_history = []\n",
    "\n",
    "    best_acc = 0.0\n",
    "    patience = 5 \n",
    "    trigger = 0\n",
    "    acc_dict = {}\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        #scheduler.step()\n",
    "        print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n",
    "        print('-' * 10)\n",
    "\n",
    "        # Each epoch has a training and validation phase\n",
    "        for phase in ['train', 'val']:\n",
    "            if phase == 'train':\n",
    "                length = len_train\n",
    "                model.train()  # Set model to training mode\n",
    "            else:\n",
    "                length = len_val\n",
    "                model.eval()   # Set model to evaluate mode\n",
    "\n",
    "            running_loss = 0.0\n",
    "            running_corrects = 0\n",
    "            predicted_labels, ground_truth_labels = [], []\n",
    "\n",
    "            for modality1, modality2 in zip(dataloaders[phase][0], dataloaders[phase][1]):\n",
    "                \n",
    "                # zero the parameter gradients\n",
    "                optimizer.zero_grad()\n",
    "\n",
    "                # forward\n",
    "                # track history if only in train\n",
    "                with torch.set_grad_enabled(phase == 'train'):\n",
    "                    \n",
    "                    if model_name.split(\"_\")[:2] == [\"bert\", \"resnet\"]:\n",
    "                        text_inp, masks, text_labels = modality2\n",
    "                        img_inp, labels = modality1\n",
    "\n",
    "                        text_inp, masks, text_labels = text_inp.to(device), masks.to(device), text_labels.to(device)\n",
    "                        img_inp, labels = img_inp.to(device), labels.to(device)\n",
    "                        #text_inp, masks, text_labels = text_inp.cuda(), masks.cuda(), text_labels.cuda()\n",
    "                        #img_inp, labels = img_inp.cuda(), labels.cuda()\n",
    "                        \n",
    "                        inp_len = text_inp.size(0)\n",
    "                        outputs = model([img_inp, text_inp, masks], model_name)\n",
    "\n",
    "                    elif model_name.split(\"_\")[:2] == [\"resnet\", \"mlp\"]:\n",
    "                        img_inp, labels = modality1\n",
    "                        tab_inp, tab_labels = modality2\n",
    "                        tab_inp, tab_labels = tab_inp.float(), tab_labels.long()\n",
    "\n",
    "                        tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)\n",
    "                        img_inp, labels = img_inp.to(device), labels.to(device)\n",
    "                        \n",
    "                        inp_len = tab_inp.size(0)\n",
    "                        outputs = model([tab_inp, img_inp], model_name)\n",
    "                    else:\n",
    "                        tab_inp, tab_labels = modality1\n",
    "                        text_inp, masks, labels = modality2\n",
    "                        tab_inp, tab_labels = tab_inp.float(), tab_labels.long()\n",
    "\n",
    "                        tab_inp, tab_labels = tab_inp.to(device), tab_labels.to(device)\n",
    "                        text_inp, masks, labels = text_inp.to(device), masks.to(device), labels.to(device)\n",
    "                        \n",
    "                        inp_len = tab_inp.size(0)\n",
    "                        outputs = model([tab_inp, text_inp, masks], model_name)\n",
    "                    \n",
    "                    loss = criterion(outputs, labels)\n",
    "\n",
    "                    _, preds = torch.max(outputs, 1)\n",
    "\n",
    "                    # backward + optimize only if in training phase\n",
    "                    if phase == 'train':\n",
    "                        loss.backward()\n",
    "                        optimizer.step()\n",
    "\n",
    "                # statistics\n",
    "                #print(\"text_inp.size(0)\")\n",
    "                #print(text_inp.size(0))\n",
    "\n",
    "                running_loss += loss.item() * labels.size(0)\n",
    "                running_corrects += torch.sum(preds == labels.data)\n",
    "                predicted_labels.extend(preds.cpu().detach().numpy())\n",
    "                ground_truth_labels.extend(labels.cpu().detach().numpy())\n",
    "                \n",
    "            epoch_loss = running_loss / length\n",
    "            epoch_acc = running_corrects.double() / length\n",
    "            #epoch_f1 = f1.double() / len(dataloaders[phase].dataset)\n",
    "            epoch_f1 = f1_score(ground_truth_labels, predicted_labels)\n",
    "\n",
    "            print('{} Loss: {} Acc: {}'.format(phase, epoch_loss, epoch_acc))\n",
    "\n",
    "            if phase == 'val':\n",
    "                #wandb.log({\"val_loss\": epoch_loss, \"val_acc\": epoch_acc, \"val_f1\": epoch_f1})\n",
    "                acc_dict[epoch] = float(epoch_acc.detach().cpu())\n",
    "                val_acc_history.append(epoch_acc)\n",
    "                val_loss_history.append(epoch_loss)\n",
    "                save_model(model, optimizer, path+\"_save.pth\", epoch)\n",
    "                #print(model.state_dict())\n",
    "                #torch.save(model.cpu().state_dict(), path+\"_current.pth\")\n",
    "                #model = model.cuda()\n",
    "                if epoch_acc > best_acc:\n",
    "                    best_acc = epoch_acc\n",
    "                    #best_model_wts = copy.deepcopy(model.state_dict())\n",
    "                    #torch.save(model.state_dict(), path+\"_best.pth\")\n",
    "                \"\"\"\n",
    "                if (epoch > 10) and (acc_dict[epoch] <= acc_dict[epoch - 10]):\n",
    "                    trigger +=1\n",
    "                    if trigger >= patience:\n",
    "                        return model, {\"train_acc\":train_acc_history, \"val_acc\":val_acc_history,\"train_loss\":train_loss_history, \"val_loss\":val_loss_history}\n",
    "                else:\n",
    "                    trigger = 0\n",
    "                \"\"\"    \n",
    "            if phase == 'train':\n",
    "                #wandb.log({\"train_loss\": epoch_loss, \"train_acc\": epoch_acc,\"train_f1\": epoch_f1, \"epoch\": epoch})\n",
    "                train_acc_history.append(epoch_acc.detach().cpu())\n",
    "                train_loss_history.append(epoch_loss)\n",
    "\n",
    "\n",
    "    time_elapsed = time.time() - since\n",
    "    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))\n",
    "    print('Best val Acc: {:4f}'.format(best_acc))\n",
    "\n",
    "    # load best model weights\n",
    "    #model.load_state_dict(best_model_wts)\n",
    "    #torch.save(model.cpu().state_dict(), path+\"_last.pth\")\n",
    "    return model, {\"train_acc\":train_acc_history, \"val_acc\":val_acc_history,\"train_loss\":train_loss_history, \"val_loss\":val_loss_history}\n",
    "\n",
    "\n",
    "model_name = \"bert_resnet_luong\"\n",
    "\n",
    "train_inputs_txt = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/train_txt.pt')\n",
    "val_inputs_txt = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/val_txt.pt')\n",
    "\n",
    "train_inputs_img = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/train_img.pt')\n",
    "val_inputs_img = torch.load('/users/mgolovan/data/mgolovan/facebook_memes/data/val_img.pt')\n",
    "\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "\n",
    "train_dataloader_text = DataLoader(train_inputs_txt, batch_size=16,shuffle=False)\n",
    "val_dataloader_text = DataLoader(val_inputs_txt, batch_size=16, shuffle=False)\n",
    "\n",
    "train_dataloader_img = DataLoader(train_inputs_img, batch_size=16,shuffle=False)\n",
    "val_dataloader_img = DataLoader(val_inputs_img, batch_size=16, shuffle=False)\n",
    "\n",
    "len_val = len(val_inputs_txt)\n",
    "len_train = len(train_inputs_txt)\n",
    "\n",
    "dataloaders_dict = {'train':[train_dataloader_img, train_dataloader_text], 'val':[val_dataloader_img, val_dataloader_text]}\n",
    "\n",
    "\n",
    "path = '' + model_name + \"_original_\"\n",
    "\n",
    "model, dic = train_model(model_name, dataloaders_dict, criterion, len_train, len_val, \"config\", path)  \n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca1a5bea",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.Size([16, 1, 256])\n",
    "torch.Size([16, 1, 2, 128])\n",
    "torch.Size([32, 1, 128])\n",
    "torch.Size([16, 1, 256])\n",
    "torch.Size([16, 1, 2, 128])\n",
    "torch.Size([32, 1, 128])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e057dea1",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.Size([16, 1, 256])\n",
    "torch.Size([16, 1, 2, 128])\n",
    "torch.Size([32, 1, 128])\n",
    "torch.Size([16, 1, 256])\n",
    "torch.Size([16, 1, 2, 128])\n",
    "torch.Size([32, 1, 128])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
