import os
import yaml
import random
import shutil
import sys
import re

from pathlib import Path
from copy import deepcopy
from tarski.io import PDDLReader
from plan_bench.utils import *


def filter_state_from_pddl(s):
    try:
        head = s.index("(:init \n(") + len("(:init \n")
        tail = s.index("\n)\n(:goal")
    except ValueError as e:
        with open("logs.txt", "w") as f:
            f.write(s)
        raise e

    return s[head:tail].split("\n"), head, tail


def filter_state(pddl_file):
    with open(pddl_file, "r") as f:
        text = f.read()

    head_ind = text.index("(:goal\n")
    tail_ind = text.index(")\n)", head_ind)

    predicates = text[head_ind + len("(:goal\n") : tail_ind]
    if "(and\n" in predicates:
        predicates = predicates[len("(and\n") :]
    predicates = predicates.split("\n")

    return predicates


def step(
    run_name,
    original_pddl,
    action_list,
    domain_file="plan_bench/instances/blocksworld/generated_domain.pddl",
):
    """
    Returns list representing new PDDL states and their corresponding PDDL files
    """
    new_states = []
    working_file = f"{run_name}_temp_1.pddl"
    working_ind = 1
    working_files = []
    shutil.copyfile(original_pddl, working_file)

    for action in action_list:
        with open(working_file, "r") as f:
            prev_state_str = f.read()
        prev_states, _, _ = filter_state_from_pddl(prev_state_str)

        next_states = os.popen(
            f'lifted_pddl get_next_state -d {domain_file} -p {working_file} -a "{action}"'
        ).read()
        next_states = next_states.split("\n")[:-1]
        if sorted(next_states) == sorted(prev_states):
            new_states.append("issue")
            break

        new_states.append(next_states)

        with open(working_file, "r") as f:
            pddl_state_str = f.read()

        _, head, tail = filter_state_from_pddl(pddl_state_str)
        new_pddl_state_str = (
            pddl_state_str[:head] + "\n".join(next_states) + pddl_state_str[tail:]
        )
        working_ind += 1
        new_working_file = f"{run_name}_temp_{working_ind}.pddl"
        working_files.append(new_working_file)

        with open(new_working_file, "w") as f:
            f.write(new_pddl_state_str)

        working_file = new_working_file

    return new_states, working_files


def read_config(config_file):
    with open(config_file, "r") as file:
        return yaml.safe_load(file)


def text_state_to_pddl(text, config_file):
    """
    Converts LLM output text into PDDL state
    """
    return text_to_state_blocksworld(
        text.replace(" and ", ",").split(","), read_config(config_file)
    )


def get_problem(instance, domain):
    reader = PDDLReader(raise_on_error=True)
    reader.parse_domain(domain)
    return reader.parse_instance(instance)


def get_sorted(init_atoms):
    return sorted(
        init_atoms,
        key=lambda x: x.symbol.name
        + " "
        + " ".join([subterm.name for subterm in x.subterms]),
    )


