{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import requests\n",
    "from transformers import Blip2Processor, BlipForQuestionAnswering, Blip2Model\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoModelForCausalLM, AutoProcessor\n",
    "import torch\n",
    "from PIL import Image\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "import pickle\n",
    "\n",
    "\n",
    "processor = AutoProcessor.from_pretrained('Salesforce/blip-vqa-base')\n",
    "model = BlipForQuestionAnswering.from_pretrained(\"Salesforce/blip-vqa-base\")\n",
    "# model.load_state_dict(torch.load(\"blip/blip_weights.pth\"))\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device)\n",
    "\n",
    "torch.cuda.empty_cache()\n",
    "torch.manual_seed(42)\n",
    "\n",
    "class VQADataset(torch.utils.data.Dataset):\n",
    "    \"\"\"VQA (v2) dataset.\"\"\"\n",
    "\n",
    "    def __init__(self, dataset, processor):\n",
    "        self.dataset = dataset\n",
    "        self.processor = processor\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataset)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        # get image + text\n",
    "        question = self.dataset[idx]['question']\n",
    "        answer = self.dataset[idx]['answer']\n",
    "        image_id = self.dataset[idx]['image_id']\n",
    "        image_path = #image path\n",
    "        image = Image.open(image_path).convert(\"RGB\")\n",
    "        text = question\n",
    "        # print(image_id)\n",
    "        encoding = self.processor(image, text, padding=\"max_length\", truncation=True, return_tensors=\"pt\")\n",
    "        labels = self.processor.tokenizer.encode(\n",
    "            answer, max_length= 8, pad_to_max_length=True, return_tensors='pt'\n",
    "        )\n",
    "        # print(labels)\n",
    "        encoding[\"labels\"] = labels\n",
    "        # remove batch dimension\n",
    "        for k,v in encoding.items():  encoding[k] = v.squeeze()\n",
    "        return encoding\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_dataset = load_dataset(\"json\", data_files= path to train annotations, split=\"train[:70%]\")\n",
    "valid_dataset = load_dataset(\"json\", data_files= path to train annotations, split = \"train[30%:]\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(training_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Training sets: {} - Validating set: {}\".format(len(training_dataset), len(valid_dataset)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = VQADataset(dataset=training_dataset,\n",
    "                          processor=processor)\n",
    "valid_dataset = VQADataset(dataset=valid_dataset,\n",
    "                          processor=processor)\n",
    "batch_size = 4\n",
    "train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)\n",
    "valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)\n",
    "\n",
    "\n",
    "optimizer = torch.optim.AdamW(model.parameters(), lr=4e-5)\n",
    "scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9, last_epoch=-1, verbose=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 100\n",
    "scaler = torch.cuda.amp.GradScaler()\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    epoch_loss = 0\n",
    "    model.train()\n",
    "    for idx, batch in zip(tqdm(range(len(train_dataloader)), desc='Training'), train_dataloader):\n",
    "        input_ids = batch.pop('input_ids').to(device)\n",
    "        pixel_values = batch.pop('pixel_values').to(device)\n",
    "        attention_masked = batch.pop('attention_mask').to(device)\n",
    "        labels = batch.pop('labels').to(device)\n",
    "        \n",
    "        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):\n",
    "            outputs = model(input_ids=input_ids,\n",
    "                        pixel_values=pixel_values,\n",
    "                        attention_mask=attention_masked,\n",
    "                        labels=labels)\n",
    "        loss = outputs.loss\n",
    "        epoch_loss += loss.item()\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        scaler.scale(loss).backward()\n",
    "        scaler.step(optimizer)\n",
    "        scaler.update()\n",
    "    model.eval()\n",
    "    eval_loss = 0\n",
    "    for idx, batch in zip(tqdm(range(len(valid_dataloader)), desc='Validation'), valid_dataloader):\n",
    "        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):\n",
    "            outputs = model(input_ids=batch.pop('input_ids').to(device),\n",
    "                        pixel_values=batch.pop('pixel_values').to(device),\n",
    "                        attention_mask=batch.pop('attention_mask').to(device),\n",
    "                        labels=batch.pop('labels').to(device))\n",
    "        loss = outputs.loss\n",
    "        eval_loss += loss.item()\n",
    "    print(\"Epoch: {} - Training loss: {} - Eval Loss: {} - LR: {}\".format(epoch+1, epoch_loss/len(train_dataloader), eval_loss/len(valid_dataloader), optimizer.param_groups[0][\"lr\"]))\n",
    "    scheduler.step()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "question_file =#test questions file\n",
    "\n",
    "with open(question_file, 'r') as input_file:\n",
    "    questions = json.load(input_file)\n",
    "\n",
    "temp = questions['questions']\n",
    "with open(\"file.json\", \"w\") as file:\n",
    "    # Dump the data into the file as JSON\n",
    "    json.dump(temp, file, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "qs = []\n",
    "for question in questions[\"questions\"]:\n",
    "    qs.append(question['question'])\n",
    "imgs = []\n",
    "for question in questions[\"questions\"]:\n",
    "    imgs.append(question[\"image_id\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(qs))\n",
    "print(len(imgs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "annotation_file = #test annotations file\n",
    "\n",
    "with open(annotation_file, 'r') as input_file:\n",
    "    annotations = json.load(input_file)\n",
    "\n",
    "temp = annotations['annotations']\n",
    "with open(\"file.json\", \"w\") as file:\n",
    "    # Dump the data into the file as JSON\n",
    "    json.dump(temp, file, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "answers = []\n",
    "for annotation in annotations['annotations']:\n",
    "    answers.append(annotation['multiple_choice_answer'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(answers))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import requests\n",
    "from transformers import AutoProcessor, BlipForQuestionAnswering\n",
    "import torch\n",
    "\n",
    "model = BlipForQuestionAnswering.from_pretrained(\"Salesforce/blip-vqa-base\")\n",
    "model.load_state_dict(torch.load(\"blip/blip_weights.pth\"))\n",
    "processor = AutoProcessor.from_pretrained(\"Salesforce/blip-vqa-base\")\n",
    "res = []\n",
    "correct = 0\n",
    "for idx in range(len(qs)):\n",
    "    image_path = #path to image\n",
    "    image = Image.open(image_path).convert(\"RGB\")\n",
    "    text = qs[idx]\n",
    "    inputs = processor(images=image, text=text, return_tensors=\"pt\")\n",
    "    outputs = model.generate(**inputs)\n",
    "    print(idx)\n",
    "    out = (processor.decode(outputs[0], labels = answers[idx], skip_special_tokens=True))\n",
    "    res.append(out)\n",
    "    print(res[idx], answers[idx])\n",
    "    if(res[idx]==answers[idx]):\n",
    "        correct=correct+1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(correct)\n",
    "print(len(res))\n",
    "print(correct/(len(res)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "florence",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
