{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bce6b16b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc17804f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tabulate import tabulate\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a085f1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2286e4b",
   "metadata": {},
   "source": [
    "# Utility Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00b5bf4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_json_from_path(file_path):\n",
    "    try:\n",
    "        with open(file_path, 'r') as json_file:\n",
    "            data = json.load(json_file)\n",
    "        return data\n",
    "    except FileNotFoundError:\n",
    "        print(f\"File not found at: {file_path}\")\n",
    "        return None\n",
    "    except json.JSONDecodeError:\n",
    "        print(f\"Invalid JSON format in file: {file_path}\")\n",
    "        return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1b6e4e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_ptr_cot_output(response_map):\n",
    "    one_word_map = {}\n",
    "    \n",
    "    for k in response_map:\n",
    "        one_word_map[k] = []\n",
    "        \n",
    "        for a in response_map[k]:\n",
    "            words = a.split()\n",
    "            last_word = words[-1]\n",
    "            # Remove the full stop (.) if it exists\n",
    "            last_word = last_word.rstrip(\".\")\n",
    "            last_word = last_word.rstrip(\"?\")\n",
    "            if ':' in last_word:\n",
    "                last_word = last_word.split(':')[0]\n",
    "            if ',' in last_word:\n",
    "                last_word = last_word.split(',')[0]\n",
    "#             if '?' in last_word:\n",
    "#                 last_word = last_word.split('?')[-1]\n",
    "            last_word = last_word.replace('[','')\n",
    "            last_word = last_word.replace(']','')\n",
    "            last_word = last_word.replace('\"','')\n",
    "            last_word = last_word.replace(\"'\",'')\n",
    "#             last_word = last_word.replace(\",\",'')\n",
    "            last_word = last_word.replace(\")\",'')\n",
    "            last_word = last_word.replace(\"(\",'')\n",
    "            last_word = last_word.replace(\";\",'')\n",
    "            last_word = last_word.replace(\":\",'')\n",
    "            one_word_map[k].append(last_word.lower())\n",
    "    \n",
    "    return one_word_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "047e539f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_clevr_cot_output(response_map):\n",
    "    one_word_map = {}\n",
    "    \n",
    "    for k in response_map:\n",
    "        one_word_map[k] = []\n",
    "        \n",
    "        for a in response_map[k]:\n",
    "            words = a.split()\n",
    "            last_word = words[-1]\n",
    "            # Remove the full stop (.) if it exists\n",
    "#             last_word = last_word.rstrip(\".\")\n",
    "#             last_word = last_word.rstrip(\"?\")\n",
    "#             if ':' in last_word:\n",
    "#                 last_word = last_word.split(':')[0]\n",
    "#             if '?' in last_word:\n",
    "#                 last_word = last_word.split('?')[-1]\n",
    "#             last_word = last_word.replace('[','')\n",
    "#             last_word = last_word.replace(']','')\n",
    "#             last_word = last_word.replace('\"','')\n",
    "#             last_word = last_word.replace(\"'\",'')\n",
    "#             last_word = last_word.replace(\",\",'')\n",
    "#             last_word = last_word.replace(\")\",'')\n",
    "#             last_word = last_word.replace(\"(\",'')\n",
    "#             last_word = last_word.replace(\";\",'')\n",
    "            one_word_map[k].append(last_word.lower())\n",
    "    \n",
    "    return one_word_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86e9351e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_ptr_cot_output(response_map):\n",
    "    one_word_map = {}\n",
    "    \n",
    "    for k in response_map:\n",
    "        one_word_map[k] = []\n",
    "        \n",
    "        for a in response_map[k]:\n",
    "            words = a.split()\n",
    "            last_word = words[-1]\n",
    "            # Remove the full stop (.) if it exists\n",
    "            last_word = last_word.rstrip(\".\")\n",
    "            last_word = last_word.rstrip(\"?\")\n",
    "            if ':' in last_word:\n",
    "                last_word = last_word.split(':')[0]\n",
    "            if ',' in last_word:\n",
    "                last_word = last_word.split(',')[0]\n",
    "#             if '?' in last_word:\n",
    "#                 last_word = last_word.split('?')[-1]\n",
    "            last_word = last_word.replace('[','')\n",
    "            last_word = last_word.replace(']','')\n",
    "            last_word = last_word.replace('\"','')\n",
    "            last_word = last_word.replace(\"'\",'')\n",
    "#             last_word = last_word.replace(\",\",'')\n",
    "            last_word = last_word.replace(\")\",'')\n",
    "            last_word = last_word.replace(\"(\",'')\n",
    "            last_word = last_word.replace(\";\",'')\n",
    "            last_word = last_word.replace(\":\",'')\n",
    "            one_word_map[k].append(last_word.lower())\n",
    "    \n",
    "    return one_word_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aecf115b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_clevr_flant5_non_cot(response_map, scene_mapping):\n",
    "    total = 0\n",
    "    correct = 0\n",
    "    question_family_index_performance = {}\n",
    "    \n",
    "    for k in tqdm(response_map):\n",
    "        predictions = response_map[k]\n",
    "        scene_qns = scene_mapping[k]['questions']\n",
    "        \n",
    "        for i, pred in enumerate(predictions):\n",
    "            qn = scene_qns[i]\n",
    "            qn_family = qn['question_family_index']\n",
    "            ans = qn['answer']\n",
    "            \n",
    "            if qn_family not in question_family_index_performance:\n",
    "                question_family_index_performance[qn_family] = {}\n",
    "                question_family_index_performance[qn_family]['total'] = 0\n",
    "                question_family_index_performance[qn_family]['correct'] = 0\n",
    "            \n",
    "            total +=1\n",
    "            question_family_index_performance[qn_family]['total'] +=1\n",
    "            \n",
    "            if pred == ans:\n",
    "                correct +=1\n",
    "                question_family_index_performance[qn_family]['correct'] +=1\n",
    "    \n",
    "    \n",
    "    return total, correct, question_family_index_performance\n",
    "                "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da96b29d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_ptr_flant5_non_cot(response_map, scene_mapping):\n",
    "    total = 0\n",
    "    correct = 0\n",
    "    question_family_index_performance = {}\n",
    "    \n",
    "    for k in tqdm(response_map):\n",
    "        predictions = response_map[k]\n",
    "        scene_qns = scene_mapping[k]['questions']\n",
    "        \n",
    "        for i, pred in enumerate(predictions):\n",
    "            qn = scene_qns[i]\n",
    "            qn_family = qn['question_family_index']\n",
    "            ans = str(qn['answer'])\n",
    "            \n",
    "            if qn_family not in question_family_index_performance:\n",
    "                question_family_index_performance[qn_family] = {}\n",
    "                question_family_index_performance[qn_family]['total'] = 0\n",
    "                question_family_index_performance[qn_family]['correct'] = 0\n",
    "            \n",
    "            total +=1\n",
    "            question_family_index_performance[qn_family]['total'] +=1\n",
    "            \n",
    "            pred = pred.lower()\n",
    "            ans = ans.lower()\n",
    "            \n",
    "            if pred == 'yes':\n",
    "                pred = 'true'\n",
    "            if pred == 'no':\n",
    "                pred = 'false'\n",
    "            \n",
    "            if pred == ans:\n",
    "                correct +=1\n",
    "                question_family_index_performance[qn_family]['correct'] +=1\n",
    "    \n",
    "    \n",
    "    return total, correct, question_family_index_performance\n",
    "                "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d7c2d05",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_accuracy_by_family_index(results_dict):\n",
    "    family_indices = sorted(results_dict.keys())\n",
    "    accuracies = [results_dict[idx]['correct'] / results_dict[idx]['total'] for idx in family_indices]\n",
    "    \n",
    "    plt.figure(figsize=(20, 6))\n",
    "    sns.barplot(x=family_indices, y=accuracies)\n",
    "    plt.xlabel('Question Family Index')\n",
    "    plt.ylabel('Accuracy')\n",
    "    plt.title('Accuracy by Question Family Index')\n",
    "    plt.show()\n",
    "\n",
    "def plot_correct_vs_total_by_family_index(results_dict):\n",
    "    family_indices = sorted(results_dict.keys())\n",
    "    correct = [results_dict[idx]['correct'] for idx in family_indices]\n",
    "    total = [results_dict[idx]['total'] for idx in family_indices]\n",
    "\n",
    "    plt.figure(figsize=(20, 6))\n",
    "    plt.bar(family_indices, total, label='Total Questions')\n",
    "    plt.bar(family_indices, correct, label='Correct Answers')\n",
    "    plt.xlabel('Question Family Index')\n",
    "    plt.ylabel('Count')\n",
    "    plt.title('Correct vs Total Questions by Question Family Index')\n",
    "    plt.legend()\n",
    "    plt.show()\n",
    "\n",
    "def plot_heatmap(results_dict):\n",
    "    family_indices = sorted(results_dict.keys())\n",
    "    accuracies = [results_dict[idx]['correct'] / results_dict[idx]['total'] for idx in family_indices]\n",
    "\n",
    "    heatmap_data = {\n",
    "        'Question Family Index': family_indices,\n",
    "        'Accuracy': accuracies,\n",
    "    }\n",
    "\n",
    "    plt.figure(figsize=(20, 6))\n",
    "    sns.heatmap([accuracies], xticklabels=family_indices, yticklabels=[\"Accuracy\"], cmap=\"YlGnBu\", annot=True, fmt=\".2f\")\n",
    "    plt.xlabel('Question Family Index')\n",
    "    plt.title('Accuracy Heatmap by Question Family Index')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78638536",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_answer_set(mapping):\n",
    "    answer_set = set()\n",
    "    count = 0\n",
    "    for k in mapping:\n",
    "        answers = mapping[k]\n",
    "        for a in answers:\n",
    "            answer_set.add(str(a).lower())\n",
    "    return answer_set\n",
    "\n",
    "def get_gt_set(scene_map):\n",
    "    gt_set = set()\n",
    "    for k in scene_map:\n",
    "        qns = scene_map[k]['questions']\n",
    "        for q in qns:\n",
    "            gt = str(q['answer']).lower()\n",
    "            gt_set.add(gt)\n",
    "    return gt_set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b020e2ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_dict_table(d):\n",
    "    # Extract header and rows\n",
    "    header = ['Key', 'Total', 'Correct', 'Accuracy']\n",
    "    rows = [[k, v['total'], v['correct'], v['correct']/v['total'] if v['total'] != 0 else 0] for k, v in d.items()]\n",
    "\n",
    "    # Print table\n",
    "    print(tabulate(rows, headers=header, tablefmt='pretty'))\n",
    "\n",
    "def print_pd_dict_table(d):\n",
    "    # Convert dictionary to DataFrame\n",
    "    df = pd.DataFrame.from_dict(d, orient='index')\n",
    "\n",
    "    # Calculate accuracy\n",
    "    df['accuracy'] = df['correct'] / df['total']\n",
    "    df = df.replace([pd.np.inf, -pd.np.inf], pd.NA)\n",
    "\n",
    "    # Sort by accuracy\n",
    "    df = df.sort_values('accuracy', ascending=False)\n",
    "\n",
    "    # Reset index and rename columns\n",
    "    df = df.reset_index().rename(columns={'index': 'Key'})\n",
    "    \n",
    "    # Print DataFrame\n",
    "    pd.set_option('display.max_rows', None)\n",
    "    # Print DataFrame\n",
    "    print(df)\n",
    "    \n",
    "def print_pd_dict_table(d):\n",
    "    # Convert dictionary to DataFrame\n",
    "    df = pd.DataFrame.from_dict(d, orient='index')\n",
    "\n",
    "    # Calculate accuracy\n",
    "    df['accuracy'] = df['correct'] / df['total']\n",
    "    df = df.replace([np.inf, -np.inf], pd.NA)\n",
    "\n",
    "    # Sort by accuracy\n",
    "    df = df.sort_values('accuracy', ascending=False)\n",
    "\n",
    "    # Reset index and rename columns\n",
    "    df = df.reset_index().rename(columns={'index': 'Key'})\n",
    "    \n",
    "    # Print DataFrame\n",
    "    pd.set_option('display.max_rows', None)\n",
    "    # Print DataFrame\n",
    "    print(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40f79da6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def analyse_eval(total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance):\n",
    "    print(total)\n",
    "    print()\n",
    "    print(correct)\n",
    "    print()\n",
    "    print(correct/total)\n",
    "    print()\n",
    "    print('Question Family Index Performance:')\n",
    "    print(question_family_index_performance)\n",
    "    print()\n",
    "    print_pd_dict_table(question_family_index_performance)\n",
    "    plot_accuracy_by_family_index(question_family_index_performance)\n",
    "    plot_correct_vs_total_by_family_index(question_family_index_performance)\n",
    "    plot_heatmap(question_family_index_performance)\n",
    "\n",
    "    print('Major Question Family Performance:')\n",
    "    print(major_question_family_performance)\n",
    "    print()\n",
    "    print_pd_dict_table(major_question_family_performance)\n",
    "    plot_accuracy_by_family_index(major_question_family_performance)\n",
    "    plot_correct_vs_total_by_family_index(major_question_family_performance)\n",
    "    plot_heatmap(major_question_family_performance)\n",
    "\n",
    "    print('Reasoning Step Performance:')\n",
    "    print(reasoning_step_performance)\n",
    "    print()\n",
    "    print_pd_dict_table(reasoning_step_performance)\n",
    "    plot_accuracy_by_family_index(reasoning_step_performance)\n",
    "    plot_correct_vs_total_by_family_index(reasoning_step_performance)\n",
    "    plot_heatmap(reasoning_step_performance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd5d211f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_processed_answers_updated(response_map):\n",
    "    '''\n",
    "    This updated function finds the last occurrence of 'final answer' or 'final answers' in the cot_ans and slices the string from there. \n",
    "    Then it splits the string by newline and takes the next 10 lines.\n",
    "    '''\n",
    "    cot_answers = {}\n",
    "    non_cot_answers = {}\n",
    "\n",
    "    followed_format = 0\n",
    "\n",
    "    non_cot_not_present = set()\n",
    "    cot_not_present = set()\n",
    "    \n",
    "    process_exceptions_cot = set()\n",
    "    process_exceptions_non_cot = set()\n",
    "\n",
    "    unwanted_chars = \"\\\"'()-[]{}<>,`.’‘?:;\"\n",
    "\n",
    "    for k in response_map:\n",
    "        cot_answers[k] = []\n",
    "        non_cot_answers[k] = []\n",
    "\n",
    "        if 'non_cot' not in response_map[k]:\n",
    "            non_cot_not_present.add(k)\n",
    "        else:\n",
    "            non_cot_ans = response_map[k]['non_cot'][0]['message']['content'].split('\\n')[-10:]\n",
    "            for ans in non_cot_ans:\n",
    "                try:\n",
    "                    ans = ans.split()[-1]\n",
    "                    ans = ans.translate(str.maketrans('', '', unwanted_chars))\n",
    "                    if '=' in ans:\n",
    "                        ans = ans.split('=')[-1]\n",
    "                except:\n",
    "                    process_exceptions_non_cot.add(k)\n",
    "                    ans = ''\n",
    "                non_cot_answers[k].append(ans.lower())\n",
    "\n",
    "        if 'cot' not in response_map[k]:\n",
    "            cot_not_present.add(k)\n",
    "        else:\n",
    "            cot_ans = response_map[k]['cot'][0]['message']['content']\n",
    "\n",
    "            if 'final answer' in cot_ans.lower() or 'final answers' in cot_ans.lower():\n",
    "                final_ans_index = max(cot_ans.lower().rfind('final answer'), cot_ans.lower().rfind('final answers'))\n",
    "                final_cot_ans = cot_ans[final_ans_index:].split('\\n')[1:11]\n",
    "                followed_format +=1\n",
    "                \n",
    "                for ans in final_cot_ans:\n",
    "                    ans = ans.replace('.','')\n",
    "                    try:\n",
    "                        ans = ans.split()[-1]\n",
    "                        ans = ans.translate(str.maketrans('', '', unwanted_chars))\n",
    "                        if '=' in ans:\n",
    "                            ans = ans.split('=')[-1]\n",
    "                    except:\n",
    "                        process_exceptions_cot.add(k)\n",
    "                        ans = ''\n",
    "                    cot_answers[k].append(ans.lower())\n",
    "\n",
    "    did_not_follow_format = len(response_map.keys()) - followed_format - len(cot_not_present)\n",
    "    \n",
    "    print('followed fomat: ',followed_format)\n",
    "    print('did not follow format: ',did_not_follow_format)\n",
    "    \n",
    "    print('non cot format issues:', process_exceptions_non_cot)\n",
    "    print()\n",
    "    print('cot format issues:', process_exceptions_cot)\n",
    "    \n",
    "    return cot_answers, non_cot_answers, cot_not_present, non_cot_not_present\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37de091e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_qn_family_clevr(reasoning_step):\n",
    "    function = reasoning_step['function']\n",
    "    if 'exist' in function:\n",
    "        return 'exist'\n",
    "    if 'query' in function:\n",
    "        return 'query attribute'\n",
    "    if 'equal' in function:\n",
    "        return 'compare attribute'\n",
    "    if 'count' in function:\n",
    "        return 'count'\n",
    "    if 'than' in function:\n",
    "        return 'compare numbers'\n",
    "    return 'unknown'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a7d0236",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_qn_family_ptr(reasoning_step):\n",
    "    \n",
    "    concept = ['query_object-category','query_part-category','query_part-color','count','exist']\n",
    "#     relation = [] # The rest\n",
    "    analogy = ['query_geometric-analogy-color','query_geometric-analogy-count','query_positional-analogy-category','query_positional-analogy-count','query_positional-analogy-exist']\n",
    "    arithmetic = ['equal_integer','greater_than','less_than','minus_less','minus_more','sum']\n",
    "    physics = ['query_stability','query_unstability','query_change']\n",
    "    \n",
    "    \n",
    "    function = reasoning_step['type']\n",
    "    if function in concept:\n",
    "        return 'concept'\n",
    "    if function in analogy:\n",
    "        return 'analogy'\n",
    "    if function in arithmetic:\n",
    "        return 'arithmetic'\n",
    "    if function in physics:\n",
    "        return 'physics'\n",
    "    return 'relation'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a68f49a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_clevr_gpt(response_map, scene_mapping):\n",
    "    total = 0\n",
    "    correct = 0\n",
    "    question_family_index_performance = {}\n",
    "    major_question_family_performance = {}\n",
    "    reasoning_step_performance = {}\n",
    "    \n",
    "    for k in tqdm(response_map):\n",
    "        predictions = response_map[k]\n",
    "        scene_qns = scene_mapping[k]['questions']\n",
    "        for i, pred in enumerate(predictions):\n",
    "            try:\n",
    "                qn = scene_qns[i]\n",
    "                qn_family = qn['question_family_index']\n",
    "                ans = str(qn['answer'])\n",
    "                major_family = get_qn_family_clevr(qn['program'][-1])\n",
    "                reasoning_steps = len(qn['program'])\n",
    "\n",
    "                if qn_family not in question_family_index_performance:\n",
    "                    question_family_index_performance[qn_family] = {}\n",
    "                    question_family_index_performance[qn_family]['total'] = 0\n",
    "                    question_family_index_performance[qn_family]['correct'] = 0\n",
    "                    \n",
    "                if major_family not in major_question_family_performance:\n",
    "                    major_question_family_performance[major_family] = {}\n",
    "                    major_question_family_performance[major_family]['total'] = 0\n",
    "                    major_question_family_performance[major_family]['correct'] = 0\n",
    "            \n",
    "                if reasoning_steps not in reasoning_step_performance:\n",
    "                    reasoning_step_performance[reasoning_steps] = {}\n",
    "                    reasoning_step_performance[reasoning_steps]['total'] = 0\n",
    "                    reasoning_step_performance[reasoning_steps]['correct'] = 0\n",
    "\n",
    "                total +=1\n",
    "                question_family_index_performance[qn_family]['total'] +=1\n",
    "                major_question_family_performance[major_family]['total']+=1\n",
    "                reasoning_step_performance[reasoning_steps]['total'] +=1\n",
    "\n",
    "                pred = pred.lower()\n",
    "                ans = ans.lower()\n",
    "\n",
    "#                 if pred == 'yes':\n",
    "#                     pred = 'true'\n",
    "#                 if pred == 'no':\n",
    "#                     pred = 'false'\n",
    "\n",
    "                if pred == ans:\n",
    "                    correct +=1\n",
    "                    question_family_index_performance[qn_family]['correct'] +=1\n",
    "                    major_question_family_performance[major_family]['correct']+=1\n",
    "                    reasoning_step_performance[reasoning_steps]['correct'] +=1\n",
    "                    \n",
    "            except Exception as e:\n",
    "                print(k)\n",
    "                print(i)\n",
    "                print(\"An exception occurred:\", str(e))\n",
    "                print()\n",
    "    \n",
    "    return total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance\n",
    "\n",
    "\n",
    "def evaluate_ptr_gpt(response_map, scene_mapping):\n",
    "    total = 0\n",
    "    correct = 0\n",
    "    question_family_index_performance = {}\n",
    "    major_question_family_performance = {}\n",
    "    reasoning_step_performance = {}\n",
    "    \n",
    "    for k in tqdm(response_map):\n",
    "        predictions = response_map[k]\n",
    "        scene_qns = scene_mapping[k]['questions']\n",
    "        for i, pred in enumerate(predictions):\n",
    "            try:\n",
    "                qn = scene_qns[i]\n",
    "                qn_family = qn['question_family_index']\n",
    "                ans = str(qn['answer'])\n",
    "                major_family = get_qn_family_ptr(qn['program'][-1])\n",
    "                reasoning_steps = len(qn['program'])\n",
    "\n",
    "                if qn_family not in question_family_index_performance:\n",
    "                    question_family_index_performance[qn_family] = {}\n",
    "                    question_family_index_performance[qn_family]['total'] = 0\n",
    "                    question_family_index_performance[qn_family]['correct'] = 0\n",
    "                    \n",
    "                if major_family not in major_question_family_performance:\n",
    "                    major_question_family_performance[major_family] = {'total': 0, 'correct': 0}\n",
    "\n",
    "                if reasoning_steps not in reasoning_step_performance:\n",
    "                    reasoning_step_performance[reasoning_steps] = {'total': 0, 'correct': 0}\n",
    "\n",
    "                total +=1\n",
    "                question_family_index_performance[qn_family]['total'] +=1\n",
    "                major_question_family_performance[major_family]['total'] += 1\n",
    "                reasoning_step_performance[reasoning_steps]['total'] += 1\n",
    "\n",
    "                pred = pred.lower()\n",
    "                ans = ans.lower()\n",
    "\n",
    "#                 if pred == 'yes':\n",
    "#                     pred = 'true'\n",
    "#                 if pred == 'no':\n",
    "#                     pred = 'false'\n",
    "\n",
    "                if pred == ans:\n",
    "                    correct +=1\n",
    "                    question_family_index_performance[qn_family]['correct'] +=1\n",
    "                    major_question_family_performance[major_family]['correct'] += 1\n",
    "                    reasoning_step_performance[reasoning_steps]['correct'] += 1\n",
    "            except Exception as e:\n",
    "                print(k)\n",
    "                print(i)\n",
    "                print(\"An exception occurred:\", str(e))\n",
    "                print()\n",
    "    \n",
    "    return total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance\n",
    "                "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b75cb44",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "ba9f1ea8",
   "metadata": {},
   "source": [
    "# InstructGPT"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d1bd90f6",
   "metadata": {},
   "source": [
    "# InstructGPT Standard"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bdc1781d",
   "metadata": {},
   "source": [
    "# CLEVR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d79d321c",
   "metadata": {},
   "outputs": [],
   "source": [
    "clevr_scene_path = '/home/user/Desktop/vqa_research/clevr_val_experiments/clevr_val_scene_mapping.json'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "505bdafa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_instruct_gpt_cot_output(res_mapping):\n",
    "    final_map = {}\n",
    "    count = 0\n",
    "    for k in res_mapping:\n",
    "        final_map[k] = []\n",
    "        ans = res_mapping[k]['cot']\n",
    "        \n",
    "        for response in ans:\n",
    "            if response == []:\n",
    "                count +=1\n",
    "                final_map[k].append('')\n",
    "            else:\n",
    "                a= response[0]\n",
    "                if a == '':\n",
    "                    final_map[k].append(a)\n",
    "                else:\n",
    "                    last_word = a.split()[-1].lower()\n",
    "                    last_word = last_word.replace('.','')\n",
    "                    last_word = last_word.replace('\"','')\n",
    "                    last_word = last_word.replace('>','')\n",
    "                    last_word = last_word.replace('<','')\n",
    "                    last_word = last_word.replace('?','')\n",
    "                    last_word = last_word.replace(']','')\n",
    "                    last_word = last_word.replace('[','')\n",
    "                    last_word = last_word.replace('(','')\n",
    "                    last_word = last_word.replace(')','')\n",
    "\n",
    "                    final_map[k].append(last_word)\n",
    "    print(count)\n",
    "    return final_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80435c77",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene_mapping = load_json_from_path(clevr_scene_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ebe69391",
   "metadata": {},
   "source": [
    "# Ada"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7027015d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c994abc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_map = load_json_from_path('/home/user/Desktop/vqa_research/gpt_exps/instruct_GPT_Standard-Separate_experiments/ada_clevr/response_mapping_clevr_instruct_gpt_standard_ada.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5546f787",
   "metadata": {},
   "outputs": [],
   "source": [
    "# res_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e7f6307",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(res_map.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b36a500",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_res_map = process_instruct_gpt_cot_output(res_mapping=res_map)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f7277b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_res_map['0']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a3b6dda",
   "metadata": {},
   "outputs": [],
   "source": [
    "total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance = evaluate_clevr_gpt(final_res_map, scene_mapping)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8212143",
   "metadata": {},
   "outputs": [],
   "source": [
    "analyse_eval(total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67ef365b",
   "metadata": {},
   "source": [
    "# Babbage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77f374aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_map = load_json_from_path('/home/user/Desktop/vqa_research/gpt_exps/instruct_GPT_Standard-Separate_experiments/babbage_clevr/response_mapping_clevr_instruct_gpt_standard_babbage.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8509ebd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_res_map = process_instruct_gpt_cot_output(res_mapping=res_map)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acf72272",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_res_map['0']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75039620",
   "metadata": {},
   "outputs": [],
   "source": [
    "total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance = evaluate_clevr_gpt(final_res_map, scene_mapping)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "183b6687",
   "metadata": {},
   "outputs": [],
   "source": [
    "analyse_eval(total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b4ef88b",
   "metadata": {},
   "source": [
    "# Curie"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b46d48dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_map = load_json_from_path('/home/user/Desktop/vqa_research/gpt_exps/instruct_GPT_Standard-Separate_experiments/curie_clevr/response_mapping_clevr_instruct_gpt_standard_babbage.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de09f158",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_res_map = process_instruct_gpt_cot_output(res_mapping=res_map)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed2e85ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_res_map['500']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58f2097d",
   "metadata": {},
   "outputs": [],
   "source": [
    "total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance = evaluate_clevr_gpt(final_res_map, scene_mapping)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8e5b760",
   "metadata": {},
   "outputs": [],
   "source": [
    "analyse_eval(total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "616a863d",
   "metadata": {},
   "source": [
    "# PTR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aad79972",
   "metadata": {},
   "outputs": [],
   "source": [
    "ptr_scene_path = '/home/user/Desktop/vqa_research/ptr_val/ptr_val_scene_mapping.json'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56b28cac",
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_instruct_gpt_cot_output(res_mapping):\n",
    "    final_map = {}\n",
    "    count = 0\n",
    "    for k in res_mapping:\n",
    "        final_map[k] = []\n",
    "        ans = res_mapping[k]['cot']\n",
    "        \n",
    "        for response in ans:\n",
    "            if response == [] or len(response)!=2:\n",
    "                count +=1\n",
    "                final_map[k].append('')\n",
    "            else:\n",
    "                a= response[0]\n",
    "                if a == '':\n",
    "                    final_map[k].append(a)\n",
    "                else:\n",
    "                    last_word = a.split()[-1].lower() if a.strip() != \"\" else \"\"\n",
    "                    last_word = last_word.replace('.','')\n",
    "                    last_word = last_word.replace('\"','')\n",
    "                    last_word = last_word.replace('>','')\n",
    "                    last_word = last_word.replace('<','')\n",
    "                    last_word = last_word.replace('?','')\n",
    "                    last_word = last_word.replace(']','')\n",
    "                    last_word = last_word.replace('[','')\n",
    "                    last_word = last_word.replace('(','')\n",
    "                    last_word = last_word.replace(')','')\n",
    "\n",
    "                    if last_word =='yes':\n",
    "                        last_word = 'true'\n",
    "                    if last_word =='no':\n",
    "                        last_word = 'false'\n",
    "\n",
    "                    final_map[k].append(last_word)\n",
    "    print(count)\n",
    "    return final_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a6d52f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene_mapping = load_json_from_path(ptr_scene_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e8bb9d8",
   "metadata": {},
   "source": [
    "# Ada"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "689d89a9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee46e95f",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_map = load_json_from_path('/home/user/Desktop/vqa_research/gpt_exps/instruct_GPT_Standard-Separate_experiments/ada_ptr/response_mapping_ptr_instruct_gpt_standard_ada.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44d73324",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_res_map = process_instruct_gpt_cot_output(res_mapping=res_map)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a6716d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_res_map['PTR_val_007239']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f143da7",
   "metadata": {},
   "outputs": [],
   "source": [
    "total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance = evaluate_ptr_gpt(final_res_map, scene_mapping)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ecce61b",
   "metadata": {},
   "outputs": [],
   "source": [
    "analyse_eval(total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9d7c662",
   "metadata": {},
   "source": [
    "# Babbage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "050250a4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6779ac15",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_map = load_json_from_path('/home/user/Desktop/vqa_research/gpt_exps/instruct_GPT_Standard-Separate_experiments/babbage_ptr/response_mapping_ptr_instruct_gpt_standard_babbage.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf434019",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_res_map = process_instruct_gpt_cot_output(res_mapping=res_map)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fc51949",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_res_map['PTR_val_007239']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aef51f7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance = evaluate_ptr_gpt(final_res_map, scene_mapping)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57d32f83",
   "metadata": {},
   "outputs": [],
   "source": [
    "analyse_eval(total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e3a3ae3",
   "metadata": {},
   "source": [
    "# Curie"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3748b8fc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c02a45d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_map = load_json_from_path('/home/user/Desktop/vqa_research/gpt_exps/instruct_GPT_Standard-Separate_experiments/curie_ptr/response_mapping_ptr_instruct_gpt_standard_curie.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48f943e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_res_map = process_instruct_gpt_cot_output(res_mapping=res_map)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5b35a82",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_res_map['PTR_val_007239']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b575a4a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance = evaluate_ptr_gpt(final_res_map, scene_mapping)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f219acf2",
   "metadata": {},
   "outputs": [],
   "source": [
    "analyse_eval(total, correct, question_family_index_performance, major_question_family_performance, reasoning_step_performance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "502eb12f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b88c9d3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
