import json
import openai
import os
import re
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import networkx as nx

from image import Image

class Task:
    digit_to_alphabet = ["b", "l", "r", "g", "y", "a", "p", "o", "n", "w"]
    digit_to_word = ["black", "blue", "red", "green", "yellow", "gray", "purple", "orange", "cyan", "brown"]
    regex_mask = {
        "number": r"\d",
        "alphabet": r"(" + "|".join(digit_to_alphabet) + ")",
        "word": r"(" + "|".join(digit_to_word) + ")"
    }
    regex_delimiter = {
        "none": r"",
        ".": r".",
        ",": r",",
        "|": r"\|",
        ";": r";"
    }
    regex_dend = {
        "none": r"",
        ".": r".?",
        ",": r",?",
        "|": r"\|?",
        ";": r";?"
    }
    graph_encodings = {"node_edge_set", "graph_json", "object_json", "object_json_1d", "object_json_words", "object_descriptor", "object_json_w_edge", "object_descriptor_w_edge"}
    graph_encoding_examples = {
        # "object_descriptor": "You will be asked to solve a few tasks, each task providing a few examples of input {self.mode}s containing objects transforming into output {self.mode}s containing objects. The objects are represented by its attributes including the coordinates it contains on the {self.mode}, the color represented as a number and its size . Identify transformation(s) applied to the input {self.mode}s to obtain output {self.mode}s.\nTask 1:\nInput {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:\nObject 0: coordinates=[(2, 3)], color=1, size=1\nObject 1: coordinates=[(4, 3), (4, 4)], color=2, size=2\n\nBecomes output {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:\nObject 0: coordinates=[(2, 3), (3, 3)], color=1, size=2\nObject 1: coordinates=[(4, 3), (4, 4)], color=2, size=2\n\nAnswer:\nTransformation applied:\nExtend size 1 object towards color 2 object until they touch.\n\nTask 2:\nInput {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:\nObject 0: coordinates=[(2, 1), (3, 1)], color=1, size=2\n\nBecomes output {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:\nObject Node 0: coordinates=[(2, 2), (3, 2)], color=2, size=2\n\nInput {self.mode} 2:\n{self.mode} size:(6, 5)\nObjects:\nObject 0: coordinates=[(2, 3), (2, 4), (3, 4)], color=1, size=3\n\nBecomes output {self.mode} 2:\n{self.mode} size:(6, 5)\nObjects:\nObject 0: coordinates=[(2, 4), (2, 5), (3, 5)], color=2, size=3\n\nAnswer:\nTransformation applied:\nMove color 1 object 1 pixel to the right\nRecolor color 1 object to color 2\n\nTask 3:"
        "object_descriptor": "You will be asked to solve a few tasks, each task providing a few examples of input {self.mode}s containing objects transforming into output {self.mode}s containing objects. The objects are represented by its attributes including the coordinates it contains on the {self.mode}, the color represented as a number and its size . Identify transformation(s) applied to the input {self.mode}s to obtain output {self.mode}s.\nTask 1:\nInput {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:\nObject 1: coordinates=[(1, 2)], color=6, size=1\nObject 2: coordinates=[(3, 2), (3, 3)], color=2, size=2\n\nBecomes output {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:\nObject 1: coordinates=[(1, 2), (2, 2)], color=6, size=2\nObject 2: coordinates=[(3, 2), (3, 3)], color=2, size=2\n\nAnswer:\nTransformation applied:\nExtend size 1 object towards color 2 object until they touch.\n\nTask 2:\nInput {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:\nObject 1: coordinates=[(1, 0), (2, 0)], color=1, size=2\n\nBecomes output {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:\nObject Node 1: coordinates=[(1, 1), (2, 1)], color=2, size=2\n\nInput {self.mode} 2:\n{self.mode} size:(6, 5)\nObjects:\nObject 1: coordinates=[(1, 2), (1, 3), (2, 3)], color=1, size=3\n\nBecomes output {self.mode} 2:\n{self.mode} size:(6, 5)\nObjects:\nObject 1: coordinates=[(1, 3), (1, 4), (2, 4)], color=2, size=3\n\nAnswer:\nTransformation applied:\nMove color 1 object 1 pixel to the right\nRecolor color 1 object to color 2\n\nTask 3:",
        # "object_json": "You will be asked to solve a few logical reasoning tasks, each task providing a few examples of input {self.mode}s containing objects transforming into output {self.mode}s containing objects. The objects are represented by its attributes including the coordinates it contains on the {self.mode}, the color represented as a number and its size . Identify transformation(s) applied to the input {self.mode}s to obtain output {self.mode}s.\nTask 1:\nInput {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:[{\"coordinates\": [[1, 2]], \"color\": 6, \"size\": 1}, {\"coordinates\": [[3, 2], [3, 3]], \"color\": 2, \"size\": 2}]\n\nBecomes output {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:[{\"coordinates\": [[1, 2], [2, 2]], \"color\": 6, \"size\": 2}, {\"coordinates\": [[3, 2], [3, 3]], \"color\": 2, \"size\": 2}]\nAnswer:\nTransformation applied:\n1.Extend size-1 color-6 object towards color-2 object until they touch.\n\nTask 2:\nInput {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:[{\"coordinates\": [[2, 1], [3, 1]], \"color\": 1, \"size\": 2}]\n\nBecomes output {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:[{\"coordinates\": [[2, 2], [3, 2]], \"color\": 2, \"size\": 2}]\n\nInput {self.mode} 2:\n{self.mode} size:(6, 5)\nObjects:[{\"coordinates\": [[1, 2], [1, 3], [2, 3]], \"color\": 1, \"size\": 2}]\n\nBecomes output {self.mode} 2:\n{self.mode} size:(6, 5)\nObjects:[{\"coordinates\": [[1, 3], [1, 4], [2, 4]], \"color\": 2, \"size\": 2}]\n\nAnswer:\nTransformation applied:\n1.Move color 1 object 1 pixel to the right\n2.Recolor color 1 object to color 2\n\nTask 3:"
        "object_json": "You will be asked to solve a few logical reasoning tasks, each task providing a few examples of input {self.mode}s containing objects transforming into output {self.mode}s containing objects. The objects are represented by its attributes including the coordinates it contains on the {self.mode}, the color represented as a number and its size . Identify transformation(s) applied to the input {self.mode}s to obtain output {self.mode}s.\nTask 1:\nInput {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:[{\"coordinates\": [[1, 2]], \"color\": 6, \"size\": 1}, {\"coordinates\": [[3, 2], [3, 3]], \"color\": 2, \"size\": 2}]\n\nBecomes output {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:[{\"coordinates\": [[1, 2], [2, 2]], \"color\": 6, \"size\": 2}, {\"coordinates\": [[3, 2], [3, 3]], \"color\": 2, \"size\": 2}]\nAnswer:\nTransformation applied:\n1.Extend size-1 color-6 object towards color-2 object until they touch.\n\nTask 2:\nInput {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:[{\"coordinates\": [[2, 1], [3, 1]], \"color\": 1, \"size\": 2}]\n\nBecomes output {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:[{\"coordinates\": [[2, 2], [3, 2]], \"color\": 2, \"size\": 2}]\n\nInput {self.mode} 2:\n{self.mode} size:(6, 5)\nObjects:[{\"coordinates\": [[1, 2], [1, 3], [2, 3]], \"color\": 1, \"size\": 3}]\n\nBecomes output {self.mode} 2:\n{self.mode} size:(6, 5)\nObjects:[{\"coordinates\": [[1, 3], [1, 4], [2, 4]], \"color\": 2, \"size\": 3}]\n\nAnswer:\nTransformation applied:\n1.Move color 1 object 1 pixel down\n2.Recolor color 1 object to color 2\n\nTask 3:",
        # "object_json": "You will be asked to solve a few logical reasoning tasks, each task providing a few examples of input {self.mode}s containing objects transforming into output {self.mode}s containing objects. The objects are represented by its attributes including the coordinates it contains on the {self.mode}, the color represented as a number and its size . Identify transformation(s) applied to the input {self.mode}s to obtain output {self.mode}s.\nTask 1:\nInput {self.mode} 1:\n{self.mode} size:(5, 5)\n{self.mode} background color:0\nObjects:[{\"coordinates\": [[1, 2]], \"color\": 6, \"size\": 1}, {\"coordinates\": [[3, 2], [3, 3]], \"color\": 2, \"size\": 2}]\n\nBecomes output {self.mode} 1:\n{self.mode} size:(5, 5)\n{self.mode} background color:0\nObjects:[{\"coordinates\": [[1, 2], [2, 2]], \"color\": 6, \"size\": 2}, {\"coordinates\": [[3, 2], [3, 3]], \"color\": 2, \"size\": 2}]\nAnswer:\nTransformation applied:\n1.Extend size-1 color-6 object towards color-2 object until they touch.\n\nTask 2:\nInput {self.mode} 1:\n{self.mode} size:(5, 5)\n{self.mode} background color:0\nObjects:[{\"coordinates\": [[2, 1], [3, 1]], \"color\": 1, \"size\": 2}]\n\nBecomes output {self.mode} 1:\n{self.mode} size:(5, 5)\n{self.mode} background color:0\nObjects:[{\"coordinates\": [[2, 2], [3, 2]], \"color\": 2, \"size\": 2}]\n\nInput {self.mode} 2:\n{self.mode} size:(6, 5)\n{self.mode} background color:0\nObjects:[{\"coordinates\": [[1, 2], [1, 3], [2, 3]], \"color\": 1, \"size\": 3}]\n\nBecomes output {self.mode} 2:\n{self.mode} size:(6, 5)\n{self.mode} background color:0\nObjects:[{\"coordinates\": [[1, 3], [1, 4], [2, 4]], \"color\": 2, \"size\": 3}]\n\nAnswer:\nTransformation applied:\n1.Move color 1 object 1 pixel down\n2.Recolor color 1 object to color 2\n\nTask 3:",
        "object_json_w_edge": "You will be asked to solve a few logical reasoning tasks, each task providing a few examples of input {self.mode}s containing objects transforming into output {self.mode}s containing objects. The objects are represented by its attributes including the coordinates it contains on the {self.mode}, the color represented as a number and its size . Identify transformation(s) applied to the input {self.mode}s to obtain output {self.mode}s.\nTask 1:\nInput {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:[{\"coordinates\": [[1, 2]], \"color\": 6, \"size\": 1, \"id\": 1, \"neighbors\": [2]}, {\"coordinates\": [[3, 2], [3, 3]], \"color\": 2, \"size\": 2, \"id\": 2, \"neighbors\": [1]}]\n\nBecomes output {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:[{\"coordinates\": [[1, 2], [2, 2]], \"color\": 6, \"size\": 2, \"id\": 1, \"neighbors\": [2]}, {\"coordinates\": [[3, 2], [3, 3]], \"color\": 2, \"size\": 2, \"id\": 2, \"neighbors\": [1]}]\nAnswer:\nTransformation applied:\n1.Extend size-1 color-6 object towards its color-2 neighbor until they touch.\n\nTask 2:\nInput {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:[{\"coordinates\": [[2, 1], [3, 1]], \"color\": 1, \"size\": 2, \"id\": 1, \"neighbors\": []}]\n\nBecomes output {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:[{\"coordinates\": [[2, 2], [3, 2]], \"color\": 2, \"size\": 2, \"id\": 1, \"neighbors\": []}]\n\nInput {self.mode} 2:\n{self.mode} size:(6, 5)\nObjects:[{\"coordinates\": [[1, 2], [1, 3], [2, 3]], \"color\": 1, \"size\": 2, \"id\": 1, \"neighbors\": []}]\n\nBecomes output {self.mode} 2:\n{self.mode} size:(6, 5)\nObjects:[{\"coordinates\": [[1, 3], [1, 4], [2, 4]], \"color\": 2, \"size\": 2, \"id\": 1, \"neighbors\": []}]\n\nAnswer:\nTransformation applied:\n1.Move color 1 object 1 pixel down\n2.Recolor color 1 object to color 2\n\nTask 3:",
        "object_descriptor_w_edge": "You will be asked to solve a few tasks, each task providing a few examples of input {self.mode}s containing objects transforming into output {self.mode}s containing objects. The objects are represented by its attributes including the coordinates it contains on the {self.mode}, the color represented as a number and its size . Identify transformation(s) applied to the input {self.mode}s to obtain output {self.mode}s.\nTask 1:\nInput {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:\nObject 1: coordinates=[(1, 2)], color=6, size=1, neighbors=[Object 2]\nObject 2: coordinates=[(3, 2), (3, 3)], color=2, size=2, neighbors=[Object 1]\n\nBecomes output {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:\nObject 1: coordinates=[(1, 2), (2, 2)], color=6, size=2, neighbors=[Object 1]\nObject 2: coordinates=[(3, 2), (3, 3)], color=2, size=2, neighbors=[Object 2]\n\nAnswer:\nTransformation applied:\nExtend size 1 object towards its color 2 neighbor until they touch.\n\nTask 2:\nInput {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:\nObject 1: coordinates=[(1, 0), (2, 0)], color=1, size=2, neighbors=[]\n\nBecomes output {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:\nObject Node 1: coordinates=[(1, 1), (2, 1)], color=2, size=2, neighbors=[]\n\nInput {self.mode} 2:\n{self.mode} size:(6, 5)\nObjects:\nObject 1: coordinates=[(1, 2), (1, 3), (2, 3)], color=1, size=3, neighbors=[]\n\nBecomes output {self.mode} 2:\n{self.mode} size:(6, 5)\nObjects:\nObject 1: coordinates=[(1, 3), (1, 4), (2, 4)], color=2, size=3, neighbors=[]\n\nAnswer:\nTransformation applied:\nMove color 1 object 1 pixel to the right\nRecolor color 1 object to color 2\n\nTask 3:",
        "object_json_words": "You will be asked to solve a few logical reasoning tasks, each task providing a few examples of input {self.mode}s containing objects transforming into output {self.mode}s containing objects. The objects are represented by its attributes including the coordinates it contains on the {self.mode}, the color represented as a number and its size . Identify transformation(s) applied to the input {self.mode}s to obtain output {self.mode}s.\nTask 1:\nInput {self.mode} 1:\n{self.mode} size:(5, 5)\n{self.mode} background color: black\nObjects:[{\"coordinates\": [[1, 2]], \"color\": purple, \"size\": 1}, {\"coordinates\": [[3, 2], [3, 3]], \"color\": red, \"size\": 2}]\n\nBecomes output {self.mode} 1:\n{self.mode} size:(5, 5)\n{self.mode} background color: black\nObjects:[{\"coordinates\": [[1, 2], [2, 2]], \"color\": purple, \"size\": 2}, {\"coordinates\": [[3, 2], [3, 3]], \"color\": red, \"size\": 2}]\nAnswer:\nTransformation applied:\n1.Extend size 1 color purple object towards color red object until they touch.\n\nTask 2:\nInput {self.mode} 1:\n{self.mode} size:(5, 5)\n{self.mode} background color: black\nObjects:[{\"coordinates\": [[2, 1], [3, 1]], \"color\": blue, \"size\": 2}]\n\nBecomes output {self.mode} 1:\n{self.mode} size:(5, 5)\n{self.mode} background color: black\nObjects:[{\"coordinates\": [[2, 2], [3, 2]], \"color\": red, \"size\": 2}]\n\nInput {self.mode} 2:\n{self.mode} size:(6, 5)\n{self.mode} background color: black\nObjects:[{\"coordinates\": [[1, 2], [1, 3], [2, 3]], \"color\": blue, \"size\": 3}]\n\nBecomes output {self.mode} 2:\n{self.mode} size:(6, 5)\n{self.mode} background color: black\nObjects:[{\"coordinates\": [[1, 3], [1, 4], [2, 4]], \"color\": red, \"size\": 3}]\n\nAnswer:\nTransformation applied:\n1.Move color blue object 1 pixel down\n2.Recolor color blue object to color red\n\nTask 3:"
    }
    multicolor_abs = ["na", "mcccg", "mcccg_d"]

    def __init__(self, filepath):
        self.gpt_models = {"3.5": "gpt-3.5-turbo", "4": "gpt-4"}
        self.task_id = filepath.split("/")[-1].split(".")[0]
        self.db_path = "results/arc_solutions.csv"

        self.encoding = None
        self.delimiter = None
        self.mode = None

        self.train_input, self.train_output, self.test_input, self.test_output = [], [], [], []
        self.raw_data = None
        with open(filepath) as f:
            self.raw_data = f.read()
            data = json.loads(self.raw_data)
        for i, data_pair in enumerate(data["train"]):
            self.train_input.append(Image(grid=data_pair["input"], name=self.task_id + "_" + str(i+1) + "_train_in"))
            self.train_output.append(Image(grid=data_pair["output"], name=self.task_id + "_" + str(i+1) + "_train_out"))
        for i, data_pair in enumerate(data["test"]):
            self.test_input.append(Image(grid=data_pair["input"], name=self.task_id + "_" + str(i+1) + "_test_in"))
            self.test_output.append(Image(grid=data_pair["output"], name=self.task_id + "_" + str(i+1) + "_test_out"))
        self.test_answer_str = None

        self.few_shot_examples = []
        self.few_shot_task_files = []

        self.abstraction = "nbccg"

    def load_few_shot_examples(self):
        for task_file in self.few_shot_task_files:
            with open(task_file) as f:
                data = json.load(f)
            example = {"train_input": [], "train_output": [], "test_input": [], "test_output": []}
            for i, data_pair in enumerate(data["train"]):
                example["train_input"].append(Image(grid=data_pair["input"], name=self.task_id + "_" + str(i + 1) + "_train_in"))
                example["train_output"].append(Image(grid=data_pair["output"], name=self.task_id + "_" + str(i + 1) + "_train_out"))
            for i, data_pair in enumerate(data["test"]):
                example["test_input"].append(Image(grid=data_pair["input"], name=self.task_id + "_" + str(i + 1) + "_test_in"))
                example["test_output"].append(Image(grid=data_pair["output"], name=self.task_id + "_" + str(i + 1) + "_test_out"))
            self.few_shot_examples.append(example)

    def load_graph_representations(self, dictionary=None, abstraction="nbccg"):
        if dictionary is None:
            for train_in in self.train_input:
                getattr(train_in, Image.abstraction_ops[abstraction])()
                # train_in.get_non_black_components_graph()
            for train_out in self.train_output:
                # train_out.get_non_black_components_graph()
                getattr(train_out, Image.abstraction_ops[abstraction])()
            for test_in in self.test_input:
                # test_in.get_non_black_components_graph()
                getattr(test_in, Image.abstraction_ops[abstraction])()
            for test_out in self.test_output:
                # test_out.get_non_black_components_graph()
                getattr(test_out, Image.abstraction_ops[abstraction])()
        else:
            for train_in in dictionary["train_input"]:
                # train_in.get_non_black_components_graph()
                getattr(train_in, Image.abstraction_ops[abstraction])()
            for train_out in dictionary["train_output"]:
                # train_out.get_non_black_components_graph()
                getattr(train_out, Image.abstraction_ops[abstraction])()
            for test_in in dictionary["test_input"]:
                # test_in.get_non_black_components_graph()
                getattr(test_in, Image.abstraction_ops[abstraction])()
            for test_out in dictionary["test_output"]:
                # test_out.get_non_black_components_graph()
                getattr(test_out, Image.abstraction_ops[abstraction])()

    def get_encoded_prompt_str(self, is_str=False, version=1):

        # if is_str:
        #     json_dict = json.loads(json_dict)
        if version == 1:
            prompt = "Demonstrations:\n"
            prompt_sample = ""
            for demo_in, demo_out in zip(self.train_input, self.train_output):
                prompt += "\nInput " + self.mode + ":\n"
                prompt += demo_in.get_encoded_string(encoding=self.encoding, delimiter=self.delimiter)
                prompt += "\nBecomes output " + self.mode + ":\n"
                prompt += demo_out.get_encoded_string(encoding=self.encoding, delimiter=self.delimiter)

            for test_in, test_out in zip(self.test_input, self.test_output):
                prompt += "\nTest:\n\nInput " + self.mode + ":\n"
                prompt += test_in.get_encoded_string(encoding=self.encoding, delimiter=self.delimiter)
                prompt_sample += test_in.get_encoded_string(encoding=self.encoding, delimiter=self.delimiter)
                prompt += "\nWhat does this input " + self.mode + " become?"
                break  # only use the first test input
        elif version == 2:
            prompt = ""
            prompt_sample = ""
            index = 1
            for demo_in, demo_out in zip(self.train_input, self.train_output):
                prompt += "\nInput " + self.mode + f" {index}" + ":\n"
                prompt += demo_in.get_encoded_string(encoding=self.encoding, delimiter=self.delimiter)
                prompt += "\nBecomes output " + self.mode + f" {index}" + ":\n"
                prompt += demo_out.get_encoded_string(encoding=self.encoding, delimiter=self.delimiter)
                index += 1

            for test_in, test_out in zip(self.test_input, self.test_output):
                prompt += "\nInput " + self.mode + f" {index}" + ":\n"
                prompt += test_in.get_encoded_string(encoding=self.encoding, delimiter=self.delimiter)
                prompt_sample += test_in.get_encoded_string(encoding=self.encoding, delimiter=self.delimiter)
                prompt += "\nBecomes output " + self.mode + f" {index}" + ":\n"
                break  # only use the first test input
        if version == 3:
            prompt = ""
            prompt_sample = ""
            for demo_in, demo_out in zip(self.train_input, self.train_output):
                prompt += "\nInput " + self.mode + ":\n"
                prompt += demo_in.get_encoded_string(encoding=self.encoding, delimiter=self.delimiter)
                prompt += "\nBecomes output " + self.mode + ":\n"
                prompt += demo_out.get_encoded_string(encoding=self.encoding, delimiter=self.delimiter)

            for test_in, test_out in zip(self.test_input, self.test_output):
                prompt += "\nInput " + self.mode + ":\n"
                prompt += test_in.get_encoded_string(encoding=self.encoding, delimiter=self.delimiter)
                prompt_sample += test_in.get_encoded_string(encoding=self.encoding, delimiter=self.delimiter)
                prompt += "\nWhat does this input " + self.mode + " become?"
                break  # only use the first test input

        return prompt, prompt_sample

    def get_graph_encoded_prompt_str(self, version=1):
        if version == 1:
            prompt = "Demonstrations:\n"
            prompt_sample = ""
            for index, (demo_in, demo_out) in enumerate(zip(self.train_input, self.train_output)):
                prompt += "\nInput " + self.mode + f" {index + 1}" + ":\n"
                prompt += self.mode + " size:" + str(demo_in.image_size) + "\n"
                prompt += demo_in.get_graph_encoded_string(encoding=self.encoding)
                prompt += "\nBecomes output " + self.mode + f" {index + 1}" + ":\n"
                prompt += self.mode + " size:" + str(demo_out.image_size) + "\n"
                prompt += demo_out.get_graph_encoded_string(encoding=self.encoding)

            for test_in, test_out in zip(self.test_input, self.test_output):
                prompt += "\nTest:\n\nInput " + self.mode + ":\n"
                prompt += self.mode + " size:" + str(test_in.image_size) + "\n"
                prompt += test_in.get_graph_encoded_string(encoding=self.encoding)
                prompt_sample += test_in.get_graph_encoded_string(encoding=self.encoding)
                prompt += "\nWhat does this input " + self.mode + " become?"
                break  # only use the first test input
        elif version == 2:
            prompt = ""
            prompt_sample = ""
            # if self.encoding == "object_json_words":
            #     convert = lambda i: Image.digit_to_word[i]
            # else:
            #     convert = lambda i: str(i)
            index = 1
            for demo_in, demo_out in zip(self.train_input, self.train_output):
                prompt += "\nInput " + self.mode + f" {index}" + ":\n"
                prompt += self.mode + " size:" + str(demo_in.image_size) + "\n"
                # prompt += self.mode + " background color:" + convert(demo_in.background_color) + "\n"
                prompt += demo_in.get_graph_encoded_string(encoding=self.encoding)
                prompt += "\nBecomes output " + self.mode + f" {index}" + ":\n"
                prompt += self.mode + " size:" + str(demo_out.image_size) + "\n"
                # prompt += self.mode + " background color:" + convert(demo_out.background_color) + "\n"
                prompt += demo_out.get_graph_encoded_string(encoding=self.encoding)
                index += 1

            for test_in, test_out in zip(self.test_input, self.test_output):
                prompt += "\nInput " + self.mode + f" {index}" + ":\n"
                prompt += self.mode + " size:" + str(test_in.image_size) + "\n"
                # prompt += self.mode + " background color:" + convert(test_in.background_color) + "\n"
                prompt += test_in.get_graph_encoded_string(encoding=self.encoding)
                prompt_sample += test_in.get_graph_encoded_string(encoding=self.encoding)
                prompt += "\nBecomes output " + self.mode + f" {index}" + ":\n"
                break  # only use the first test input
        elif version == 3:
            prompt = "Demonstrations:\n"
            prompt_sample = ""
            for index, (demo_in, demo_out) in enumerate(zip(self.train_input, self.train_output)):
                prompt += "\nInput " + self.mode + f" {index + 1}" + ":\n"
                prompt += self.mode + " size:" + str(demo_in.image_size) + "\n"
                prompt += self.mode + " background color:" + str(demo_in.background_color) + "\n"
                prompt += demo_in.get_graph_encoded_string(encoding=self.encoding)
                prompt += "\nBecomes output " + self.mode + f" {index + 1}" + ":\n"
                prompt += self.mode + " size:" + str(demo_out.image_size) + "\n"
                prompt += self.mode + " background color:" + str(demo_out.background_color) + "\n"
                prompt += demo_out.get_graph_encoded_string(encoding=self.encoding)

            for test_in, test_out in zip(self.test_input, self.test_output):
                prompt += "\nTest:\n\nInput " + self.mode + ":\n"
                prompt += self.mode + " size:" + str(test_in.image_size) + "\n"
                prompt += self.mode + " background color:" + str(test_in.background_color) + "\n"
                prompt += test_in.get_graph_encoded_string(encoding=self.encoding)
                prompt_sample += test_in.get_graph_encoded_string(encoding=self.encoding)
                prompt += "\nWhat does this input " + self.mode + " become?"
                break  # only use the first test input
        elif version == 4:
            prompt = ""
            prompt_sample = ""
            if self.encoding == "object_json_words":
                convert = lambda i: Image.digit_to_word[i]
            else:
                convert = lambda i: str(i)
            index = 1
            for demo_in, demo_out in zip(self.train_input, self.train_output):
                prompt += "\nInput " + self.mode + f" {index}" + ":\n"
                prompt += self.mode + " size:" + str(demo_in.image_size) + "\n"
                prompt += self.mode + " background color:" + convert(demo_in.background_color) + "\n"
                prompt += demo_in.get_graph_encoded_string(encoding=self.encoding)
                prompt += "\nBecomes output " + self.mode + f" {index}" + ":\n"
                prompt += self.mode + " size:" + str(demo_out.image_size) + "\n"
                prompt += self.mode + " background color:" + convert(demo_out.background_color) + "\n"
                prompt += demo_out.get_graph_encoded_string(encoding=self.encoding)
                index += 1

            for test_in, test_out in zip(self.test_input, self.test_output):
                prompt += "\nInput " + self.mode + f" {index}" + ":\n"
                prompt += self.mode + " size:" + str(test_in.image_size) + "\n"
                prompt += self.mode + " background color:" + convert(test_in.background_color) + "\n"
                prompt += test_in.get_graph_encoded_string(encoding=self.encoding)
                prompt_sample += test_in.get_graph_encoded_string(encoding=self.encoding)
                prompt += "\nBecomes output " + self.mode + f" {index}" + ":\n"
                break  # only use the first test input

        return prompt, prompt_sample

    def generate_prompt(self, prompt_type="direct_grid_few_shot"):

        messages, prompt_sample = [], ""

        if prompt_type == "direct_grid_few_shot":
            prompt_str, prompt_sample = self.get_encoded_prompt_str()
            messages = [
                {"role": "system",
                 "content": "Assistant is a chatbot that is capable of doing human-level reasoning and inference. Assistant will try to solve some puzzles and answer the steps as concisely as possible.\nKnowledge cutoff: 2021-09-01\nCurrent date: 2023-03-16"},
                {"role": "user",
                 "content": "Let's play some puzzles that focus on reasoning and logic. In each puzzle, you will be provided a few demonstrations of how an \"input " + self.mode + "\" gets transformed into a corresponding \"output " + self.mode + "\". At the end, you will get a brand new \"input " + self.mode + "\", then you must answer the corresponding \"output " + self.mode + "\" and describe the transformations used step by step starting from the \"input " + self.mode + "\". Is it clear?"},
                {"role": "assistant", "content": "Yes, it's clear. Let's get started with the first puzzle!"},
                {"role": "user", "content": "Here's the first puzzle:\n" + prompt_str}]

        elif prompt_type == "direct_grid_in_context_few_shot_cot":
            encoding_mapper = {
                "number": lambda i: str(i),
                "alphabet": lambda i: Image.digit_to_alphabet[i],
                "word": lambda i: Image.digit_to_word[i]
            }
            convert = encoding_mapper[self.encoding]

            sol = [f"Answer:\nTransformation applied:\n1.Extend size 1 color {convert(6)} object towards color {convert(2)} object until they touch.\n",
                   f"Answer:\nTransformation applied:\n1.Move color {convert(1)} object 1 pixel down\n2.Recolor color {convert(1)} object to color {convert(2)}\n"]

            self.few_shot_task_files = ["dataset/fewshot/fewshot1.json", "dataset/fewshot/fewshot2.json"]
            self.load_few_shot_examples()

            prompt = "You will be asked to solve a few tasks, each task providing a few examples of input {self.mode}s containing objects transforming into output {self.mode}s containing objects. Identify transformation(s) applied to the input {self.mode}s to obtain output {self.mode}s.\n"
            for ex_idx, example in enumerate(self.few_shot_examples):
                prompt += f"Task {ex_idx + 1}:"
                index = 1
                for demo_in, demo_out in zip(example["train_input"], example["train_output"]):
                    prompt += "\nInput " + self.mode + f" {index}" + ":\n"
                    prompt += demo_in.get_encoded_string(encoding=self.encoding, delimiter=self.delimiter)
                    prompt += "\nBecomes output " + self.mode + f" {index}" + ":\n"
                    prompt += demo_out.get_encoded_string(encoding=self.encoding, delimiter=self.delimiter)
                    index += 1
                prompt += sol[ex_idx]
            prompt += "Task 3:\n"

            example_str = prompt.replace("{self.mode}", self.mode)
            prompt_str, prompt_sample = self.get_encoded_prompt_str(version=2)

            messages = [
                {"role": "user",
                 "content": example_str + prompt_str}]

        elif prompt_type == "object_based_few_shot":
            with open(f"dataset/subset/ARGA-solutions/solutions_{self.task_id}.json") as f:
                task_solutions = json.load(f)
            self.abstraction = task_solutions["abstraction"]
            self.load_graph_representations(abstraction=self.abstraction)  # load ARGA graph representations

            prompt_str, prompt_sample = self.get_graph_encoded_prompt_str()
            messages = [
                {"role": "system",
                 "content": "Assistant is a chatbot that is capable of doing human-level reasoning and inference. Assistant will try to solve some puzzles and answer the steps as concisely as possible.\nKnowledge cutoff: 2021-09-01\nCurrent date: 2023-03-16"},
                {"role": "user",
                 "content": "Let's play some puzzles that focus on reasoning and logic. In each puzzle, you will be provided a few demonstrations of how an \"input " + self.mode + "\" gets transformed into a corresponding \"output " + self.mode + "\". At the end, you will get a brand new \"input " + self.mode + "\", then you must answer the corresponding \"output " + self.mode + "\" and describe the transformations used step by step starting from the \"input " + self.mode + "\". Is it clear?"},
                {"role": "assistant", "content": "Yes, it's clear. Let's get started with the first puzzle!"},
                {"role": "user", "content": "Here's the first puzzle:\n" + prompt_str + "\nBased on the patterns observed in the demonstrations, the output image should be as follows:\nOutput image:"}]

        elif prompt_type == "object_based_in_context_few_shot_cot":

            with open(f"dataset/subset/ARGA-solutions/solutions_{self.task_id}.json") as f:
                task_solutions = json.load(f)
            self.abstraction = task_solutions["abstraction"]

            self.load_graph_representations(abstraction=self.abstraction)  # load ARGA graph representations
            example_str = self.graph_encoding_examples[self.encoding].replace("{self.mode}", self.mode)
            prompt_str, prompt_sample = self.get_graph_encoded_prompt_str(2)

            messages = [
                {"role": "user",
                 "content": example_str + prompt_str}]

        elif prompt_type == "object_based_in_context_few_shot_cot_1d":
            graph_encoding_examples_detailed = {
                # "object_descriptor": "You will be asked to solve a few tasks, each task providing a few examples of input {self.mode}s containing objects transforming into output {self.mode}s containing objects. The objects are represented by its attributes including the coordinates it contains on the {self.mode}, the color represented as a number and its size . Identify transformation(s) applied to the input {self.mode}s to obtain output {self.mode}s.\nTask 1:\nInput {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:\nObject 1: coordinates=[(1, 2)], color=6, size=1\nObject 2: coordinates=[(3, 2), (3, 3)], color=2, size=2\n\nBecomes output {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:\nObject 1: coordinates=[(1, 2), (2, 2)], color=6, size=2\nObject 2: coordinates=[(3, 2), (3, 3)], color=2, size=2\n\nAnswer:\nLet's think step by step\nQ: Which objects change from the input image to the output image?\nA: Object 1 in Input Image 1 changed.\nQ: What are the shared attributes of changed objects?\nA: The changed objects have color 6 and size 1.\nQ: What are the changes made to the objects?\nA: The object is extended towards the color 2 object until they touch.\n\n\nTask 2:\nInput {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:\nObject 1: coordinates=[(1, 0), (2, 0)], color=1, size=2\n\nBecomes output {self.mode} 1:\n{self.mode} size:(5, 5)\nObjects:\nObject Node 1: coordinates=[(1, 1), (2, 1)], color=2, size=2\n\nInput {self.mode} 2:\n{self.mode} size:(6, 5)\nObjects:\nObject 1: coordinates=[(1, 2), (1, 3), (2, 3)], color=1, size=3\n\nBecomes output {self.mode} 2:\n{self.mode} size:(6, 5)\nObjects:\nObject 1: coordinates=[(1, 3), (1, 4), (2, 4)], color=2, size=3\n\nAnswer:\nLet's think step by step\nQ: Which objects change from the input image to the output image?\nA: Object 1 in Input Image 1, Object 1 in Input Image 2 changed.\nQ: What are the shared attributes of changed objects?\nA: The changed objects have color 1.\nQ: What are the changes made to the objects?\nA: The objects are moved 1 coordinate down, and recolored to color 2\n\nTask 3:",
                "object_json": "You will be asked to solve a few tasks, each task providing a few examples of input {self.mode}s containing objects transforming into output {self.mode}s containing objects. The objects are represented by its attributes including the coordinates it contains on the {self.mode}, the color represented as a number and its size. Identify transformation(s) applied to the input {self.mode}s to obtain output {self.mode}s.\nTask 1:\nInput {self.mode} 1:\n{self.mode} size:(1, 7)\nObjects:[{\"coordinates\": [[1, 1]], \"color\": 6, \"size\": 1}, {\"coordinates\": [[1, 5]], \"color\": 2, \"size\": 1}]\n\nBecomes output {self.mode} 1:\n{self.mode} size:(1, 7)\nObjects:[{\"coordinates\": [[1, 4]], \"color\": 6, \"size\": 1}, {\"coordinates\": [[1, 5]], \"color\": 2, \"size\": 1}]\nAnswer:\nTransformation applied:\n1.Extend size-1 color-6 object towards color-2 object until they touch.\n\nTask 2:\nInput {self.mode} 1:\n{self.mode} size:(1, 9)\nObjects:[{\"coordinates\": [[1, 2], [1, 3], [1, 4]], \"color\": 1, \"size\": 3}]\n\nBecomes output {self.mode} 1:\n{self.mode} size:(1, 9)\nObjects:[{\"coordinates\": [[1, 3], [1, 4], [1, 5]], \"color\": 2, \"size\": 3}]\n\nInput {self.mode} 2:\n{self.mode} size:(1, 9)\nObjects:[{\"coordinates\": [[1, 1], [1, 2], [1, 3], [1, 4], [1, 5]], \"color\": 1, \"size\": 5}]\n\nBecomes output {self.mode} 2:\n{self.mode} size:(1, 9)\nObjects:[{\"coordinates\": [[1, 2], [1, 3], [1, 4], [1, 5], [1, 6]], \"color\": 2, \"size\": 5}]\n\nAnswer:\nTransformation applied:\n1.Move color 1 object 1 pixel to the right\n2.Recolor color 1 object to color 2\n\nTask 3:"
            }
            self.load_graph_representations(abstraction=self.abstraction)  # load ARGA graph representations
            example_str = graph_encoding_examples_detailed[self.encoding].replace("{self.mode}", self.mode)
            prompt_str, prompt_sample = self.get_graph_encoded_prompt_str(2)

            messages = [
                {"role": "user",
                 "content": example_str + prompt_str}]

        return messages, prompt_sample

    def run_GPT(self, mode="grid", encoding="number", delimiter="", json_str="", gpt="3.5",
                task_type="2d", few_shot_task_files=None, prompt_type="direct_grid_few_shot", temp=0, subset_name=None):
        """
        Task_id: ARC task_id, e.g 3aa6fb7a
        mode: grid|sequence, we tell GPT the inputs/outputs are in the form of grid|sequence. LM might have better results treating this encoding as sequence.
        mask: digit|alphabet|word, mask digit grids into alphabet/word grids, avoid a direct search from github json files
        delimiter: none|.|,|||;, delimiter between pixels
        json_str: customized json input for crafted new tasks (as string)
        """
        self.encoding = encoding
        self.delimiter = delimiter
        self.mode = mode

        # load few-shot/one-shot example tasks
        if few_shot_task_files:
            self.few_shot_task_files = few_shot_task_files

        # get prompt for GPT
        full_prompt, prompt_sample = self.generate_prompt(prompt_type=prompt_type)
        print("Full prompt constructed:")
        print(full_prompt)

        # Call GPT
        print("Calling GPT version:", str(gpt))

        gpt_error = False
        try:
            completion = openai.ChatCompletion.create(
                model=self.gpt_models[str(gpt)],
                # Try deterministic
                temperature=temp,
                messages=full_prompt
            )
        except Exception as e:
            # print("GPT error:", e)
            completion = {"choices": [{"message": {"content": f"{e}"}}]}
            gpt_answer = "GPT error: " + str(e)
            gpt_error = True

        if encoding not in self.graph_encodings:  # if direct grid encoding
            self.test_answer_str = self.test_output[0].get_encoded_string(delimiter=self.delimiter, encoding=self.encoding)
        else:
            self.test_answer_str = self.test_output[0].get_encoded_string(delimiter="", encoding="number")

        print("GPT answer:")
        print(completion['choices'][0]['message']['content'])

        match_flag, continuous_score, continuous_score_br, correct_size = 0, 0, 0, 0
        if not gpt_error:
            if encoding not in self.graph_encodings:
                gpt_answer = self.parse_completion_regex(completion=completion['choices'][0]['message']['content'])
            else:
                gpt_answer = self.parse_graph_completion_regex(completion=completion['choices'][0]['message']['content'])
            if gpt_answer == "parsing_error":
                # No regex match, try manual parse (or GPT parse)
                print("Regex failed, need manual inspect...")
            else:
                match_flag, continuous_score, continuous_score_br, correct_size = self.score_answer(gpt_answer)

            if isinstance(gpt_answer, list):
                gpt_answer = "\n".join("".join(str(col) for col in row) for row in gpt_answer)
        # else:
        #     # full_prompt = "{}"

        print(f"GPT Answer Continuous Accuracy: {continuous_score * 100}%")
        # Save to csv
        row_data = {"Task_ID": self.task_id, "Task_json": self.raw_data, "Task_type": task_type, "Mode": mode, "Encoding": encoding,
                    "Delimiter": delimiter, "LLM_model": "GPT",
                    "GPT_version": gpt, "GPT_temperature": 0, "Prompt_sample": prompt_sample,
                    "Full_prompt": json.dumps(full_prompt),
                    "LLM_json_return": json.dumps(completion),
                    "LLM_full_answer": completion['choices'][0]['message']['content'],
                    "LLM_extracted_answer": gpt_answer, "True_answer": self.test_answer_str,
                    "Match_flag": str(match_flag), "Continuous_score": str(continuous_score),
                    "Continuous_score_br": str(continuous_score_br), "Correct_size": str(correct_size)}

        if self.delimiter == "|":
            delim = "pipe"
        elif self.delimiter == ",":
            delim = "comma"
        else:
            delim = ""
        if subset_name:
            save_file_path = "results/" + subset_name + "_" + prompt_type + "_" + self.encoding + "_" + delim + "_" + self.mode + "_" + gpt + "_temp" + str(temp) + ".csv"
        else:
            save_file_path = "results/" + prompt_type + "_" + self.encoding + "_" + delim + "_" + self.mode + "_" + gpt + "_temp" + str(temp) + ".csv"
        if os.path.isfile(save_file_path):
            # Read existing file into a dataframe
            db_df = pd.read_csv(save_file_path)
        else:
            # Create a new dataframe with headers
            db_df = pd.DataFrame(columns=list(row_data.keys()))

        # Create a dataframe with the new row data
        row_df = pd.DataFrame([row_data])
        # Concatenate the existing and new dataframes, drop duplicates, and reset index
        db_df_upd = pd.concat([db_df, row_df]).drop_duplicates(ignore_index=True).reset_index(drop=True)
        # Write the updated dataframe to the csv file
        db_df_upd.to_csv(save_file_path, index=False)

        return

    def parse_completion_regex(self, completion):
        if self.delimiter == "":
            delimiter = "none"
        else:
            delimiter = self.delimiter

        # trim whitespaces, since GPT4 sometimes "prettifies" the output
        completion = completion.replace(' ', '')

        regex_str = '(((mask_delimiter_){nrows_m_1}(mask_dend){1}\n){ncols_m_1}((mask_delimiter_){nrows_m_1}(mask_dend){1}\n?){1})'

        nrows, ncols = self.test_output[0].image_size

        regex_value = regex_str.replace("nrows_m_1", str(nrows - 1))
        regex_value = regex_value.replace("ncols_m_1", str(ncols - 1))
        regex_value = regex_value.replace("mask_", self.regex_mask[self.encoding])
        regex_value = regex_value.replace("delimiter_", self.regex_delimiter[delimiter])
        regex_value = regex_value.replace("dend", self.regex_dend[delimiter])

        ret_grid_str_tup = re.findall(r'' + regex_value, completion)
        if not ret_grid_str_tup:
            # Try ignore predicted grid size, assume result grid is sth larger than 2x2
            regex_value_nosize = regex_str.replace("nrows_m_1", "1,")
            regex_value_nosize = regex_value_nosize.replace("ncols_m_1", "1,")
            regex_value_nosize = regex_value_nosize.replace("mask_", self.regex_mask[self.encoding])
            regex_value_nosize = regex_value_nosize.replace("delimiter_", self.regex_delimiter[delimiter])
            regex_value_nosize = regex_value_nosize.replace("dend", self.regex_dend[delimiter])
            ret_grid_str_tup2 = re.findall(r'' + regex_value_nosize, completion)

            # print(repr(regex_value_nosize))
            if not ret_grid_str_tup2:
                # No match by regex
                grid_str = "error"
            else:
                grid_str = ret_grid_str_tup2[-1][0]
        else:
            grid_str = ret_grid_str_tup[-1][0]

        return grid_str

    def parse_graph_completion_regex(self, completion):

        if self.encoding == "object_descriptor":
            completion = completion.replace(' ', '')


            size_pattern = rf'{self.mode}size:\((\d+),(\d+)\)'
            size_match = re.findall(size_pattern, completion)
            height, width = int(size_match[0][0]), int(size_match[0][1])
            grid = [[0 for _ in range(width)] for _ in range(height)]


            if self.abstraction in self.multicolor_abs:
                object_pattern = r'Object(\d+):coordinates=\[(.*?)\],color=\[(.*?)\],size=(\d+)'
            else:
                object_pattern = r'Object(\d+):coordinates=\[(.*?)\],color=(\d+),size=(\d+)'


            object_matches = re.findall(object_pattern, completion)

            if not object_matches:
                return "parsing_error"

            for match in object_matches:
                node_id, coordinates_str, color, size = match
                coordinates = [tuple(map(int, coord.split(','))) for coord in coordinates_str.strip(")(").split('),(')]

                # node_id, size = int(node_id), int(size)
                if self.abstraction in self.multicolor_abs:
                    color = [int(c) for c in color.split(',')]
                    for indx, (rnum, cnum) in enumerate(coordinates):
                        try:
                            grid[rnum][cnum] = color[indx]
                        except IndexError:
                            pass
                else:
                    color = int(color)
                    for (rnum, cnum) in coordinates:
                        try:
                            grid[rnum][cnum] = color
                        except IndexError:
                            pass

        elif self.encoding == "object_json" or self.encoding == "object_json_words":
            completion = completion.replace(' ', '')

            word_to_digit = {'black': 0, 'blue': 1, 'red': 2, 'green': 3, 'yellow': 4, 'gray': 5, 'purple': 6,
                             'orange': 7, 'cyan': 8, 'brown': 9, "0": 0, "1": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6,
                             "7": 7, "8": 8, "9": 9}

            # initialize output grid
            size_pattern = rf'{self.mode}size:\((\d+),(\d+)\)'
            size_match = re.findall(size_pattern, completion)
            height, width = int(size_match[0][0]), int(size_match[0][1])
            grid = [[0 for _ in range(width)] for _ in range(height)]

            # Find object lines in the input string
            if self.abstraction in self.multicolor_abs:
                object_pattern = r'"coordinates":\[(.*?)\],"color":\[(.*?)\],"size":(\d+)'
            else:
                if self.encoding != "object_json_words":
                    object_pattern = r'"coordinates":\[(.*?)\],"color":(\d+),"size":(\d+)'
                else:
                    object_pattern = r'"coordinates":\[(.*?)\],"color":(.*?),"size":(\d+)'

            object_matches = re.findall(object_pattern, completion)


            if not object_matches:
                return "parsing_error"


            for match in object_matches:
                coordinates_str, color, size = match
                # color, size = int(color), int(size)
                coordinates = [tuple(map(int, coord.split(','))) for coord in coordinates_str.strip("][").split('],[')]

                if self.abstraction in self.multicolor_abs:

                    color = [word_to_digit[c.strip("\"\'")] for c in color.split(',')]
                    for indx, (rnum, cnum) in enumerate(coordinates):
                        try:
                            grid[rnum][cnum] = color[indx]
                        except IndexError:
                            pass
                else:
                    color = word_to_digit[color.strip("\"\'")]
                    for (rnum, cnum) in coordinates:
                        try:
                            # if self.abstraction == "nbvcg3":
                            #     grid[height - rnum - 1][cnum] = color
                            # else:
                            grid[rnum][cnum] = color
                        except IndexError:
                            pass
        elif self.encoding == "object_json_w_edge":
            completion = completion.replace(' ', '')

            # initialize output grid
            size_pattern = rf'{self.mode}size:\((\d+),(\d+)\)'
            size_match = re.findall(size_pattern, completion)
            height, width = int(size_match[0][0]), int(size_match[0][1])
            grid = [[0 for _ in range(width)] for _ in range(height)]

            # Find object lines in the input string
            if self.abstraction in self.multicolor_abs:
                object_pattern = r'"coordinates":\[(.*?)\],"color":\[(.*?)\],"size":(\d+),"id":(\d+),"neighbors":\[(.*?)\]'
            else:
                object_pattern = r'"coordinates":\[(.*?)\],"color":(\d+),"size":(\d+),"id":(\d+),"neighbors":\[(.*?)\]'

            object_matches = re.findall(object_pattern, completion)

            if not object_matches:
                return "parsing_error"

            for match in object_matches:
                coordinates_str, color, size, id, neighbor = match
                # color, size = int(color), int(size)
                coordinates = [tuple(map(int, coord.split(','))) for coord in coordinates_str.strip("][").split('],[')]

                if self.abstraction in self.multicolor_abs:
                    color = [int(c) for c in color.split(',')]
                    for indx, (rnum, cnum) in enumerate(coordinates):
                        try:
                            grid[rnum][cnum] = color[indx]
                        except IndexError:
                            pass
                else:
                    color = int(color)
                    for (rnum, cnum) in coordinates:
                        try:
                            grid[rnum][cnum] = color
                        except IndexError:
                            pass
        elif self.encoding == "object_descriptor_w_edge":
            completion = completion.replace(' ', '')

            # initialize output grid
            # size_pattern = rf'(?<=[Oo]utput{self.mode})[\s\S]*{self.mode}size:\((\d+),(\d+)\)'
            size_pattern = rf'{self.mode}size:\((\d+),(\d+)\)'
            size_match = re.findall(size_pattern, completion)
            height, width = int(size_match[0][0]), int(size_match[0][1])
            grid = [[0 for _ in range(width)] for _ in range(height)]

            # # Find object lines in the input string
            # output_image_index = completion.lower().find(f'output{self.mode}')
            # relevant_text = completion[output_image_index:]

            if self.abstraction in self.multicolor_abs:
                object_pattern = r'Object(\d+):coordinates=\[(.*?)\],color=\[(.*?)\],size=(\d+),neighbors=\[(.*?)\]'
            else:
                object_pattern = r'Object(\d+):coordinates=\[(.*?)\],color=(\d+),size=(\d+),neighbors=\[(.*?)\]'
            # object_pattern = rf'(?<=[Oo]utput{self.mode})[\s\S]*Object(\d+):coordinates=\[(.*?)\],color=(\d+),size=(\d+)'
            object_matches = re.findall(object_pattern, completion)

            if not object_matches:
                return "parsing_error"

            for match in object_matches:
                node_id, coordinates_str, color, size, neighbors = match
                coordinates = [tuple(map(int, coord.split(','))) for coord in coordinates_str.strip(")(").split('),(')]

                # node_id, size = int(node_id), int(size)
                if self.abstraction in self.multicolor_abs:
                    color = [int(c) for c in color.split(',')]
                    for indx, (rnum, cnum) in enumerate(coordinates):
                        try:
                            grid[rnum][cnum] = color[indx]
                        except IndexError:
                            pass
                else:
                    color = int(color)
                    for (rnum, cnum) in coordinates:
                        try:
                            grid[rnum][cnum] = color
                        except IndexError:
                            pass

        return grid

    def score_answer(self, answer):

        if self.encoding not in self.graph_encodings:
            grid_predicted = answer.strip().split('\n')
            grid_actual = self.test_answer_str.strip().split('\n')

            # Determine the number of rows and columns in the grids
            num_rows = len(grid_predicted)
            if self.delimiter != "":
                grid_predicted = [row.split(self.delimiter) for row in grid_predicted]
                grid_actual = [row.split(self.delimiter) for row in grid_actual]
                num_cols = len(grid_predicted[0])
            else:
                num_cols = len(grid_predicted[0])
        else:
            grid_actual = self.test_output[0].grid
            grid_predicted = answer
            num_rows, num_cols = len(grid_predicted), len(grid_predicted[0])

        if self.test_output[0].image_size != (num_rows, num_cols):
            return 0, 0, 0, 0

        # Count the number of matching cells and return the percentage
        continuous_accuracy = 0
        continuous_accuracy_nb = 0
        continuous_accuracy_nb_count = 0
        for i in range(num_rows):
            for j in range(num_cols):
                if grid_actual[i][j] != '0' and grid_actual[i][j] != 'b' and grid_actual[i][j] != 'black':
                    continuous_accuracy_nb_count += 1
                    if grid_predicted[i][j] == grid_actual[i][j]:
                        continuous_accuracy_nb += 1
                if grid_predicted[i][j] == grid_actual[i][j]:
                    continuous_accuracy += 1

        match_flag = 1 if continuous_accuracy / (num_rows * num_cols) == 1 else 0
        return match_flag, continuous_accuracy / (num_rows * num_cols), continuous_accuracy_nb / continuous_accuracy_nb_count, 1

    def print_prompt(self, mode="grid", encoding="number", delimiter="", prompt_type="single_stage_one_shot", few_shot_task_files=None):

        self.encoding = encoding
        self.delimiter = delimiter
        self.mode = mode

        if few_shot_task_files:
            self.few_shot_task_files = few_shot_task_files

        if encoding not in self.graph_encodings:
            self.test_answer_str = self.test_output[0].get_encoded_string(delimiter=self.delimiter, encoding=self.encoding)
        else:
            self.load_graph_representations()
            self.test_answer_str = self.test_output[0].get_graph_encoded_string(encoding=self.encoding)

        full_prompt, prompt_sample = self.generate_prompt(prompt_type=prompt_type)

        for prompt in full_prompt:
            print(prompt["role"] + ": ")
            print(prompt["content"])

        print("Answer: " + self.test_answer_str)





