{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_positions(A, B):\n",
    "    positions = []\n",
    "    added_elements = set()  \n",
    "    for num in A:\n",
    "        if num not in added_elements and num in B:\n",
    "            positions.append(B.index(num))\n",
    "            added_elements.add(num)\n",
    "    return positions, list(added_elements)\n",
    "    \n",
    "def extract_elements_by_positions(remain_map, image_id, pred):\n",
    "    ind = []\n",
    "    ind_temp = []\n",
    "    ind_multiple = []\n",
    "    \n",
    "\n",
    "    for p in pred:\n",
    "\n",
    "        count = image_id.count(p)\n",
    "        \n",
    "        if count > 1:\n",
    "            ind_multiple.append(p)\n",
    "            continue\n",
    "\n",
    "        index = image_id.index(p)\n",
    "        ind_temp.append(index)\n",
    "\n",
    "    for p in ind_multiple:\n",
    "        min_distance_sum = float('inf')\n",
    "        min_index = None\n",
    "        \n",
    "\n",
    "        for index, value in enumerate(image_id):\n",
    "            if value == p:\n",
    "\n",
    "                distance_sum = sum(abs(index - temp_index) for temp_index in ind_temp)\n",
    "                \n",
    "\n",
    "                if distance_sum < min_distance_sum and index not in ind_temp:\n",
    "                    min_distance_sum = distance_sum\n",
    "                    min_index = index\n",
    "        \n",
    "\n",
    "        if min_index is not None:\n",
    "            ind_temp.append(min_index)\n",
    "    \n",
    "\n",
    "    for index in ind_temp:\n",
    "        ind.append(remain_map[index])\n",
    "    \n",
    "    return ind\n",
    "\n",
    "\n",
    "def filter_elements(J, K):\n",
    "    filtered_list = [elem for elem in J if elem in K]\n",
    "    return filtered_list\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "data = [json.loads(q) for q in open('your_results', \"r\")]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "loc_list = []\n",
    "null_count = 0\n",
    "intersect = []\n",
    "reli = []\n",
    "prec = []\n",
    "iou = []\n",
    "for item in data:\n",
    "    M = item['image_map']\n",
    "    A = item['predict_tokens'][1:]\n",
    "    A_ = filter_elements(A,M)\n",
    "    B = item['gt_tokens']\n",
    "    num_pred = len(A)\n",
    "    corr,corr_list = find_positions(A,B)\n",
    "    ratio = len(corr)/num_pred\n",
    "    precision = len(corr)/len(B)\n",
    "    union = len(A_)+len(B)-len(corr)\n",
    "    iou_ = len(corr)/union\n",
    "    iou.append(iou_)\n",
    "    prec.append(precision)\n",
    "    intersect.append(ratio)\n",
    "    C = item['remained_map']\n",
    "    D = item['image_map']\n",
    "    findd = extract_elements_by_positions(C,D,A_)\n",
    "    loc = {}\n",
    "    loc['loc'] = findd\n",
    "    loc['corr'] = corr_list\n",
    "    loc['gt_box'] = item['gt_answer']\n",
    "    loc['image'] = item['file_name']\n",
    "    loc['object'] = item['prompt']\n",
    "    loc['precision'] = precision #### intersect with gt\n",
    "    loc['pred_intersect_gt'] = '{:.3f}'.format(ratio) ###previous overlap, \n",
    "    loc['iou'] = iou_\n",
    "    loc['reliability'] = '{:.3f}'.format(len(A_)/len(A)) ###how many of pred is usable?\n",
    "    reli.append(len(A_)/len(A))\n",
    "    if ratio==0:\n",
    "        null_count = null_count + 1\n",
    "    loc_list.append(loc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.4957919046483946\n",
      "0.334850178651876\n",
      "0.2676380397075303\n",
      "0.21899404960448327\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "reli = np.array(reli)\n",
    "intersect = np.array(intersect)\n",
    "prec = np.array(prec)\n",
    "iou = np.array(iou)\n",
    "print(reli.mean())\n",
    "print(intersect.mean())\n",
    "print(prec.mean())\n",
    "print(iou.mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('output_file', 'w', encoding='utf-8') as f:\n",
    "    json.dump(loc_list, f, ensure_ascii=False, indent=4)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llava",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