def build_in_context_ex(
    run_name,
    original_pddl,
    config_file,
    domain_file="plan_bench/instances/blocksworld/generated_domain.pddl",
    original=True,
):
    def compute_plan(domain, instance, plan_cache_file=None):
        if not plan_cache_file:
            plan_cache_file = "sas_plan"

        fast_downward_path = "downward"
        # Remove > /dev/null to see the output of fast-downward
        assert os.path.exists(f"{fast_downward_path}/fast-downward.py")
        cmd = f'{fast_downward_path}/fast-downward.py --sas-file dummy_sas --plan-file {plan_cache_file} {domain} {instance} --search "astar(lmcut())" > /dev/null 2>&1'
        os.system(cmd)

        if not os.path.exists(plan_cache_file):
            return ""
        return Path(plan_cache_file).read_text()

    def get_plan_as_text(data, given_plan=None, plan_file=None):
        OBJS = data["encoded_objects"]
        PLAN = ""
        # print(given_plan)
        if given_plan:
            for action in given_plan:
                act_name, objs = action.split("_")[0], action.split("_")[1:]
                PLAN += "(" + act_name + " " + " ".join(objs) + ")\n"
                # PLAN += data['actions'][act_name].format(*objs) + "\n"
            return PLAN

        if not plan_file:
            plan_file = "sas_plan"
        PLAN = ""
        with open(plan_file) as f:
            plan = [line.rstrip() for line in f][:-1]

        for action in plan:
            action = action.strip("(").strip(")")
            act_name, objs = action.split(" ")[0], action.split(" ")[1:]
            PLAN += "(" + act_name + " " + " ".join(objs) + ")\n"
            # PLAN += data['actions'][act_name].format(*objs) + "\n"
        return PLAN

    def parse_problem(problem, data, shuffle):
        def parse(init_goal_preds, OBJS):
            TEXT = ""
            predicates = []

            init_goal_preds = list(init_goal_preds)
            for atom in init_goal_preds:
                objs = []
                for subterm in atom.subterms:
                    if "obfuscated" in data["domain_name"]:
                        objs.append(subterm.name.replace("o", "object_"))
                    elif "blocksworld" in data["domain_name"]:
                        objs.append(OBJS[subterm.name])
                        # print(OBJS[subterm.name])
                    elif "logistics" in data["domain_name"]:
                        obj = subterm.name
                        objs.append(
                            f"{OBJS[obj[0]].format(*[chr for chr in obj if chr.isdigit()])}"
                        )
                    elif "depots" in data["domain_name"]:
                        objs.append(subterm.name)
                    # ADD SPECIFIC TRANSLATION FOR EACH DOMAIN HERE
                try:
                    pred_string = data["predicates"][atom.symbol.name].format(*objs)
                    predicates.append(pred_string)
                except:
                    # print("[-]: Predicate not found in predicates dict: {}".format(atom.symbol.name))
                    pass

            if len(predicates) > 1:
                predicates = [item for item in predicates if item]
                TEXT += ", ".join(predicates[:-1]) + f" and {predicates[-1]}"
            else:
                TEXT += predicates[0]
            return TEXT

        OBJS = data["encoded_objects"]

        init_atoms = get_sorted(problem.init.as_atoms())
        goal_preds = (
            get_sorted(problem.goal.subformulas)
            if hasattr(problem.goal, "subformulas")
            else [problem.goal]
        )

        if shuffle:
            random.shuffle(init_atoms)
            random.shuffle(goal_preds)
        # print(shuffle,init_atoms)
        # print(goal_preds)
        # ----------- INIT STATE TO TEXT ----------- #
        INIT = parse(init_atoms, OBJS)

        # ----------- GOAL TO TEXT ----------- #
        GOAL = parse(goal_preds, OBJS)

        return INIT, GOAL

    def instance_to_text(problem, get_plan, data, plan_file=None, shuffle=False):
        """
        Function to make an instance into human-readable format
        """

        OBJS = data["encoded_objects"]

        # ----------- PARSE THE PROBLEM ----------- #
        INIT, GOAL = parse_problem(problem, data, shuffle)

        # ----------- PLAN TO TEXT ----------- #
        PLAN = ""
        if not plan_file:
            plan_file = "sas_plan"
        if get_plan:
            PLAN = "\n"
            with open(plan_file) as f:
                plan = [line.rstrip() for line in f][:-1]

            for action in plan:
                action = action.strip("(").strip(")")
                act_name, objs = action.split(" ")[0], action.split(" ")[1:]
                if "obfuscated" in data["domain_name"]:
                    objs = [j.replace("o", "object_") for j in objs]
                elif "blocksworld" in data["domain_name"]:
                    objs = [OBJS[obj] for obj in objs]
                elif "logistics" in data["domain_name"]:
                    objs = [
                        f"{OBJS[obj[0]].format(*[chr for chr in obj if chr.isdigit()])}"
                        for obj in objs
                    ]
                # elif 'depots' in data['domain_name']:  no formatting necessary
                # ADD SPECIFIC TRANSLATION FOR EACH DOMAIN HERE

                PLAN += data["actions"][act_name].format(*objs) + "\n"
            PLAN += "[PLAN END]\n"

        return INIT, GOAL, PLAN, data

    def fill_template(INIT, GOAL, PLAN, data, instruction=False):
        text = ""
        if INIT != "":
            text += "\n[STATEMENT]\n"
            text += f"As initial conditions I have that, {INIT.strip()}."
        if GOAL != "":
            text += f"\nMy goal is to have that {GOAL}."
        if not instruction:
            text += f"\n\nMy plan is as follows:\n\n[PLAN]{PLAN}"
        else:
            text += f"\n\nWhat is the plan to achieve my goal? Just give the actions in the plan."

        # TODO: Add this replacement to the yml file -- Use "Translations" dict in yml
        if "blocksworld" in data["domain_name"]:
            text = text.replace("-", " ").replace("ontable", "on the table")
        return text

    data = read_config(config_file)
    query = data["domain_intro"]
    # --------------- Read Instance --------------- #
    problem = get_problem(original_pddl, domain_file)
    # --------------------------------------------- #
    # ------------ Put plan and instance into text ------------ #
    gt_plan = compute_plan(domain_file, original_pddl, run_name)
    gt_plan_text = get_plan_as_text(data, plan_file=run_name)
    INIT, GOAL, PLAN, data = instance_to_text(problem, True, data, plan_file=run_name)
    query += fill_template(INIT, GOAL, PLAN, data)
    if original:
        os.remove(run_name)
        return query, INIT, GOAL, PLAN, data

    states, pddl_files = step(
        run_name, original_pddl, [act[1:-1] for act in gt_plan_text.split("\n")[:-1]]
    )

    init_format = "Currently, I have that, {}."
    init_after_first_format = "Currently, I no longer have: {}\nInstead, I have: {}\nAnd the following remain the same: {}"
    text_states = [init_format.format(INIT.strip())]
    prev_step_init = INIT
    for pddl_file in pddl_files:
        step_problem = get_problem(pddl_file, domain_file)
        step_init, _ = parse_problem(step_problem, data, False)
        step_init = step_init.strip()

        split_prev = prev_step_init.split(", ")
        split_prev_last_step = split_prev[-1].split(" and ")
        split_prev = split_prev[:-1] + split_prev_last_step
        split_curr = step_init.split(", ")
        split_curr_last_step = split_curr[-1].split(" and ")
        split_curr = split_curr[:-1] + split_curr_last_step

        # new_preds = [pred for pred in split_curr if not pred in split_prev]
        # deleted_preds = [pred for pred in split_prev if not pred in split_curr]
        # remaining_preds = [
        #     pred
        #     for pred in split_curr
        #     if not (pred in new_preds or pred in deleted_preds)
        # ]

        # text_states.append(
        #     init_after_first_format.format(
        #         ", ".join(deleted_preds),
        #         ", ".join(new_preds),
        #         ", ".join(remaining_preds),
        #     )
        # )
        text_states.append(init_format.format(step_init))

        os.remove(pddl_file)

    os.remove(f"{run_name}_temp_1.pddl")
    if os.path.exists(run_name):
        os.remove(run_name)

    prompt = (
        query[: query.index("\n[STATEMENT]")] + f"\nMy goal is to have that {GOAL}.\n\n"
    )
    for state_string, step_string in zip(text_states, PLAN.split("\n")[1:-2]):
        prompt += (
            state_string
            + f"\nThe goal is: {GOAL}. I have not reached it."
            + "\nMy next step is: "
            + step_string
            + "\n\n"
        )

    return (
        prompt
        + f"{text_states[-1]}\nI have reached the goal since the final state has: {GOAL}",
        INIT,
        GOAL,
        PLAN,
        data,
    )


