{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3657da81",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tabulate import tabulate\n",
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a55c508d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1e6cda9",
   "metadata": {},
   "source": [
    "# Utility Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24d99900",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_json_from_path(file_path):\n",
    "    try:\n",
    "        with open(file_path, 'r') as json_file:\n",
    "            data = json.load(json_file)\n",
    "        return data\n",
    "    except FileNotFoundError:\n",
    "        print(f\"File not found at: {file_path}\")\n",
    "        return None\n",
    "    except json.JSONDecodeError:\n",
    "        print(f\"Invalid JSON format in file: {file_path}\")\n",
    "        return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e995859",
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_clevr_cot_output(response_map):\n",
    "    one_word_map = {}\n",
    "    \n",
    "    for k in response_map:\n",
    "        one_word_map[k] = []\n",
    "        \n",
    "        for a in response_map[k]:\n",
    "            words = a.split()\n",
    "            last_word = words[-1]\n",
    "            # Remove the full stop (.) if it exists\n",
    "#             last_word = last_word.rstrip(\".\")\n",
    "#             last_word = last_word.rstrip(\"?\")\n",
    "#             if ':' in last_word:\n",
    "#                 last_word = last_word.split(':')[0]\n",
    "#             if '?' in last_word:\n",
    "#                 last_word = last_word.split('?')[-1]\n",
    "#             last_word = last_word.replace('[','')\n",
    "#             last_word = last_word.replace(']','')\n",
    "#             last_word = last_word.replace('\"','')\n",
    "#             last_word = last_word.replace(\"'\",'')\n",
    "#             last_word = last_word.replace(\",\",'')\n",
    "#             last_word = last_word.replace(\")\",'')\n",
    "#             last_word = last_word.replace(\"(\",'')\n",
    "#             last_word = last_word.replace(\";\",'')\n",
    "            one_word_map[k].append(last_word.lower())\n",
    "    \n",
    "    return one_word_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5a24c23",
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_ptr_cot_output(response_map):\n",
    "    one_word_map = {}\n",
    "    \n",
    "    for k in response_map:\n",
    "        one_word_map[k] = []\n",
    "        \n",
    "        for a in response_map[k]:\n",
    "            words = a.split()\n",
    "            last_word = words[-1]\n",
    "            # Remove the full stop (.) if it exists\n",
    "            last_word = last_word.rstrip(\".\")\n",
    "            last_word = last_word.rstrip(\"?\")\n",
    "            if ':' in last_word:\n",
    "                last_word = last_word.split(':')[0]\n",
    "            if ',' in last_word:\n",
    "                last_word = last_word.split(',')[0]\n",
    "#             if '?' in last_word:\n",
    "#                 last_word = last_word.split('?')[-1]\n",
    "            last_word = last_word.replace('[','')\n",
    "            last_word = last_word.replace(']','')\n",
    "            last_word = last_word.replace('\"','')\n",
    "            last_word = last_word.replace(\"'\",'')\n",
    "#             last_word = last_word.replace(\",\",'')\n",
    "            last_word = last_word.replace(\")\",'')\n",
    "            last_word = last_word.replace(\"(\",'')\n",
    "            last_word = last_word.replace(\";\",'')\n",
    "            last_word = last_word.replace(\":\",'')\n",
    "            one_word_map[k].append(last_word.lower())\n",
    "    \n",
    "    return one_word_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61320c12",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_qn_family_clevr(reasoning_step):\n",
    "    function = reasoning_step['function']\n",
    "    if 'exist' in function:\n",
    "        return 'exist'\n",
    "    if 'query' in function:\n",
    "        return 'query attribute'\n",
    "    if 'equal' in function:\n",
    "        return 'compare attribute'\n",
    "    if 'count' in function:\n",
    "        return 'count'\n",
    "    if 'than' in function:\n",
    "        return 'compare numbers'\n",
    "    return 'unknown'\n",
    "\n",
    "def evaluate_clevr_flant5_non_cot(response_map, scene_mapping):\n",
    "    total = 0\n",
    "    correct = 0\n",
    "    question_family_index_performance = {}\n",
    "    major_question_family_performance = {}\n",
    "    reasoning_step_performance = {}\n",
    "    \n",
    "    for k in tqdm(response_map):\n",
    "        predictions = response_map[k]\n",
    "        scene_qns = scene_mapping[k]['questions']\n",
    "        \n",
    "        for i, pred in enumerate(predictions):\n",
    "            qn = scene_qns[i]\n",
    "            qn_family = qn['question_family_index']\n",
    "            ans = qn['answer']\n",
    "            major_family = get_qn_family_clevr(qn['program'][-1])\n",
    "            reasoning_steps = len(qn['program'])\n",
    "            \n",
    "            \n",
    "            if qn_family not in question_family_index_performance:\n",
    "                question_family_index_performance[qn_family] = {}\n",
    "                question_family_index_performance[qn_family]['total'] = 0\n",
    "                question_family_index_performance[qn_family]['correct'] = 0\n",
    "            \n",
    "            if major_family not in major_question_family_performance:\n",
    "                major_question_family_performance[major_family] = {}\n",
    "                major_question_family_performance[major_family]['total'] = 0\n",
    "                major_question_family_performance[major_family]['correct'] = 0\n",
    "            \n",
    "            if reasoning_steps not in reasoning_step_performance:\n",
    "                reasoning_step_performance[reasoning_steps] = {}\n",
    "                reasoning_step_performance[reasoning_steps]['total'] = 0\n",
    "                reasoning_step_performance[reasoning_steps]['correct'] = 0\n",
    "            \n",
    "            total +=1\n",
    "            question_family_index_performance[qn_family]['total'] +=1\n",
    "            major_question_family_performance[major_family]['total']+=1\n",
    "            reasoning_step_performance[reasoning_steps]['total'] +=1\n",
    "            \n",
    "            if pred == ans:\n",
    "                correct +=1\n",
    "                question_family_index_performance[qn_family]['correct'] +=1\n",
    "                major_question_family_performance[major_family]['correct']+=1\n",
    "                reasoning_step_performance[reasoning_steps]['correct'] +=1\n",
    "    \n",
    "    return total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance\n",
    "                "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fab93e12",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_qn_family_ptr(reasoning_step):\n",
    "    \n",
    "    concept = ['query_object-category','query_part-category','query_part-color','count','exist']\n",
    "#     relation = [] # The rest\n",
    "    analogy = ['query_geometric-analogy-color','query_geometric-analogy-count','query_positional-analogy-category','query_positional-analogy-count','query_positional-analogy-exist']\n",
    "    arithmetic = ['equal_integer','greater_than','less_than','minus_less','minus_more','sum']\n",
    "    physics = ['query_stability','query_unstability','query_change']\n",
    "    \n",
    "    \n",
    "    function = reasoning_step['type']\n",
    "    if function in concept:\n",
    "        return 'concept'\n",
    "    if function in analogy:\n",
    "        return 'analogy'\n",
    "    if function in arithmetic:\n",
    "        return 'arithmetic'\n",
    "    if function in physics:\n",
    "        return 'physics'\n",
    "    return 'relation'\n",
    "\n",
    "def evaluate_ptr_flant5_non_cot(response_map, scene_mapping):\n",
    "    total = 0\n",
    "    correct = 0\n",
    "    question_family_index_performance = {}\n",
    "    major_question_family_performance = {}\n",
    "    reasoning_step_performance = {}\n",
    "    \n",
    "    for k in tqdm(response_map):\n",
    "        predictions = response_map[k]\n",
    "        scene_qns = scene_mapping[k]['questions']\n",
    "        \n",
    "        for i, pred in enumerate(predictions):\n",
    "            qn = scene_qns[i]\n",
    "            qn_family = qn['question_family_index']\n",
    "            ans = str(qn['answer'])\n",
    "            major_family = get_qn_family_ptr(qn['program'][-1])\n",
    "            reasoning_steps = len(qn['program'])\n",
    "            \n",
    "            if qn_family not in question_family_index_performance:\n",
    "                question_family_index_performance[qn_family] = {'total': 0, 'correct': 0}\n",
    "                \n",
    "            if major_family not in major_question_family_performance:\n",
    "                major_question_family_performance[major_family] = {'total': 0, 'correct': 0}\n",
    "                \n",
    "            if reasoning_steps not in reasoning_step_performance:\n",
    "                reasoning_step_performance[reasoning_steps] = {'total': 0, 'correct': 0}\n",
    "                \n",
    "            total += 1\n",
    "            question_family_index_performance[qn_family]['total'] += 1\n",
    "            major_question_family_performance[major_family]['total'] += 1\n",
    "            reasoning_step_performance[reasoning_steps]['total'] += 1\n",
    "            \n",
    "            pred = pred.lower()\n",
    "            ans = ans.lower()\n",
    "            \n",
    "            if pred == 'yes':\n",
    "                pred = 'true'\n",
    "            if pred == 'no':\n",
    "                pred = 'false'\n",
    "            \n",
    "            if pred == ans:\n",
    "                correct += 1\n",
    "                question_family_index_performance[qn_family]['correct'] += 1\n",
    "                major_question_family_performance[major_family]['correct'] += 1\n",
    "                reasoning_step_performance[reasoning_steps]['correct'] += 1\n",
    "    \n",
    "    return total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance\n",
    "                "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15d4c1f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_accuracy_by_family_index(results_dict):\n",
    "    family_indices = sorted(results_dict.keys())\n",
    "    accuracies = [results_dict[idx]['correct'] / results_dict[idx]['total'] for idx in family_indices]\n",
    "    \n",
    "    plt.figure(figsize=(20, 6))\n",
    "    sns.barplot(x=family_indices, y=accuracies)\n",
    "    plt.xlabel('Question Family Index')\n",
    "    plt.ylabel('Accuracy')\n",
    "    plt.title('Accuracy by Question Family Index')\n",
    "    plt.show()\n",
    "\n",
    "def plot_correct_vs_total_by_family_index(results_dict):\n",
    "    family_indices = sorted(results_dict.keys())\n",
    "    correct = [results_dict[idx]['correct'] for idx in family_indices]\n",
    "    total = [results_dict[idx]['total'] for idx in family_indices]\n",
    "\n",
    "    plt.figure(figsize=(20, 6))\n",
    "    plt.bar(family_indices, total, label='Total Questions')\n",
    "    plt.bar(family_indices, correct, label='Correct Answers')\n",
    "    plt.xlabel('Question Family Index')\n",
    "    plt.ylabel('Count')\n",
    "    plt.title('Correct vs Total Questions by Question Family Index')\n",
    "    plt.legend()\n",
    "    plt.show()\n",
    "\n",
    "def plot_heatmap(results_dict):\n",
    "    family_indices = sorted(results_dict.keys())\n",
    "    accuracies = [results_dict[idx]['correct'] / results_dict[idx]['total'] for idx in family_indices]\n",
    "\n",
    "    heatmap_data = {\n",
    "        'Question Family Index': family_indices,\n",
    "        'Accuracy': accuracies,\n",
    "    }\n",
    "\n",
    "    plt.figure(figsize=(20, 6))\n",
    "    sns.heatmap([accuracies], xticklabels=family_indices, yticklabels=[\"Accuracy\"], cmap=\"YlGnBu\", annot=True, fmt=\".2f\")\n",
    "    plt.xlabel('Question Family Index')\n",
    "    plt.title('Accuracy Heatmap by Question Family Index')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e6409de",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_answer_set(mapping):\n",
    "    answer_set = set()\n",
    "    count = 0\n",
    "    for k in mapping:\n",
    "        answers = mapping[k]\n",
    "        for a in answers:\n",
    "            answer_set.add(str(a).lower())\n",
    "    return answer_set\n",
    "\n",
    "def get_gt_set(scene_map):\n",
    "    gt_set = set()\n",
    "    for k in scene_map:\n",
    "        qns = scene_map[k]['questions']\n",
    "        for q in qns:\n",
    "            gt = str(q['answer']).lower()\n",
    "            gt_set.add(gt)\n",
    "    return gt_set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "025191cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_dict_table(d):\n",
    "    # Extract header and rows\n",
    "    header = ['Key', 'Total', 'Correct', 'Accuracy']\n",
    "    rows = [[k, v['total'], v['correct'], v['correct']/v['total'] if v['total'] != 0 else 0] for k, v in d.items()]\n",
    "\n",
    "    # Print table\n",
    "    print(tabulate(rows, headers=header, tablefmt='pretty'))\n",
    "\n",
    "def print_pd_dict_table(d):\n",
    "    # Convert dictionary to DataFrame\n",
    "    df = pd.DataFrame.from_dict(d, orient='index')\n",
    "\n",
    "    # Calculate accuracy\n",
    "    df['accuracy'] = df['correct'] / df['total']\n",
    "    df = df.replace([np.inf, -np.inf], pd.NA)\n",
    "\n",
    "    # Sort by accuracy\n",
    "    df = df.sort_values('accuracy', ascending=False)\n",
    "\n",
    "    # Reset index and rename columns\n",
    "    df = df.reset_index().rename(columns={'index': 'Key'})\n",
    "    \n",
    "    # Print DataFrame\n",
    "    pd.set_option('display.max_rows', None)\n",
    "    # Print DataFrame\n",
    "    print(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "520de619",
   "metadata": {},
   "outputs": [],
   "source": [
    "def analyse_eval(total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance):\n",
    "    print(total)\n",
    "    print()\n",
    "    print(correct)\n",
    "    print()\n",
    "    print(correct/total)\n",
    "    print()\n",
    "    print('Question Family Index Performance:')\n",
    "    print(question_family_index_performance)\n",
    "    print()\n",
    "    print_pd_dict_table(question_family_index_performance)\n",
    "    plot_accuracy_by_family_index(question_family_index_performance)\n",
    "    plot_correct_vs_total_by_family_index(question_family_index_performance)\n",
    "    plot_heatmap(question_family_index_performance)\n",
    "\n",
    "    print('Major Question Family Performance:')\n",
    "    print(major_question_family_performance)\n",
    "    print()\n",
    "    print_pd_dict_table(major_question_family_performance)\n",
    "    plot_accuracy_by_family_index(major_question_family_performance)\n",
    "    plot_correct_vs_total_by_family_index(major_question_family_performance)\n",
    "    plot_heatmap(major_question_family_performance)\n",
    "\n",
    "    print('Reasoning Step Performance:')\n",
    "    print(reasoning_step_performance)\n",
    "    print()\n",
    "    print_pd_dict_table(reasoning_step_performance)\n",
    "    plot_accuracy_by_family_index(reasoning_step_performance)\n",
    "    plot_correct_vs_total_by_family_index(reasoning_step_performance)\n",
    "    plot_heatmap(reasoning_step_performance)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52d67e85",
   "metadata": {},
   "source": [
    "# Clevr\n",
    "\n",
    "clevr categories fall into 5 classes of tasks: Exist, Count, Compare Integer, Query Attribute and Compare Attribute.\n",
    "\n",
    "There are a total of 90 question families indicating different logical operations, all of which fall under one or more of these broader categories."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99ed2794",
   "metadata": {},
   "outputs": [],
   "source": [
    "CLEVR_SCENE_MAPPING_PATH = './clevr_val_scene_mapping.json'\n",
    "clevr_scene_mapping = load_json_from_path(CLEVR_SCENE_MAPPING_PATH)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b1afe53",
   "metadata": {},
   "source": [
    "## Flant5_xl_non_Cot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b76e096",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9334ce1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# flan_t5_xl_cot_clevr == flan_t5_xxl_cot_clevr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8d68519",
   "metadata": {},
   "outputs": [],
   "source": [
    "clevr_scene_mapping['0']['questions']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4482a8c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "clevr_gt_set = get_gt_set(clevr_scene_mapping)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83748e0d",
   "metadata": {},
   "source": [
    "## flan_t5_xl_non_cot_clevr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38deb85a",
   "metadata": {},
   "outputs": [],
   "source": [
    "flan_t5_xl_non_cot_clevr = load_json_from_path('/Users/aishiknagar/Desktop/vqa_experiments/clevr_results/flant5_xl_non_cot_clevr/response_mapping_flan_t5_xl_final.json')\n",
    "answer_set = get_answer_set(flan_t5_xl_non_cot_clevr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de86fdb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "answer_set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22be0357",
   "metadata": {},
   "outputs": [],
   "source": [
    "total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance = evaluate_clevr_flant5_non_cot(flan_t5_xl_non_cot_clevr, clevr_scene_mapping)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "068bb8b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "analyse_eval(total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ffc1985",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "f0c47c52",
   "metadata": {},
   "source": [
    "## flan_t5_xl_cot_clevr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4562fb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_clevr_cot_output(response_map):\n",
    "    one_word_map = {}\n",
    "    \n",
    "    for k in response_map:\n",
    "        one_word_map[k] = []\n",
    "        \n",
    "        for a in response_map[k]:\n",
    "            words = a.split()\n",
    "            last_word = words[-1]\n",
    "#             Remove the full stop (.) if it exists\n",
    "            last_word = last_word.rstrip(\".\")\n",
    "            last_word = last_word.rstrip(\"?\")\n",
    "            if ':' in last_word:\n",
    "                last_word = last_word.split(':')[0]\n",
    "            if '?' in last_word:\n",
    "                last_word = last_word.split('?')[-1]\n",
    "            last_word = last_word.replace('[','')\n",
    "            last_word = last_word.replace(']','')\n",
    "            last_word = last_word.replace('\"','')\n",
    "            last_word = last_word.replace(\"'\",'')\n",
    "            last_word = last_word.replace(\",\",'')\n",
    "            last_word = last_word.replace(\")\",'')\n",
    "            last_word = last_word.replace(\"(\",'')\n",
    "            last_word = last_word.replace(\";\",'')\n",
    "            one_word_map[k].append(last_word.lower())\n",
    "    \n",
    "    return one_word_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d32e8d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "flan_t5_xl_cot_clevr = load_json_from_path('/Users/aishiknagar/Desktop/vqa_experiments/clevr_results/flant5_xl_cot_clevr/response_mapping_flan_t5_xl_cot.json')\n",
    "processed_cot_output = process_clevr_cot_output(flan_t5_xl_cot_clevr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88eb82af",
   "metadata": {},
   "outputs": [],
   "source": [
    "answer_set = get_answer_set(processed_cot_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a29c0be",
   "metadata": {},
   "outputs": [],
   "source": [
    "answer_set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2a6c502",
   "metadata": {},
   "outputs": [],
   "source": [
    "total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance = evaluate_clevr_flant5_non_cot(processed_cot_output, clevr_scene_mapping)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "baa0c5b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "analyse_eval(total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fda383c",
   "metadata": {},
   "source": [
    "## flan_t5_xxl_cot_clevr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ad6fffb",
   "metadata": {},
   "outputs": [],
   "source": [
    "flan_t5_xxl_cot_clevr = load_json_from_path('/Users/aishiknagar/Desktop/vqa_experiments/clevr_results/flant_t5_xxl_cot_clevr/response_mapping_flan_t5_xxl_cot.json')\n",
    "processed_cot_output = process_clevr_cot_output(flan_t5_xxl_cot_clevr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d839f612",
   "metadata": {},
   "outputs": [],
   "source": [
    "answer_set = get_answer_set(processed_cot_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a5efca5",
   "metadata": {},
   "outputs": [],
   "source": [
    "answer_set "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95884344",
   "metadata": {},
   "outputs": [],
   "source": [
    "total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance = evaluate_clevr_flant5_non_cot(processed_cot_output, clevr_scene_mapping)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54fc3675",
   "metadata": {},
   "outputs": [],
   "source": [
    "analyse_eval(total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56faf455",
   "metadata": {},
   "source": [
    "# PTR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6015abde",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load scene mapping\n",
    "ptr_scene_mapping = load_json_from_path('./ptr_val_scene_mapping.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d443b0b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "ptr_gt_set = get_gt_set(ptr_scene_mapping)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d12694ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "ptr_scene_mapping['PTR_val_007239']['questions']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2a24b4a",
   "metadata": {},
   "source": [
    "## flan_t5_xl_non_cot_ptr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57e46492",
   "metadata": {},
   "outputs": [],
   "source": [
    "flan_t5_xl_non_cot_ptr = load_json_from_path('./ptr_response maps/flant5_xl_ptr_non_cot/response_mapping_flan_t5_xl.json')\n",
    "ans_set = get_answer_set(flan_t5_xl_non_cot_ptr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16dc1244",
   "metadata": {},
   "outputs": [],
   "source": [
    "ans_set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe369bba",
   "metadata": {},
   "outputs": [],
   "source": [
    "total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance = evaluate_ptr_flant5_non_cot(flan_t5_xl_non_cot_ptr, ptr_scene_mapping)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73045e54",
   "metadata": {},
   "outputs": [],
   "source": [
    "analyse_eval(total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61bc55e0",
   "metadata": {},
   "source": [
    "## flan_t5_xxl_non_cot_ptr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "284ccfc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "flan_t5_xxl_non_cot_ptr = load_json_from_path('./ptr_response maps/flant5_xxl_ptr_non_cot/response_mapping_flan_t5_xxl.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1422991",
   "metadata": {},
   "outputs": [],
   "source": [
    "# flan_t5_xxl_non_cot_ptr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa15abf6",
   "metadata": {},
   "outputs": [],
   "source": [
    "ans_set = get_answer_set(flan_t5_xxl_non_cot_ptr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c5772f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "ans_set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6598a4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance = evaluate_ptr_flant5_non_cot(flan_t5_xxl_non_cot_ptr, ptr_scene_mapping)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eae56431",
   "metadata": {},
   "outputs": [],
   "source": [
    "analyse_eval(total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42af1c5c",
   "metadata": {},
   "source": [
    "## flan_t5_xl_cot_ptr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19c019eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "flan_t5_xl_cot_ptr= load_json_from_path('./ptr_response maps/flant5_xl_cot_ptr/response_mapping_flan_t5_xl_cot.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3988272a",
   "metadata": {},
   "outputs": [],
   "source": [
    "processed_cot_output = process_ptr_cot_output(flan_t5_xl_cot_ptr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2e19277",
   "metadata": {},
   "outputs": [],
   "source": [
    "ans_set = get_answer_set(processed_cot_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e383d39",
   "metadata": {},
   "outputs": [],
   "source": [
    "ans_set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0941f44d",
   "metadata": {},
   "outputs": [],
   "source": [
    "total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance = evaluate_ptr_flant5_non_cot(processed_cot_output, ptr_scene_mapping)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2dadd5fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "analyse_eval(total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1a3c9ee",
   "metadata": {},
   "source": [
    "## flan_t5_xxl_cot_ptr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52b80e12",
   "metadata": {},
   "outputs": [],
   "source": [
    "flan_t5_xxl_cot_ptr = load_json_from_path('./ptr_response maps/flant5_xxl_ptr_cot/response_mapping_flan_t5_xxl_cot.json')\n",
    "processed_cot_output = process_ptr_cot_output(flan_t5_xxl_cot_ptr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1682272",
   "metadata": {},
   "outputs": [],
   "source": [
    "ans_set = get_answer_set(processed_cot_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46dc3895",
   "metadata": {},
   "outputs": [],
   "source": [
    "ans_set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15b6c47a",
   "metadata": {},
   "outputs": [],
   "source": [
    "total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance = evaluate_ptr_flant5_non_cot(processed_cot_output, ptr_scene_mapping)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4ef08ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "analyse_eval(total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcc2787a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12884f7d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
