{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datetime\n",
    "import math\n",
    "import os\n",
    "\n",
    "import pathlib\n",
    "from functools import partial\n",
    "import warnings\n",
    "\n",
    "import pandas as pd\n",
    "import torch.multiprocessing as mp\n",
    "from joblib import Memory\n",
    "from num2words import num2words\n",
    "import numpy as np\n",
    "from omegaconf import OmegaConf\n",
    "from rich.console import Console\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "\n",
    "from configs import config\n",
    "from utils import seed_everything\n",
    "import datasets\n",
    "import openai\n",
    "\n",
    "intermediate_variable = []\n",
    "\n",
    "def inject_saver(code, save_intermediate_steps=True):\n",
    "    injected_function_name = 'save_intermediate'\n",
    "    if injected_function_name in code:\n",
    "        return code\n",
    "    code = code.split(\"\\n\")\n",
    "    newcode = []\n",
    "    for n, codeline in enumerate(code):\n",
    "        codeline, indent = split_codeline_and_indent_level(codeline)\n",
    "\n",
    "        if codeline.startswith('#') or codeline == '':  # this will cause issues if you have lots of comment lines\n",
    "            continue\n",
    "        if '#' in codeline:\n",
    "            codeline = codeline.split('#')[0]\n",
    "\n",
    "        thing_to_show, code_type = get_thing_to_show_codetype(codeline)\n",
    "        # console.print(thing_to_show)\n",
    "\n",
    "        if code_type in ('assign', 'append', 'if', 'return', 'for', 'sort', 'add'):\n",
    "            if '\\'' in codeline:\n",
    "                codeline.replace('\\'', '\\\\\\'')\n",
    "\n",
    "            if save_intermediate_steps:\n",
    "                escape_thing = lambda x: x.replace(\"'\", \"\\\\'\")\n",
    "                injection_string_format = \\\n",
    "                    lambda \\\n",
    "                        thing: f\"{indent}{injected_function_name}(lineno={n},value=({thing}),valuename='{escape_thing(thing)}'); \" \n",
    "\n",
    "            extension_list = []\n",
    "            if isinstance(thing_to_show, list):\n",
    "                injection_string_list = [injection_string_format(f\"{thing}\") for thing in thing_to_show]\n",
    "                extension_list.extend(injection_string_list)\n",
    "            elif code_type == 'for':\n",
    "                injection_string = injection_string_format(f\"{thing_to_show}\")\n",
    "                injection_string = \" \" * 4 + injection_string\n",
    "                extension_list.append(injection_string)\n",
    "            else:\n",
    "                extension_list.append(injection_string_format(f\"{thing_to_show}\"))\n",
    "\n",
    "            if code_type in ('if', 'return'):\n",
    "                extension_list = extension_list + [f\"{indent}{codeline}\"]\n",
    "            else:\n",
    "                extension_list = [f\"{indent}{codeline}\"] + extension_list\n",
    "\n",
    "            newcode.extend(extension_list)\n",
    "\n",
    "        elif code_type == 'elif_else':\n",
    "            newcode.append(f\"{indent}{codeline}\")\n",
    "        else:\n",
    "            newcode.append(f\"{indent}{codeline}\")\n",
    "\n",
    "    return \"\\n\".join(newcode)\n",
    "\n",
    "def get_thing_to_show_codetype(codeline):\n",
    "    # can output either a list of things to show, or a single thing to show\n",
    "    things_to_show = []\n",
    "    if codeline.startswith(\"if\"):\n",
    "        condition, rest = codeline[3:].split(\":\", 1)\n",
    "        codeline = f\"if {condition}:{rest}\"\n",
    "        code_type = \"if\"\n",
    "\n",
    "        operators = ['==', '!=', '>=', '<=', '>', '<']\n",
    "        things_to_show = []\n",
    "        for op in operators:\n",
    "            if op in condition:\n",
    "                things_to_show = [x.strip() for x in condition.split(op)]\n",
    "                # print(things_to_show)\n",
    "                break\n",
    "        # things_to_show.append(thing_to_show)\n",
    "        thing_to_show = things_to_show + [condition.strip()]\n",
    "\n",
    "    elif codeline.startswith(\"for\"):\n",
    "        code_type = 'for'\n",
    "        thing_to_show = codeline.split(\"for \")[1].split(\" in \")[0]\n",
    "\n",
    "    elif codeline.startswith(\"return\"):\n",
    "        thing_to_show = codeline.split(\"return \")[1]\n",
    "        code_type = 'return'\n",
    "\n",
    "    elif ' = ' in codeline:\n",
    "        code_type = 'assign'\n",
    "        thing_to_show = codeline.split(' = ')[0]\n",
    "    elif ' += ' in codeline:\n",
    "        code_type = 'assign'\n",
    "        thing_to_show = codeline.split(' += ')[0]\n",
    "    elif ' -= ' in codeline:\n",
    "        code_type = 'assign'\n",
    "        thing_to_show = codeline.split(' -= ')[0]\n",
    "    elif ' *= ' in codeline:\n",
    "        code_type = 'assign'\n",
    "        thing_to_show = codeline.split(' *= ')[0]\n",
    "    elif ' /= ' in codeline:\n",
    "        code_type = 'assign'\n",
    "        thing_to_show = codeline.split(' /= ')[0]\n",
    "\n",
    "    elif '.append(' in codeline:\n",
    "        code_type = 'append'\n",
    "        thing_to_show = codeline.split('.append(')[0] + '[-1]'\n",
    "    elif '.add(' in codeline:\n",
    "        code_type = 'add'\n",
    "        thing_to_show = codeline.split('.add(')[0]\n",
    "\n",
    "    elif '.sort(' in codeline:\n",
    "        code_type = 'sort'\n",
    "        thing_to_show = codeline.split('.sort(')[0]\n",
    "\n",
    "    elif codeline.startswith(\"elif\") or codeline.startswith(\"else\"):\n",
    "        thing_to_show = None\n",
    "        code_type = 'elif_else'\n",
    "    else:\n",
    "        thing_to_show = None\n",
    "        code_type = 'other'\n",
    "\n",
    "    if isinstance(thing_to_show, list):\n",
    "        thing_to_show = [thing if not (thing.strip().startswith(\"'\") and thing.strip().endswith(\"'\"))\n",
    "                         else thing.replace(\"'\", '\"') for thing in thing_to_show if thing is not None]\n",
    "    elif isinstance(thing_to_show, str):\n",
    "        thing_to_show = thing_to_show if not (thing_to_show.strip().startswith(\"'\") and\n",
    "                                              thing_to_show.strip().endswith(\"'\")) else thing_to_show.replace(\"'\", '\"')\n",
    "    return thing_to_show, code_type\n",
    "\n",
    "def split_codeline_and_indent_level(codeline):\n",
    "    origlen = len(codeline)\n",
    "    codeline = codeline.lstrip()\n",
    "    indent = origlen - len(codeline)\n",
    "    indent = \" \" * indent\n",
    "    return codeline, indent\n",
    "\n",
    "def save_intermediate(lineno, value, valuename):\n",
    "    thing_to_show = value\n",
    "\n",
    "    from PIL import Image\n",
    "    if isinstance(thing_to_show, Image.Image):\n",
    "        save_intermediate_variable(valuename , thing_to_show)\n",
    "    elif str(type(thing_to_show)) == \"<class 'image_patch.ImagePatch'>\":\n",
    "        save_intermediate_variable(valuename , thing_to_show)\n",
    "    elif isinstance(thing_to_show, list) or isinstance(thing_to_show, tuple):\n",
    "        if len(thing_to_show) > 0:\n",
    "            for i, thing in enumerate(thing_to_show):\n",
    "                save_intermediate(None, thing, f\"{valuename}[{i}]\")\n",
    "            save_intermediate_variable(valuename , len(thing_to_show))\n",
    "            return\n",
    "        else:\n",
    "            save_intermediate_variable(valuename , None)\n",
    "        \n",
    "    elif isinstance(thing_to_show, dict):\n",
    "        if len(thing_to_show) > 0:\n",
    "            for i, (thing_k, thing_v) in enumerate(thing_to_show.items()):\n",
    "                save_intermediate(None, thing_v, f\"{valuename}['{thing_k}']\")\n",
    "            save_intermediate_variable(valuename , len(thing_to_show))\n",
    "            return\n",
    "        else:\n",
    "            save_intermediate_variable(valuename , None)\n",
    "    else:\n",
    "        save_intermediate_variable(valuename , thing_to_show)\n",
    "        return\n",
    "\n",
    "def save_intermediate_variable(name , value):\n",
    "    intermediate_variable.append([name,value])\n",
    "    return\n",
    "\n",
    "import torch\n",
    "def general_postprocessing(prediction):\n",
    "    try:\n",
    "        if type(prediction).__name__ == 'ImagePatch':\n",
    "            prediction = prediction.classify_object()\n",
    "\n",
    "        if isinstance(prediction, list):\n",
    "            prediction = prediction[0] if len(prediction) > 0 else \"no\"\n",
    "\n",
    "        if isinstance(prediction, torch.Tensor):\n",
    "            prediction = prediction.item()\n",
    "        if prediction is None:\n",
    "            prediction = \"no\"\n",
    "        if isinstance(prediction, bool):\n",
    "            if prediction:\n",
    "                prediction = \"yes\"\n",
    "            else:\n",
    "                prediction = \"no\"\n",
    "        elif isinstance(prediction, int):\n",
    "            prediction = str(prediction)\n",
    "            print(\"No answer is a number, so this will be wrong\")\n",
    "    except:\n",
    "        prediction = str(prediction)\n",
    "\n",
    "    prediction = str(prediction)\n",
    "\n",
    "    prediction = prediction.replace('\\n', ' ')\n",
    "    prediction = prediction.replace('\\t', ' ')\n",
    "    prediction = prediction.strip()\n",
    "    prediction = prediction.lower()\n",
    "\n",
    "    if prediction == 'true':\n",
    "        prediction = 'yes'\n",
    "    elif prediction == 'false':\n",
    "        prediction = 'no'\n",
    "    return prediction\n",
    "\n",
    "def answer_right_or_not(prediction, ground_truth):\n",
    "    if general_postprocessing(prediction) == str(ground_truth):\n",
    "        return True\n",
    "    else:\n",
    "        return False\n",
    "\n",
    "# See https://github.com/pytorch/pytorch/issues/11201, https://github.com/pytorch/pytorch/issues/973\n",
    "# Not for dataloader, but for multiprocessing batches\n",
    "mp.set_sharing_strategy('file_system')\n",
    "queue_results = None\n",
    "\n",
    "cache = Memory('cache/' if config.use_cache else None, verbose=0)\n",
    "runs_dict = {}\n",
    "seed_everything()\n",
    "console = Console(highlight=False)\n",
    "\n",
    "\n",
    "def my_collate(batch):\n",
    "    # Avoid stacking images (different size). Return everything as a list\n",
    "    to_return = {k: [d[k] for d in batch] for k in batch[0].keys()}\n",
    "    return to_return\n",
    "\n",
    "\n",
    "def run_program(parameters, queues_in_, input_type_, retrying=False):\n",
    "    from image_patch import ImagePatch, llm_query, best_image_match, distance, bool_to_yesno\n",
    "    from video_segment import VideoSegment\n",
    "\n",
    "    global queue_results\n",
    "\n",
    "    code, sample_id, image, possible_answers, query = parameters\n",
    "    for _ in range(1):\n",
    "        code_header = f'def execute_command_{sample_id}(' \\\n",
    "                    f'{input_type_}, possible_answers, query, ' \\\n",
    "                    f'ImagePatch, VideoSegment, ' \\\n",
    "                    'llm_query, bool_to_yesno, distance, best_image_match):\\n'\n",
    "        lines = code.split(\"\\n\")\n",
    "        lines[0] = code_header\n",
    "        code_origin = \"\\n\".join(lines)\n",
    "\n",
    "        code = inject_saver(code_origin, save_intermediate_steps=True) \n",
    "        try:\n",
    "            exec(compile(code, 'Codex', 'exec'), globals())\n",
    "        except Exception as e:\n",
    "            console.print(f'Sample {sample_id} failed at compilation time with error: {e}')\n",
    "            try:\n",
    "                with open(config.fixed_code_file, 'r') as f:\n",
    "                    fixed_code = f.read()\n",
    "                code = code_header + fixed_code \n",
    "                exec(compile(code, 'Codex', 'exec'), globals())\n",
    "            except Exception as e2:\n",
    "                console.print(f'Not even the fixed code worked. Sample {sample_id} failed at compilation time with error: {e2}')\n",
    "                return None, code\n",
    "            \n",
    "        queues = [queues_in_, queue_results]\n",
    "\n",
    "        image_patch_partial = partial(ImagePatch, queues=queues)\n",
    "        video_segment_partial = partial(VideoSegment, queues=queues)\n",
    "        llm_query_partial = partial(llm_query, queues=queues)\n",
    "\n",
    "        try:\n",
    "            result = globals()[f'execute_command_{sample_id}'](\n",
    "                # Inputs to the function\n",
    "                image, possible_answers, query,\n",
    "                # Classes to be used\n",
    "                image_patch_partial, video_segment_partial,\n",
    "                # Functions to be used\n",
    "                llm_query_partial, bool_to_yesno, distance, best_image_match)\n",
    "\n",
    "            if answer_right_or_not(result , possible_answers):\n",
    "                break\n",
    "            \n",
    "            ls = os.listdir(\"/home/viper/intermediate\")\n",
    "            for i in ls:\n",
    "                c_path = os.path.join(\"/home/viper/intermediate\", i)\n",
    "                os.remove(c_path)\n",
    "\n",
    "            import matplotlib\n",
    "            #image_save = []\n",
    "            text_save = []\n",
    "            for i ,item in enumerate(intermediate_variable):\n",
    "                if isinstance(item[1] , ImagePatch):\n",
    "                    matplotlib.pyplot.imsave('./intermediate/'+item[0]+str(i)+'.jpg'  , item[1].cropped_image.permute(1,2,0).numpy())\n",
    "                else:\n",
    "                    item = [str(j) for j in item]\n",
    "                    item = \" \".join(item)\n",
    "                    text_save.append(item)\n",
    "\n",
    "            from mPLUGOwl.pipeline.interface import get_model , do_generate\n",
    "            model_mPLUG, tokenizer_mPLUG, processor_mPLUG = get_model(pretrained_ckpt='mPLUGOwl/MAGAer13/mplug-owl-llama-7b-pt', use_bf16=True)\n",
    "            prompts_img = ['''The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\\n''']\n",
    "            \n",
    "            relative_path = \"/home/viper/intermediate\"\n",
    "            image_save = os.listdir(relative_path)\n",
    "            image_save = [relative_path +\"/\"+ name for name in image_save]\n",
    "            \n",
    "            for i in range(len(image_save)):\n",
    "                prompts_img.extend([\"Human: <image>\\n\"])\n",
    "            prompts_img.extend([\"AI:\"])\n",
    "            print(\"mPLUGOwl do_generate\")\n",
    "            sentence_img = do_generate(prompts_img, image_save, model_mPLUG, tokenizer_mPLUG, processor_mPLUG,\n",
    "                                    use_bf16=True, max_length=512, top_k=5, do_sample=True)\n",
    "\n",
    "            whole_prompt = [\"Please give me a caption of this picture\"]\n",
    "            whole_img = do_generate(whole_prompt, image, model_mPLUG, tokenizer_mPLUG, processor_mPLUG,\n",
    "                                    use_bf16=True, max_length=512, top_k=5, do_sample=True)\n",
    "\n",
    "            prompts_txt = [\"The follwings are intermediate variables from a program execution process. What can you summarize?\\n\"]\n",
    "            prompts_img.extend([\"[intermediate variables]:\"])\n",
    "            prompts_img.extend(text_save)\n",
    "            sentence_txt = do_generate(prompts_txt, None, model_mPLUG, tokenizer_mPLUG, processor_mPLUG,\n",
    "                                    use_bf16=True, max_length=512, top_k=5, do_sample=True)\n",
    "\n",
    "            file_object = open(\"prompts/feedback.prompt\",'r')\n",
    "            prompt_feedback = file_object.read()\n",
    "            prompt_feedback.replace(\"INSERT_QUERY_HERE\", query)\n",
    "            prompt_feedback.replace(\"INSERT_ADDITIONAL_DESCRIPTION_HERE\",whole_img)\n",
    "            prompt_feedback.replace(\"INSERT_IMAGE_CAPTION_HERE\",sentence_img)\n",
    "            prompt_feedback.replace(\"INSERT_TEXT_INFORMATION_HERE\",sentence_txt)\n",
    "            prompt_feedback.replace(\"INSERT_CODE_HERE\",code_origin)\n",
    "            feedback = openai.ChatCompletion.create(\n",
    "            model=\"gpt-3.5-turbo\",\n",
    "            messages=[\n",
    "                {\"role\": \"system\", \"content\": \"You are a feedback generater.\"},\n",
    "                {\"role\": \"system\", \"content\": \"I will provide you with a [Question] [Additional Description] [Image Caption] [Text Information] and [Code]\"},\n",
    "                {\"role\": \"system\", \"content\": \"[Question] is the problem to be solved, [Additional Description] is additional description of the question, [Image Caption] is image captions of a series question-based pictures, [Text Information] of the intermediate output of a piece of code and [Code](may be some errors inside) is the subject you have to provide feedback on\"},\n",
    "                {\"role\": \"user\", \"content\": prompt_feedback},\n",
    "                {\"role\": \"user\", \"content\": \"What suggestions can you provide to correct(or optimize) [Code]?\"},\n",
    "            ]\n",
    "            )\n",
    "\n",
    "            file_object = open(\"prompts/gen_code.prompt\",'r')\n",
    "            prompt_gen_code = file_object.read()\n",
    "            prompt_gen_code.replace(\"INSERT_QUERY_HERE\", query)\n",
    "            prompt_gen_code.replace(\"INSERT_FEEDBACK_HERE\", feedback.choices[0].message[\"content\"])\n",
    "            prompt_gen_code.replace(\"INSERT_CODE_HERE\",code_origin)\n",
    "            code = openai.ChatCompletion.create(\n",
    "            model=\"gpt-3.5-turbo\",\n",
    "            messages=[\n",
    "                {\"role\": \"system\", \"content\": \"You are a code optimizer(or corrector). Only answer with a code function starting def execute_command.\"},\n",
    "                {\"role\": \"system\", \"content\": \"I will provide you with [API usage methods], [Question], [Feedback] and [Code].\"},\n",
    "                {\"role\": \"system\", \"content\": \"[API usage methods] is the usage of API , [Question] is the problem to be solved, and [Feedback] is the feedback to the possible incorrect [Code].\"},\n",
    "                {\"role\": \"user\", \"content\": prompt_gen_code},\n",
    "                {\"role\": \"user\", \"content\": \"Please optimize(or correct) the [Code]. Only answer with a code function starting def execute_command.\"},\n",
    "            ]\n",
    "            )\n",
    "\n",
    "            code = code.choices[0].message[\"content\"]\n",
    "\n",
    "        except Exception as e:\n",
    "            if retrying:\n",
    "                return None, code\n",
    "            console.print(f'Sample {sample_id} failed with error: {e}. Next you will see an \"expected an indented block\" error. ')\n",
    "            # Retry again with fixed code\n",
    "            new_code = \"[\"  # This code will break upon execution, and it will be caught by the except clause\n",
    "            result = run_program((new_code, sample_id, image, possible_answers, query), queues_in_, input_type_,\n",
    "                                retrying=True)[0]\n",
    "\n",
    "    # The function run_{sample_id} is defined globally (exec doesn't work locally). A cleaner alternative would be to\n",
    "    # save it in a global dict (replace globals() for dict_name in exec), but then it doesn't detect the imported\n",
    "    # libraries for some reason. Because defining it globally is not ideal, we just delete it after running it.\n",
    "    if f'execute_command_{sample_id}' in globals():\n",
    "        del globals()[f'execute_command_{sample_id}']  # If it failed to compile the code, it won't be defined\n",
    "    return result, code_origin\n",
    "\n",
    "\n",
    "def worker_init(queue_results_):\n",
    "    global queue_results\n",
    "    index_queue = mp.current_process()._identity[0] % len(queue_results_)\n",
    "    queue_results = queue_results_[index_queue]\n",
    "\n",
    "\n",
    "def main():\n",
    "    mp.set_start_method('spawn')\n",
    "    from vision_processes import queues_in, finish_all_consumers, forward, manager\n",
    "    from datasets.dataset import MyDataset\n",
    "    console.print(\"ALL LOADED!\")\n",
    "    batch_size = config.dataset.batch_size\n",
    "    num_processes = min(batch_size, 50)\n",
    "\n",
    "    if config.multiprocessing:\n",
    "        queue_results_main = manager.Queue()\n",
    "        queues_results = [manager.Queue() for _ in range(batch_size)]\n",
    "    else:\n",
    "        queue_results_main = None\n",
    "        queues_results = [None for _ in range(batch_size)]\n",
    "\n",
    "    codex = partial(forward, model_name='codex', queues=[queues_in, queue_results_main])\n",
    "    if config.clear_cache:\n",
    "        cache.clear()\n",
    "\n",
    "    # if config.wandb:\n",
    "    #     import wandb\n",
    "    #     wandb.init(project=\"viper\", config=OmegaConf.to_container(config))\n",
    "    #     # log the prompt file\n",
    "    #     wandb.save(config.prompt)\n",
    "\n",
    "    dataset = MyDataset(**config.dataset)\n",
    "    with open(config.codex.prompt) as f:\n",
    "        base_prompt = f.read().strip()\n",
    "\n",
    "    codes_all = None\n",
    "    if config.use_cached_codex:\n",
    "        results = pd.read_csv(config.cached_codex_path)\n",
    "        codes_all = [r.split('# Answer is:')[1] for r in results['code']]\n",
    "    # python -c \"from joblib import Memory; cache = Memory('cache/', verbose=0); cache.clear()\"\n",
    "    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True,\n",
    "                            collate_fn=my_collate)\n",
    "    input_type = dataset.input_type\n",
    "\n",
    "    all_results = []\n",
    "    all_answers = []\n",
    "    all_codes = []\n",
    "    all_ids = []\n",
    "    all_querys = []\n",
    "    all_img_paths = []\n",
    "    all_possible_answers = []\n",
    "    all_query_types = []\n",
    "\n",
    "    with mp.Pool(processes=num_processes, initializer=worker_init, initargs=(queues_results,)) \\\n",
    "            if config.multiprocessing else open(os.devnull, \"w\") as pool:\n",
    "        \n",
    "        n_batches = len(dataloader)\n",
    "\n",
    "        for i, batch in tqdm(enumerate(dataloader), total=n_batches):\n",
    "            # Combine all querys and get Codex predictions for them\n",
    "            # TODO compute Codex for next batch as current batch is being processed\n",
    "            if not config.use_cached_codex:\n",
    "                console.print(batch) #batch包含query，answer，image，index\n",
    "                # print(base_prompt) #base_prompt就是api.prompt\n",
    "                codes = codex(prompt=batch['query'], base_prompt=base_prompt)\n",
    "            else:\n",
    "                codes = codes_all[i * batch_size:(i + 1) * batch_size]  # If cache\n",
    "\n",
    "            # Run the code\n",
    "            if config.execute_code:\n",
    "                if not config.multiprocessing:\n",
    "                    # Otherwise, we would create a new model for every process\n",
    "                    results = []\n",
    "                    for c, sample_id, img, possible_answers, query in \\\n",
    "                            zip(codes, batch['index'], batch['image'], batch['answer'], batch['query']):\n",
    "                        try:\n",
    "                            result = run_program([c, sample_id , img, possible_answers, query], queues_in, input_type)\n",
    "                            results.append(result)\n",
    "                        except Exception as e:\n",
    "                            console.print(f'Exception: {e}')\n",
    "                            code_header = f'def execute_command_{sample_id}(' \\\n",
    "                                f'{input_type}, possible_answers, query, ' \\\n",
    "                                f'ImagePatch, VideoSegment, ' \\\n",
    "                                'llm_query, bool_to_yesno, distance, best_image_match):\\n'\n",
    "                            lines = c.split(\"\\n\")\n",
    "                            lines[0] = code_header\n",
    "                            wrong_code = \"\\n\".join(lines)\n",
    "                            results.append((e , wrong_code))\n",
    "\n",
    "                else:\n",
    "                    results = list(pool.imap(partial(\n",
    "                        run_program, queues_in_=queues_in, input_type_=input_type),\n",
    "                        zip(codes, batch['index'], batch['image'], batch['answer'], batch['query'])))\n",
    "            else: \n",
    "                results = [(None, c) for c in codes]\n",
    "                warnings.warn(\"Not executing code! This is only generating the code. We set the flag \"\n",
    "                                \"'execute_code' to False by default, because executing code generated by a language \"\n",
    "                                \"model can be dangerous. Set the flag 'execute_code' to True if you want to execute \"\n",
    "                                \"it.\")\n",
    "                \n",
    "\n",
    "            all_results += [r[0] for r in results]\n",
    "            all_codes += [r[1] for r in results]\n",
    "            all_ids += batch['index']\n",
    "            all_answers += batch['answer']\n",
    "            all_possible_answers += batch['answer']\n",
    "            all_query_types += batch['index']\n",
    "            all_querys += batch['query']\n",
    "            all_img_paths += [dataset.get_sample_path(idx) for idx in batch['index']]\n",
    "            if i % config.log_every == 0:\n",
    "                try:\n",
    "                    accuracy = datasets.accuracy(all_results, all_answers, all_possible_answers, all_query_types)\n",
    "                    console.print(f'Accuracy at Batch {i}/{n_batches}: {accuracy}')\n",
    "\n",
    "                except Exception as e:\n",
    "                    console.print(f'Error computing accuracy: {e}')\n",
    "            if i % 2 == 0:\n",
    "                console.print('Saving results to sub_result' , i // 2)\n",
    "                df = pd.DataFrame([all_results, all_answers, all_codes, all_ids, all_querys, all_img_paths,\n",
    "                                all_possible_answers]).T\n",
    "                df.columns = ['result', 'answer', 'code', 'id', 'query', 'img_path', 'possible_answers']\n",
    "                # make the result column a string\n",
    "                df['result'] = df['result'].apply(str)\n",
    "                df.to_csv(\"/home/viper/results/sub_results/sub_results\"+str(i // 2)+\".csv\"\n",
    "                            , header=True, index=False, encoding='utf-8')\n",
    "\n",
    "    try:\n",
    "        accuracy = datasets.accuracy(all_results, all_answers, all_possible_answers, all_query_types)\n",
    "        console.print(f'Final accuracy: {accuracy}')\n",
    "    except Exception as e:\n",
    "        console.print(f'Error computing accuracy: {e}')\n",
    "\n",
    "    if config.save:\n",
    "        results_dir = pathlib.Path(config['results_dir'])\n",
    "        results_dir = results_dir / config.dataset.split\n",
    "        results_dir.mkdir(parents=True, exist_ok=True)\n",
    "        if not config.save_new_results:\n",
    "            filename = 'results.csv'\n",
    "        else:\n",
    "            existing_files = list(results_dir.glob('results_*.csv'))\n",
    "            if len(existing_files) == 0:\n",
    "                filename = 'results_0.csv'\n",
    "            else:\n",
    "                filename = 'results_' + str(max([int(ef.stem.split('_')[-1]) for ef in existing_files if\n",
    "                                                 str.isnumeric(ef.stem.split('_')[-1])]) + 1) + '.csv'\n",
    "        console.print('Saving results to', filename)\n",
    "        df = pd.DataFrame([all_results, all_answers, all_codes, all_ids, all_querys, all_img_paths,\n",
    "                           all_possible_answers]).T\n",
    "        df.columns = ['result', 'answer', 'code', 'id', 'query', 'img_path', 'possible_answers']\n",
    "        # make the result column a string\n",
    "        df['result'] = df['result'].apply(str)\n",
    "        df.to_csv(results_dir / filename, header=True, index=False, encoding='utf-8')\n",
    "        # torch.save([all_results, all_answers, all_codes, all_ids, all_querys, all_img_paths], results_dir/filename)\n",
    "\n",
    "        # if config.wandb:\n",
    "        #     wandb.log({'accuracy': accuracy})\n",
    "        #     wandb.log({'results': wandb.Table(dataframe=df, allow_mixed_types=True)})\n",
    "\n",
    "    finish_all_consumers()\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    main()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "vipergpt",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