def create_react_ex(
    run_name,
    original_pddl,
    config_file,
    domain_file="plan_bench/instances/blocksworld/generated_domain.pddl",
    original=True,
):
    def compute_plan(domain, instance, plan_cache_file=None):
        if not plan_cache_file:
            plan_cache_file = "sas_plan"

        fast_downward_path = "downward"
        sas_plan_dummy = f"{plan_cache_file}_"
        # Remove > /dev/null to see the output of fast-downward
        assert os.path.exists(f"{fast_downward_path}/fast-downward.py")
        cmd = f'{fast_downward_path}/fast-downward.py --sas-file {sas_plan_dummy} --plan-file {plan_cache_file} {domain} {instance} --search "astar(lmcut())" > /dev/null 2>&1'
        os.system(cmd)
        os.remove(sas_plan_dummy)

        if not os.path.exists(plan_cache_file):
            return ""
        return Path(plan_cache_file).read_text()

    def get_plan_as_text(data, given_plan=None, plan_file=None):
        OBJS = data["encoded_objects"]
        PLAN = ""
        # print(given_plan)
        if given_plan:
            for action in given_plan:
                act_name, objs = action.split("_")[0], action.split("_")[1:]
                PLAN += "(" + act_name + " " + " ".join(objs) + ")\n"
                # PLAN += data['actions'][act_name].format(*objs) + "\n"
            return PLAN

        if not plan_file:
            plan_file = "sas_plan"
        PLAN = ""
        with open(plan_file) as f:
            plan = [line.rstrip() for line in f][:-1]

        for action in plan:
            action = action.strip("(").strip(")")
            act_name, objs = action.split(" ")[0], action.split(" ")[1:]
            PLAN += "(" + act_name + " " + " ".join(objs) + ")\n"
            # PLAN += data['actions'][act_name].format(*objs) + "\n"
        return PLAN

    def parse_problem(problem, data, shuffle):
        def parse(init_goal_preds, OBJS):
            TEXT = ""
            predicates = []

            init_goal_preds = list(init_goal_preds)
            for atom in init_goal_preds:
                objs = []
                for subterm in atom.subterms:
                    if "obfuscated" in data["domain_name"]:
                        objs.append(subterm.name.replace("o", "object_"))
                    elif "blocksworld" in data["domain_name"]:
                        objs.append(OBJS[subterm.name])
                        # print(OBJS[subterm.name])
                    elif "logistics" in data["domain_name"]:
                        obj = subterm.name
                        objs.append(
                            f"{OBJS[obj[0]].format(*[chr for chr in obj if chr.isdigit()])}"
                        )
                    elif "depots" in data["domain_name"]:
                        objs.append(subterm.name)
                    # ADD SPECIFIC TRANSLATION FOR EACH DOMAIN HERE
                try:
                    pred_string = data["predicates"][atom.symbol.name].format(*objs)
                    predicates.append(pred_string)
                except:
                    # print("[-]: Predicate not found in predicates dict: {}".format(atom.symbol.name))
                    pass

            if len(predicates) > 1:
                predicates = [item for item in predicates if item]
                TEXT += ", ".join(predicates[:-1]) + f" and {predicates[-1]}"
            else:
                TEXT += predicates[0]
            return TEXT

        OBJS = data["encoded_objects"]

        init_atoms = get_sorted(problem.init.as_atoms())
        goal_preds = (
            get_sorted(problem.goal.subformulas)
            if hasattr(problem.goal, "subformulas")
            else [problem.goal]
        )

        if shuffle:
            random.shuffle(init_atoms)
            random.shuffle(goal_preds)
        # print(shuffle,init_atoms)
        # print(goal_preds)
        # ----------- INIT STATE TO TEXT ----------- #
        INIT = parse(init_atoms, OBJS)

        # ----------- GOAL TO TEXT ----------- #
        GOAL = parse(goal_preds, OBJS)

        return INIT, GOAL

    def instance_to_text(problem, get_plan, data, plan_file=None, shuffle=False):
        """
        Function to make an instance into human-readable format
        """

        OBJS = data["encoded_objects"]

        # ----------- PARSE THE PROBLEM ----------- #
        INIT, GOAL = parse_problem(problem, data, shuffle)

        # ----------- PLAN TO TEXT ----------- #
        PLAN = ""
        if not plan_file:
            plan_file = "sas_plan"
        if get_plan:
            PLAN = "\n"
            with open(plan_file) as f:
                plan = [line.rstrip() for line in f][:-1]

            for action in plan:
                action = action.strip("(").strip(")")
                act_name, objs = action.split(" ")[0], action.split(" ")[1:]
                if "obfuscated" in data["domain_name"]:
                    objs = [j.replace("o", "object_") for j in objs]
                elif "blocksworld" in data["domain_name"]:
                    objs = [OBJS[obj] for obj in objs]
                elif "logistics" in data["domain_name"]:
                    objs = [
                        f"{OBJS[obj[0]].format(*[chr for chr in obj if chr.isdigit()])}"
                        for obj in objs
                    ]
                # elif 'depots' in data['domain_name']:  no formatting necessary
                # ADD SPECIFIC TRANSLATION FOR EACH DOMAIN HERE

                PLAN += data["actions"][act_name].format(*objs) + "\n"
            PLAN += "[PLAN END]\n"

        return INIT, GOAL, PLAN, data

    def fill_template(INIT, GOAL, PLAN, data, instruction=False):
        text = ""
        if INIT != "":
            text += "\n[STATEMENT]\n"
            text += f"As initial conditions I have that, {INIT.strip()}."
        if GOAL != "":
            text += f"\nMy goal is to have that {GOAL}."
        if not instruction:
            text += f"\n\nMy plan is as follows:\n\n[PLAN]{PLAN}"
        else:
            text += f"\n\nWhat is the plan to achieve my goal? Just give the actions in the plan."

        # TODO: Add this replacement to the yml file -- Use "Translations" dict in yml
        if "blocksworld" in data["domain_name"]:
            text = text.replace("-", " ").replace("ontable", "on the table")
        return text

    data = read_config(config_file)
    query = data["domain_intro"]
    # --------------- Read Instance --------------- #
    problem = get_problem(original_pddl, domain_file)
    # --------------------------------------------- #
    # ------------ Put plan and instance into text ------------ #
    gt_plan = compute_plan(domain_file, original_pddl, run_name)
    gt_plan_text = get_plan_as_text(data, plan_file=run_name)
    INIT, GOAL, PLAN, data = instance_to_text(problem, True, data, plan_file=run_name)
    query += fill_template(INIT, GOAL, PLAN, data)
    if original:
        os.remove(run_name)
        return query, INIT, GOAL, PLAN, data

    states, pddl_files = step(
        run_name, original_pddl, [act[1:-1] for act in gt_plan_text.split("\n")[:-1]]
    )

    init_format = "[State {}] Currently, {}."
    text_states = [init_format.format(1, INIT.strip())]
    prev_step_init = INIT

    goal_split = GOAL.split(", ")
    goal_split_last_step = goal_split[-1].split(" and ")
    goal_split = goal_split[:-1] + goal_split_last_step

    split_curr = INIT.split(", ")
    split_curr_last_step = split_curr[-1].split(" and ")
    split_curr = split_curr[:-1] + split_curr_last_step
    state_dicts = [{k: k in split_curr for k in goal_split}]

    for i, pddl_file in enumerate(pddl_files, 1):
        step_problem = get_problem(pddl_file, domain_file)
        step_init, _ = parse_problem(step_problem, data, False)
        step_init = step_init.strip()

        # split_prev = prev_step_init.split(", ")
        # split_prev_last_step = split_prev[-1].split(" and ")
        # split_prev = split_prev[:-1] + split_prev_last_step
        split_curr = step_init.split(", ")
        split_curr_last_step = split_curr[-1].split(" and ")
        split_curr = split_curr[:-1] + split_curr_last_step

        goal_state_dict_step = {k: k in split_curr for k in goal_split}
        text_states.append(init_format.format(i, step_init))
        state_dicts.append(goal_state_dict_step)

        os.remove(pddl_file)

    os.remove(f"{run_name}_temp_1.pddl")
    if os.path.exists(run_name):
        os.remove(run_name)

    prompt = (
        query[: query.index("\n[STATEMENT]")]
        + f"\nThe goal is to have that {GOAL}.\n\n"
    )
    message_chain = []
    for state_string, state_dict, step_string in zip(
        text_states, state_dicts, PLAN.split("\n")[1:-2]
    ):
        reached_preds = "\n".join(s for s, reached in state_dict.items() if reached)
        state_string_formatted = (
            f"You have reached the following parts of the goal:\n{reached_preds}\n"
            if len(reached_preds)
            else ""
        )
        not_reached_preds = "\n".join(
            s for s, reached in state_dict.items() if not reached
        )
        state_string_formatted += f"You have not reached the following parts of the goal:\n{not_reached_preds}"

        message_chain += [
            {"role": "user", "content": f"{state_string}\n{state_string_formatted}"},
            {
                "role": "assistant",
                "content": f"My next step is: {step_string}",
            },
        ]

    # print(text_states)
    reached_preds = "\n".join(s for s in state_dict.keys())
    message_chain += [
        {"role": "user", "content": text_states[-1]},
        {
            "role": "assistant",
            "content": f"You have reached the following parts of the goal:\n{reached_preds}\nYou am done.",
        },
    ]

    return (
        prompt + f"\nMy next step is:",
        message_chain,
        INIT,
        GOAL,
        PLAN,
        data,
    )


