import json
import openai
import re
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import os.path

from image import Image

class Task:
    digit_to_alphabet = ["b", "l", "r", "g", "y", "a", "p", "o", "n", "w"]
    digit_to_word = ["black", "blue", "red", "green", "yellow", "gray", "purple", "orange", "cyan", "brown"]
    regex_mask = {
        "number": r"\d",
        "alphabet": r"(" + "|".join(digit_to_alphabet) + ")",
        "word": r"(" + "|".join(digit_to_word) + ")"
    }
    regex_delimiter = {
        "none": r"",
        ".": r".",
        ",": r",",
        "|": r"\|",
        ";": r";"
    }
    regex_dend = {
        "none": r"",
        ".": r".?",
        ",": r",?",
        "|": r"\|?",
        ";": r";?"
    }

    def __init__(self, filepath, db_path="results/arc_solutions.csv"):
        self.gpt_models = {"3.5": "gpt-3.5-turbo", "4": "gpt-4"}
        self.task_id = filepath.split("/")[-1].split(".")[0]
        self.db_path = db_path
        self.encoding = None
        self.delimiter = None
        self.mode = None
        self.task_type = None

        self.train_input, self.train_output, self.test_input, self.test_output = [], [], [], []
        self.raw_data = None
        with open(filepath) as f:
            self.raw_data = f.read()
            data = json.loads(self.raw_data)
        for i, data_pair in enumerate(data["train"]):
            self.train_input.append(Image(grid=data_pair["input"], name=self.task_id + "_" + str(i+1) + "_train_in"))
            self.train_output.append(Image(grid=data_pair["output"], name=self.task_id + "_" + str(i+1) + "_train_out"))
        for i, data_pair in enumerate(data["test"]):
            self.test_input.append(Image(grid=data_pair["input"], name=self.task_id + "_" + str(i+1) + "_test_in"))
            self.test_output.append(Image(grid=data_pair["output"], name=self.task_id + "_" + str(i+1) + "_test_out"))
        self.test_answer_str = None

    def get_encoded_prompt_str(self, is_str=False):

        # if is_str:
        #     json_dict = json.loads(json_dict)

        prompt = "Demonstrations:\n"
        prompt_sample = ""
        for demo_in, demo_out in zip(self.train_input, self.train_output):
            prompt += "\nInput " + self.mode + ":\n"
            prompt += demo_in.get_encoded_string(encoding=self.encoding, delimiter=self.delimiter)
            prompt += "\nBecomes output " + self.mode + ":\n"
            prompt += demo_out.get_encoded_string(encoding=self.encoding, delimiter=self.delimiter)

        for test_in, test_out in zip(self.test_input, self.test_output):
            prompt += "\nTest:\n\nInput " + self.mode + ":\n"
            prompt += test_in.get_encoded_string(encoding=self.encoding, delimiter=self.delimiter)
            prompt_sample += test_in.get_encoded_string(encoding=self.encoding, delimiter=self.delimiter)
            prompt += "\nWhat does this input " + self.mode + " become?"
            break  # only use the first test input


        return prompt, prompt_sample

    def trial_prompt(self, prompt_type="single_stage"):
        prompt_str, prompt_sample = self.get_encoded_prompt_str()
        messages = [
            {"role": "system",
             "content": "Assistant is a chatbot that capable of doing human-level reasoning and inference. Solver will try to solve some puzzles and answer the steps as concisely as possible.\nKnowledge cutoff: 2021-09-01\nCurrent date: 2023-03-16"},
            {"role": "user",
             "content": "Let's play some puzzles that focus on reasoning and logic. In each puzzle, you will be provided a few demonstrations of how an \"input " + self.mode + "\" gets transformed into a corresponding \"output " + self.mode + "\". At the end, you will get a brand new \"input " + self.mode + "\", then you must answer the corresponding \"output " + self.mode + "\" and describe the transformations used step by step starting from the \"input " + self.mode + "\". Is it clear?"},
            {"role": "assistant", "content": "Yes, it's clear. Let's get started with the first puzzle!"},
            {"role": "user", "content": "Here's the first puzzle:\n" + prompt_str}]
        return messages, prompt_sample

    def run_GPT(self, mode="grid", encoding="number", delimiter="", json_str="", gpt="3.5",
                task_type="2d", overwrite=False):
        """
        Task_id: ARC task_id, e.g 3aa6fb7a
        mode: grid|sequence, we tell GPT the inputs/outputs are in the form of grid|sequence. LM might have better results treating this encoding as sequence.
        mask: digit|alphabet|word, mask digit grids into alphabet/word grids, avoid a direct search from github json files
        delimiter: none|.|,|||;, delimiter between pixels
        json_str: customized json input for crafted new tasks (as string)
        """
        self.encoding = encoding
        self.delimiter = delimiter
        self.mode = mode
        self.task_type = task_type

        # if (json_str != ""):
        #     is_str = True
        #     json_dict = json_str
        # else:
        #     is_str = False
        #     json_dict = get_ARC_json(task_id)
        #     json_str = json.dumps(json_dict)

        # outgrid_size = draw_full_arc(json_dict, is_str=is_str)

        print(self.task_id)

        # get prompt for GPT
        full_prompt, prompt_sample = self.trial_prompt()
        print("Full prompt constructed:")
        print(full_prompt)

        # Call GPT
        print("Calling GPT version:", str(gpt))
        completion = openai.ChatCompletion.create(
            model=self.gpt_models[str(gpt)],
            # Try deterministic
            temperature=0,
            messages=full_prompt
        )

        #with open("completion", "r") as f:
        #    completion = json.loads(f.read())


        # for debugging:
        # completion = {"choices": [{"message": {"content": "The input grid becomes:\n\n'0110002200\n0110002200\n1111022000\n0000000000\n0220000100\n0222200100\n0000000100\n0000000100\n0111110000\n0000000000\n\nHere are the transformations used step by step:\n\n1. The grid is divided into 3x3 sub-grids.\n2. For each sub-grid, if the center cell is non-zero, all cells in the sub-grid are set to the value of the center cell.\n3. For each row, if there are two or more non-zero cells with the same value, all other cells in the row are set to zero.\n4. For each column, if there are two or more non-zero cells with the same value, all other cells in the column are set to zero."}}]}
        # completion = {"choices": [{"message": {"content": "The input grid becomes:\n\nblack|black|black|black|black|blue|blue|blue|black\nblack|blue|blue|blue|blue|black|blue|gray|blue\nblack|blue|gray|blue|blue|black|blue|blue|blue\nblack|blue|blue|blue|blue|black|black|black|black\nblack|black|black|black|black|black|blue|blue|blue\nblack|black|black|black|black|black|blue|gray|blue\nblack|blue|blue|blue|blue|black|blue|blue|blue\nblack|blue|gray|blue|blue|black|black|black|black\nblack|blue|blue|blue|blue|black|black|black|black\n\nThe transformations used to get from the input grid to the output grid are:\n\n1. The black cells in the first column of the input grid become blue in the output grid.\n2. The gray cell in the second row and second column of the input grid becomes blue in the output grid.\n3. The gray cell in the third row and seventh column of the input grid becomes blue in the output grid.\n4. The black cells in the fourth row of the input grid become blue in the output grid.\n5. The gray cell in the fifth row and fourth column of the input grid becomes blue in the output grid.\n6. The black cells in the eighth row of the input grid become blue in the output grid.\n7. The black cells in the ninth row of the input grid become blue in the output grid."}}]}
        # completion = {"choices": [{"message": {"content": "Using this transformation, the input grid black,black,cyan,black,cyan,black,black becomes the output grid black,black,cyan,cyan,cyan,black,black."}}]}


        print("GPT answer:")
        print(completion)
        print(completion['choices'][0]['message']['content'])
        gpt_answer = self.parse_completion_regex(completion=completion['choices'][0]['message']['content'])

        self.test_answer_str = self.test_output[0].get_encoded_string(delimiter=self.delimiter, encoding=self.encoding)
        match_flag, continuous_score, continuous_score_br, correct_size = 0, 0, 0, 0
        if gpt_answer == "error":
            # No regex match, try manual parse (or GPT parse)
            print("Regex failed, need manual inspect...")
        else:
            match_flag, continuous_score, continuous_score_br, correct_size = self.score_answer(gpt_answer)

        # DB based on tasktype
        if self.task_type=="1d" or 1:
            #self.db_path="results/1d_solutions.csv"
            print("\nTest answer:", self.test_answer_str)
            print("GPT answer:", gpt_answer)
            print("Scores:", str(match_flag), str(continuous_score), str(continuous_score_br), str(correct_size))

        # Save to csv
        row_data = {"Task_ID": self.task_id, "Task_json": self.raw_data, "Task_type": task_type, "Mode": mode, "Encoding": encoding,
                    "Delimiter": delimiter, "LLM_model": "GPT",
                    "GPT_version": gpt, "GPT_temperature": 0, "Prompt_sample": prompt_sample,
                    "Full_prompt": json.dumps(full_prompt),
                    "LLM_json_return": json.dumps(completion),
                    "LLM_full_answer": completion['choices'][0]['message']['content'],
                    "LLM_extracted_answer": gpt_answer, "True_answer": self.test_answer_str,
                    "Match_flag": str(match_flag), "Continuous_score": str(continuous_score),
                    "Continuous_score_br": str(continuous_score_br), "Correct_size": str(correct_size)}
        row_df = pd.DataFrame([row_data])
        db_df = pd.read_csv(self.db_path)
        print(self.db_path, db_df.shape)
        if overwrite:
            backup_db_path = self.db_path.replace('.csv','_ow_old.csv')
            if not os.path.isfile(backup_db_path):
                db_df.to_csv(backup_db_path, index=False)
            print(db_df.shape)
            db_df_upd = pd.concat([db_df, row_df]).drop_duplicates(subset=['Task_ID','GPT_version'], keep='last').reset_index(drop=True)
            print(db_df_upd.shape)
        else:
            db_df_upd = pd.concat([db_df, row_df]).drop_duplicates(ignore_index=True).reset_index(drop=True)
        db_df_upd.to_csv(self.db_path, index=False)

        return

    def parse_completion_regex(self, completion):
        if self.delimiter == "":
            delimiter = "none"
        else:
            delimiter = self.delimiter

        # trim whitespaces, since GPT4 sometimes "prettifies" the output
        completion = completion.replace(' ', '')


        ncols, nrows = self.test_output[0].image_size

        if self.task_type == "1d":
            regex_str = '(((mask_delimiter_){ncols_m_1}(mask_dend){1}\n){nrows_m_1}((mask_delimiter_){ncols_m_1}(mask_dend){1}\n?){1})'
            regex_value = regex_str.replace("nrows_m_1", str(0))
        else:
            regex_str = '(((mask_delimiter_){ncols_m_1}(mask_dend){1}\n){nrows_m_1}((mask_delimiter_){ncols_m_1}(mask_dend){1}\n?){1})'
            regex_value = regex_str.replace("nrows_m_1", str(nrows - 1))
        regex_value = regex_value.replace("ncols_m_1", str(ncols - 1))
        regex_value = regex_value.replace("mask_", self.regex_mask[self.encoding])
        regex_value = regex_value.replace("delimiter_", self.regex_delimiter[delimiter])
        regex_value = regex_value.replace("dend", self.regex_dend[delimiter])

        # print(repr(regex_value))

        ret_grid_str_tup = re.findall(r'' + regex_value, completion)
        if not ret_grid_str_tup:
            # Try ignore predicted grid size, assume result grid is sth larger than 2x2, for 1d-tasks, assume min grid is 1x3
            if self.task_type == "1d":
                regex_value_nosize = regex_str.replace("nrows_m_1", str(0))
                regex_value_nosize = regex_value_nosize.replace("ncols_m_1", "2,")
            else:
                regex_value_nosize = regex_str.replace("nrows_m_1", "1,")
                regex_value_nosize = regex_value_nosize.replace("ncols_m_1", "1,")
            regex_value_nosize = regex_value_nosize.replace("mask_", self.regex_mask[self.encoding])
            regex_value_nosize = regex_value_nosize.replace("delimiter_", self.regex_delimiter[delimiter])
            regex_value_nosize = regex_value_nosize.replace("dend", self.regex_dend[delimiter])
            ret_grid_str_tup2 = re.findall(r'' + regex_value_nosize, completion)

            #print(repr(regex_value_nosize))
            if not ret_grid_str_tup2:
                # No match by regex
                grid_str = "error"
            else:
                grid_str = ret_grid_str_tup2[-1][0]
        else:
            grid_str = ret_grid_str_tup[-1][0]

        return grid_str

    def score_answer(self, answer):

        rows_predicted = answer.strip().split('\n')
        rows_actual = self.test_answer_str.strip().split('\n')

        # Determine the number of rows and columns in the grids
        num_rows = len(rows_predicted)
        if self.delimiter != "":
            rows_predicted = [row.split(self.delimiter) for row in rows_predicted]
            rows_actual = [row.split(self.delimiter) for row in rows_actual]
            num_cols = len(rows_predicted[0])
        else:
            num_cols = len(rows_predicted[0])

        if self.test_output[0].image_size != (num_cols, num_rows):
            return 0, 0, 0, 0

        #print(num_cols, num_rows)
        #print(len(rows_predicted),len(rows_actual))
        #print(len(rows_predicted[0]),len(rows_actual[0]))
        #print(len(rows_predicted[1]),len(rows_actual[1]))

        # Count the number of matching cells and return the percentage
        continuous_accuracy = 0
        continuous_accuracy_nb = 0
        continuous_accuracy_nb_count = 0

        try:
            for i in range(num_rows):
                for j in range(num_cols):
                    if rows_actual[i][j] != '0' and rows_actual[i][j] != 'b' and rows_actual[i][j] != 'black':
                        continuous_accuracy_nb_count += 1
                        if rows_predicted[i][j] == rows_actual[i][j]:
                            continuous_accuracy_nb += 1
                    #print(i,j)
                    if rows_predicted[i][j] == rows_actual[i][j]:
                        continuous_accuracy += 1

            match_flag = 1 if continuous_accuracy / (num_rows * num_cols) == 1 else 0
        except IndexError:
            # Sometimes GPT returns answers in varies length, (e.g. row lengths do not match)
            match_flag, continuous_accuracy, continuous_accuracy_nb, continuous_accuracy_nb_count = 0,0,0,1

        return match_flag, continuous_accuracy / (num_rows * num_cols), continuous_accuracy_nb / continuous_accuracy_nb_count, 1







