{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import requests\n",
    "from transformers import Blip2Processor, BlipForQuestionAnswering, Blip2Model\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoModelForCausalLM, AutoProcessor\n",
    "import torch\n",
    "from PIL import Image\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_dataset = load_dataset(\"json\", data_files=#path to train annotations, split=\"train[:70%]\")\n",
    "valid_dataset = load_dataset(\"json\", data_files=#train file path, split=\"train[30%:]\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "from torch.utils.data import Dataset \n",
    "\n",
    "class VQADataset(Dataset): \n",
    "\n",
    "    def __init__(self, dataset, processor):\n",
    "        self.dataset = dataset\n",
    "        self.processor = processor\n",
    "        \n",
    "    def __len__(self): \n",
    "        return len(self.dataset)\n",
    "        \n",
    "    def __getitem__(self, idx):\n",
    "        question = self.dataset[idx]['question']\n",
    "        answer = self.dataset[idx]['answer']\n",
    "        image_id = self.dataset[idx]['image_id']\n",
    "        image_path = #image path\n",
    "        image = Image.open(image_path).convert(\"RGB\")\n",
    "        return question, answer, image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoProcessor\n",
    "import torch\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\") \n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    \"microsoft/Florence-2-base-ft\",\n",
    "    trust_remote_code=True,\n",
    "    revision='refs/pr/6'\n",
    ").to(device) \n",
    "processor = AutoProcessor.from_pretrained(\"microsoft/Florence-2-base-ft\", \n",
    "    trust_remote_code=True, revision='refs/pr/6')\n",
    "\n",
    "for param in model.vision_tower.parameters():\n",
    "  param.is_trainable = False\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Training sets: {} - Validating set: {}\".format(len(training_dataset), len(valid_dataset)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm \n",
    "from transformers import AdamW, get_scheduler\n",
    "\n",
    "def collate_fn(batch): \n",
    "    questions, answers, images = zip(*batch)\n",
    "    inputs = processor(text=list(questions), images=list(images), return_tensors=\"pt\", padding=True).to(device)\n",
    "    return inputs, answers \n",
    "train_dataset = VQADataset(dataset=training_dataset,\n",
    "                          processor=processor)\n",
    "valid_dataset = VQADataset(dataset=valid_dataset,\n",
    "                          processor=processor)\n",
    "\n",
    "batch_size = 1\n",
    "num_workers = 0\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, \n",
    "                          collate_fn=collate_fn, num_workers=num_workers, shuffle=True)\n",
    "val_loader = DataLoader(valid_dataset, batch_size=batch_size, \n",
    "                          collate_fn=collate_fn, num_workers=num_workers)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 10\n",
    "optimizer = AdamW(model.parameters(), lr=1e-6)\n",
    "num_training_steps = epochs * len(train_loader)\n",
    "\n",
    "lr_scheduler = get_scheduler(name=\"linear\", optimizer=optimizer, \n",
    "                              num_warmup_steps=0, num_training_steps=num_training_steps,)\n",
    "\n",
    "for epoch in range(epochs): \n",
    "    model.train() \n",
    "    train_loss = 0\n",
    "    i = -1\n",
    "    idx = 0\n",
    "    for inputs, answers in tqdm(train_loader, desc=f\"Training Epoch {epoch + 1}/{epochs}\"):\n",
    "        i += 1\n",
    "        input_ids = inputs[\"input_ids\"]\n",
    "        pixel_values = inputs[\"pixel_values\"] \n",
    "        labels = processor.tokenizer(text=answers, return_tensors=\"pt\", padding=True, return_token_type_ids=False).input_ids.to(device)\n",
    "        outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)\n",
    "        loss = outputs.loss\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "        train_loss += loss.item()\n",
    "        idx+=1\n",
    "    avg_train_loss = train_loss / len(train_loader)\n",
    "    print(f\"Average Training Loss: {avg_train_loss}\")\n",
    "\n",
    "    model.eval()\n",
    "    val_loss = 0\n",
    "    with torch.no_grad():\n",
    "        for batch in tqdm(val_loader, desc=f\"Validation Epoch {epoch + 1}/{epochs}\"):\n",
    "            inputs, answers = batch\n",
    "            input_ids = inputs[\"input_ids\"]\n",
    "            pixel_values = inputs[\"pixel_values\"]\n",
    "            labels = processor.tokenizer(text=answers, return_tensors=\"pt\", padding=True, return_token_type_ids=False).input_ids.to(device)\n",
    "            outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)\n",
    "            loss = outputs.loss\n",
    "            val_loss += loss.item()\n",
    "\n",
    "        print(val_loss / len(val_loader))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "question_file = #test questions \n",
    "\n",
    "with open(question_file, 'r') as input_file:\n",
    "    questions = json.load(input_file)\n",
    "\n",
    "temp = questions['questions']\n",
    "with open(\"file.json\", \"w\") as file:\n",
    "    # Dump the data into the file as JSON\n",
    "    json.dump(temp, file, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "qs = []\n",
    "for question in questions[\"questions\"]:\n",
    "    qs.append(question['question'])\n",
    "imgs = []\n",
    "for question in questions[\"questions\"]:\n",
    "    imgs.append(question[\"image_id\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(qs))\n",
    "print(len(imgs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "annotation_file = #test annotations\n",
    "\n",
    "with open(annotation_file, 'r') as input_file:\n",
    "    annotations = json.load(input_file)\n",
    "\n",
    "temp = annotations['annotations']\n",
    "with open(\"file.json\", \"w\") as file:\n",
    "    # Dump the data into the file as JSON\n",
    "    json.dump(temp, file, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "answers = []\n",
    "for annotation in annotations['annotations']:\n",
    "    answers.append(annotation['multiple_choice_answer'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(answers))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import requests\n",
    "from transformers import AutoProcessor, BlipForQuestionAnswering\n",
    "import torch\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\") \n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    \"microsoft/Florence-2-base-ft\",\n",
    "    trust_remote_code=True,\n",
    "    revision='refs/pr/6'\n",
    ").to(device) \n",
    "processor = AutoProcessor.from_pretrained(\"microsoft/Florence-2-base-ft\", \n",
    "    trust_remote_code=True, revision='refs/pr/6')\n",
    "model.load_state_dict(torch.load(\"blip/florence_weights.pth\"))\n",
    "res = []\n",
    "correct = 0\n",
    "idx = 0\n",
    "for i in range(len(qs)):\n",
    "    image_path = #image path \n",
    "    image = Image.open(image_path).convert(\"RGB\")\n",
    "    text = qs[idx]\n",
    "    inputs = processor(images=image, text=text, return_tensors=\"pt\").to(device)\n",
    "    input_ids = inputs[\"input_ids\"]\n",
    "    pixel_values = inputs[\"pixel_values\"]\n",
    "    preds = model.generate(input_ids=input_ids, pixel_values=pixel_values,max_new_tokens = 1024, num_beams = 3)\n",
    "    generated_text = processor.batch_decode(preds, skip_special_tokens=True)[0]\n",
    "    res.append(generated_text)\n",
    "    print(res[idx], answers[idx])\n",
    "    if(res[idx]==answers[idx]):\n",
    "        correct=correct+1\n",
    "    idx+=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(correct/len(res))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "florence",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