def parse_react_response(
    state_counter,
    response,
    problem_file,
    working_file,
    run_name,
    config_file,
    message_list=None,
    prev_feedbacks=None,
    domain_file="plan_bench/instances/blocksworld/generated_domain.pddl",
):
    data = read_config(config_file)
    feedback = {
        "malformed": True,
        "raw_response": response,
        "parsed_action": None,
        "pddl_action": None,
        "bad_action": True,
        "next_state_pddl": None,
        "next_state_text": "",
        "reached_goal": False,
        "goal_pred_dict": None,  # is a string
        "is_jump": False,
    }

    # def parse_step(s):
    #     table_ind = s.find('table')
    #     pick_up_ind = s.find('pick up the')
    #     put_down_ind = s.find('put down the')
    #     last_block_ind = s.rfind('block')

    #     head_ind = pick_up_ind if pick_up_ind != -1 else put_down_ind
    #     tail_ind = table_ind + len('table') if table_ind != -1 else last_block_ind + len('block')

    #     if head_ind == -1 or tail_ind == 4:
    #         return None

    #     return s[head_ind : tail_ind]

    problem = get_problem(problem_file, domain_file)
    step_header = "step is: "
    state_counter_change = 0
    if step_header in response:
        step = response[response.index(step_header) + len(step_header) :]
        step = step.replace(".", "")
        feedback["parsed_action"] = step
        if "return to [State " in step:
            feedback["is_jump"] = True
            feedback["malformed"] = False
            try:
                return_state_num = step[
                    step.index("return to [State ")
                    + len("return to [State ") : step.rfind("]")
                ]
                feedback["next_state_pddl"] = list(
                    filter(
                        lambda f: "next_state_text" in f
                        and f"[State {return_state_num}]" in f["next_state_text"],
                        prev_feedbacks,
                    )
                )[0]["next_state_pddl"]
                feedback["next_state_text"] = list(
                    filter(
                        lambda info: f"[State {return_state_num}]" in info["content"],
                        message_list,
                    )
                )[1]["content"]
                feedback["bad_action"] = False

                pddl_llm_state = text_state_to_pddl(
                    feedback["next_state_text"][
                        feedback["next_state_text"].index("] Currently, ")
                        + len("] Currently, ") :
                    ],
                    config_file,
                )
                pddl_llm_state = [
                    "(" + s.replace("_", " ") + ")" for s in pddl_llm_state
                ]
                with open(working_file, "r") as f:
                    existing_file_states = f.read()
                _, head, tail = filter_state_from_pddl(existing_file_states)
                new_file_states = (
                    existing_file_states[:head]
                    + "\n".join(pddl_llm_state)
                    + existing_file_states[tail:]
                )
                with open(working_file, "w") as f:
                    f.write(new_file_states)

            except IndexError:
                pass

            return feedback, 0

        try:
            llm_plan, _ = text_to_plan_blocksworld(
                reformat_response(step + "\n[PLAN END]"),
                problem.actions,
                f"{run_name}_temp.txt",
                data,
            )
        except Exception as e:
            print()
            print(llm_plan)
            print()
            raise e

        feedback["malformed"] = False
        os.remove(f"{run_name}_temp.txt")
        feedback["pddl_action"] = [s[1:-1] for s in llm_plan.split("\n")[:-1]][0]

        with open(working_file, "r") as f:
            prev_state_str = f.read()
        prev_states, head, tail = filter_state_from_pddl(prev_state_str)
        next_states = os.popen(
            f'lifted_pddl get_next_state -d {domain_file} -p {working_file} -a "{feedback["pddl_action"]}"'
        ).read()
        next_states = next_states.split("\n")[:-1]

        if sorted(next_states) != sorted(prev_states):
            feedback["bad_action"] = False
            feedback["next_state_pddl"] = next_states
            new_state_str = (
                prev_state_str[:head] + "\n".join(next_states) + prev_state_str[tail:]
            )
            with open(working_file, "w") as f:
                f.write(new_state_str)
            new_problem = get_problem(working_file, domain_file)
            state_init, GOAL = parse_problem(new_problem, data, False)
            feedback["next_state_text"] = state_init.strip()

            goal_predicates = filter_state(working_file)
            feedback["reached_goal"] = all(
                pred in next_states for pred in goal_predicates
            )

            split_next_state_text = feedback["next_state_text"].split(", ")
            split_next_state_last_step = split_next_state_text[-1].split(" and ")
            split_next_state_text = (
                split_next_state_text[:-1] + split_next_state_last_step
            )

            goal_split = GOAL.split(", ")
            goal_split_last_step = goal_split[-1].split(" and ")
            goal_split = goal_split[:-1] + goal_split_last_step

            reached = [pred for pred in goal_split if pred in split_next_state_text]
            not_reached = [pred for pred in goal_split if not pred in reached]
            newline = "\n"
            reached_str = (
                f"You have reached the following parts of the goal:\n{newline.join(reached)}\n"
                if len(reached)
                else ""
            )
            not_reached_str = (
                f"You have not reached the following parts of the goal:\n{newline.join(not_reached)}"
                if len(not_reached)
                else ""
            )
            feedback["goal_pred_dict"] = reached_str + not_reached_str

            feedback["next_state_text"] = (
                f"[State {state_counter + 1}] Currently, " + feedback["next_state_text"]
            )
            state_counter_change = 1

    return feedback, state_counter_change


def split_prompt_new(s):
    my_goal_ind = s.index("My goal is") - 1
    context = s[: my_goal_ind + 1]
    my_next_step = s.index("My next step") - 1
    instance = s[my_goal_ind:my_next_step]
    plan = s[my_next_step:]

    return context.strip(), instance.strip(), plan.strip()


def split_prompt_new_the(s):
    my_goal_ind = s.index("The goal is") - 1
    context = s[: my_goal_ind + 1]
    my_next_step = s.index("My next step") - 1
    instance = s[my_goal_ind:my_next_step]
    plan = s[my_next_step:]

    return context.strip(), instance.strip(), plan.strip()


def split_prompt_old(s):
    statement_ind = s.index("[STATEMENT]")
    context = s[:statement_ind]
    plan_ind = s.index("[PLAN]") + len("[PLAN]")
    instance = s[statement_ind:plan_ind]
    plan = s[plan_ind:]

    return context.strip(), instance.strip(), plan.strip()


def reformat_response(response):
    response_split = response.split("\n")
    new_responses = []
    for response_portion in response_split:
        if "on top of the" in response_portion:
            if "from" in response_portion:
                new_responses.append(response_portion + " unstack")
            else:
                new_responses.append(response_portion + " stack")
        else:
            new_responses.append(response_portion)

    return "\n".join(new_responses)


def evaluate_new(
    run_name,
    actual_plan,
    response,
    data_true_plan,
    problem_file,
    config_file,
    domain_file="plan_bench/instances/blocksworld/generated_domain.pddl",
):
    split_up = [s for s in response.split("\n") if len(s)]
    steps, next_states = [], []
    state_header_len_1 = len("Instead, I have: ")
    state_header_len_2 = len("And the following remain the same: ")
    state_buffer = []
    for s in split_up:
        if s.startswith("My next step is: "):
            steps.append(s[len("My next step is: ") :])
            # if "put" in s:
            #     steps.append(s[s.index("put") :])
            # elif "pick" in s:
            #     steps.append(s[s.index("pick") :])

        # elif s.startswith("Instead, I have: "):
        #     state_buffer.append(s[state_header_len_1:])
        # elif s.startswith("And the following remain the same: "):
        #     state_buffer.append(s[state_header_len_2:])
        #     all_but_last = ", ".join(state_buffer[:-1])
        #     if len(all_but_last):
        #         next_state = all_but_last + f" and {state_buffer[-1]}"
        #     else:
        #         next_state = state_buffer[0]
        #     next_states.append(next_state)
        #     state_buffer = []
        elif s.startswith("Currently, I have that, "):
            next_states.append(s[len("Currently, I have that, ") :])

    problem = get_problem(problem_file, domain_file)
    true_plan, _ = text_to_plan_blocksworld(
        reformat_response(actual_plan.strip()),
        problem.actions,
        f"{run_name}_temp.txt",
        data_true_plan,
    )
    os.remove(f"{run_name}_temp.txt")

    output_template = {
        "malformed": True,
        "issue_step_num": -1,
        "issue_state_num": -1,
        "steps_taken_llm": -1,
        "steps_taken_true": len(true_plan.split("\n")) - 1,
        "reached_goal": False,
        "raw_response": response,
        "parsed_steps": None,
        "parsed_states": [],
        "llm_induced_states": None,
    }

    if len(next_states) != len(steps) or len(steps) == 0:
        return output_template  # malformed response

    try:
        output_template["malformed"] = False
        problem = get_problem(problem_file, domain_file)
        llm_plan, _ = text_to_plan_blocksworld(
            reformat_response("\n".join(steps) + "\n[PLAN END]"),
            problem.actions,
            f"{run_name}_temp.txt",
            data_true_plan,
        )
        steps = [s[1:-1] for s in llm_plan.split("\n")[:-1]]
        output_template["parsed_steps"] = steps

        new_states, working_files = step(run_name, problem_file, steps)
        output_template["steps_taken_llm"] = len(new_states)
        output_template["llm_induced_states"] = new_states
        if "issue" in new_states:  # took illegal action
            output_template["issue_step_num"] = new_states.index("issue") + 1
        for working_file in working_files:
            os.remove(working_file)
        os.remove(f"{run_name}_temp_1.pddl")

        # check if LLM predicted states are wrong
        for i, (llm_state, actual_state) in enumerate(zip(next_states, new_states), 1):
            pddl_llm_state = sorted(text_state_to_pddl(llm_state, config_file))
            output_template["parsed_states"].append(pddl_llm_state)
            if (
                sorted([s[1:-1].replace(" ", "_") for s in actual_state])
                != pddl_llm_state
            ):
                output_template["issue_state_num"] = (
                    i
                    if output_template["issue_state_num"] == -1
                    else min(i, output_template["issue_state_num"])
                )
        for llm_state in next_states[len(output_template["parsed_states"]) :]:
            output_template["parsed_states"].append(
                sorted(text_state_to_pddl(llm_state, config_file))
            )

        output_template["reached_goal"] = validate_plan(
            domain_file, problem_file, f"{run_name}_temp.txt"
        )
        os.remove(f"{run_name}_temp.txt")
        return output_template
    except KeyboardInterrupt:
        sys.exit("KeyboardInterrupt")
    except Exception as e:
        print(e)
        return output_template


def evaluate_old(
    run_name,
    actual_plan,
    response,
    data_true_plan,
    problem_file,
    config_file,
    domain_file="plan_bench/instances/blocksworld/generated_domain.pddl",
):
    problem = get_problem(problem_file, domain_file)
    true_plan, _ = text_to_plan_blocksworld(
        reformat_response(actual_plan.strip()),
        problem.actions,
        f"{run_name}_temp.txt",
        data_true_plan,
    )
    os.remove(f"{run_name}_temp.txt")

    output_template = {
        "issue_step_num": -1,
        "steps_taken_llm": -1,
        "steps_taken_true": len(true_plan.split("\n")) - 1,
        "reached_goal": False,
        "true_plan": true_plan,
        "raw_response": response,
        "parsed_steps": None,
        "llm_induced_states": None,
    }
    response = reformat_response(response)
    try:
        llm_plan, _ = text_to_plan_blocksworld(
            response, problem.actions, f"{run_name}_temp.txt", data_true_plan
        )
        steps = [s[1:-1] for s in llm_plan.split("\n")[:-1]]
        output_template["parsed_steps"] = steps
        new_states, working_files = step(run_name, problem_file, steps)
        output_template["steps_taken_llm"] = len(new_states)
        output_template["llm_induced_states"] = new_states
        if "issue" in new_states:  # took illegal action
            output_template["issue_step_num"] = new_states.index("issue") + 1
        for working_file in working_files:
            os.remove(working_file)
        os.remove(f"{run_name}_temp_1.pddl")

        output_template["reached_goal"] = validate_plan(
            domain_file, problem_file, f"{run_name}_temp.txt"
        )
        os.remove(f"{run_name}_temp.txt")

        return output_template
    except KeyboardInterrupt:
        sys.exit("KeyboardInterrupt")
    except Exception as e:
        print(e)
        return output_template


def tryint(s):
    try:
        return int(s)
    except:
        return s


def alphanum_key(s):
    """Turn a string into a list of string and number chunks.
    "z23a" -> ["z", 23, "a"]
    """
    return [tryint(c) for c in re.split("([0-9]+)", s)]


def sort_nicely(l):
    """Sort the given list in the way that humans expect."""
    l.sort(key=alphanum_key)
