import random


def get_sorted(init_atoms):
    return sorted(
        init_atoms,
        key=lambda x: x.symbol.name
        + " "
        + " ".join([subterm.name for subterm in x.subterms]),
    )


def parse_problem(problem, data, shuffle):
    def parse(init_goal_preds, OBJS):
        TEXT = ""
        predicates = []

        init_goal_preds = list(init_goal_preds)
        for atom in init_goal_preds:
            objs = []
            for subterm in atom.subterms:
                if "obfuscated" in data["domain_name"]:
                    objs.append(subterm.name.replace("o", "object_"))
                elif "blocksworld" in data["domain_name"]:
                    objs.append(OBJS[subterm.name])
                elif "logistics" in data["domain_name"]:
                    obj = subterm.name
                    objs.append(
                        f"{OBJS[obj[0]].format(*[chr for chr in obj if chr.isdigit()])}"
                    )
                elif "depots" in data["domain_name"]:
                    objs.append(subterm.name)
                # ADD SPECIFIC TRANSLATION FOR EACH DOMAIN HERE
            try:
                pred_string = data["predicates"][atom.symbol.name].format(*objs)
                predicates.append(pred_string)
            except:
                # print("[-]: Predicate not found in predicates dict: {}".format(atom.symbol.name))
                pass

        if len(predicates) > 1:
            predicates = [item for item in predicates if item]
            TEXT += ", ".join(predicates[:-1]) + f" and {predicates[-1]}"
        else:
            TEXT += predicates[0]
        return TEXT

    OBJS = data["encoded_objects"]

    init_atoms = get_sorted(problem.init.as_atoms())
    goal_preds = (
        get_sorted(problem.goal.subformulas)
        if hasattr(problem.goal, "subformulas")
        else [problem.goal]
    )

    if shuffle:
        random.shuffle(init_atoms)
        random.shuffle(goal_preds)

    # ----------- INIT STATE TO TEXT ----------- #
    INIT = parse(init_atoms, OBJS)

    # ----------- GOAL TO TEXT ----------- #
    GOAL = parse(goal_preds, OBJS)

    return INIT, GOAL, init_atoms, goal_preds


def fill_template(
    INIT,
    GOAL,
    PLAN,
    data,
    instruction=False,
    backward=False,
    flip=False,
    current=False,
    head_only=False,
):
    assert not (backward and flip)

    text = ""
    if INIT != "":
        if current:
            text += "\n[[CURRENT PROBLEM]]\n\n[STATEMENT]\n"
        else:
            text += "\n[[EXAMPLE]]\n\n[STATEMENT]\n"
        if head_only:
            text += f"As initial conditions you have that, {INIT.strip()}."
        else:
            text += f"{INIT.strip()}"

    if GOAL != "" and head_only:
        text += f"\nYour goal is to have that {GOAL}."
    if not instruction:
        if backward:
            if current:
                text += f"\n\nPlease generate your reversed plan for the current problem by following the exact same format as the plan in the example."
            else:
                text += (
                    f"\n\nYour reversed plan is as follows:\n\n[REVERSED PLAN]{PLAN}"
                )
        elif flip:
            if current:
                if head_only:
                    text += f"\n\nPlease follow the format and generate the init state, goal, possible full final state, new init state, and new goal for the current problem. Choose the full final state that is the easiest to reach from the init state. Make sure the stacks are combined if they should, e.g., 'red on blue; yellow on red' should be combined as 'yellow on red on blue'."
                else:
                    text += f"\n\nPlease follow the format and generate your plan for the current problem."
            else:
                if head_only:
                    text += f"\n\nFirst you summarize the init state, goal, and possible full final state, and then flip the problem:\n\n[PLAN]{PLAN}"
                else:
                    # text += f"\n\nYou can flip the problem first, and then the plan is as follows:\n\n[PLAN]{PLAN}"
                    text += f"\n\nYour plan is as follows:\n\n[PLAN]{PLAN}"
        else:
            if current:
                if head_only:
                    text += f"\n\nPlease follow the format and generate the init state and goal for the current problem. Make sure the stacks are combined if they should, e.g., 'red on blue; yellow on red' should be combined as 'yellow on red on blue'."
                else:
                    text += f"\n\nPlease follow the format and generate your plan for the current problem."
            else:
                if head_only:
                    text += f"\n\nFirst you summarize the init state and goal:\n\n[PLAN]{PLAN}"
                else:
                    text += f"\n\nYour plan is as follows:\n\n[PLAN]{PLAN}"
    else:
        text += f"\n\nWhat is the plan to achieve my goal? Just give the actions in the plan."

    # TODO: Add this replacement to the yml file -- Use "Translations" dict in yml
    if "blocksworld" in data["domain_name"]:
        text = text.replace("-", " ").replace("ontable", "on the table")
    return text


def instance_to_text(
    problem,
    get_plan,
    data,
    shuffle=False,
    add_init=False,
    add_goal=False,
    add_step=False,
    flip=False,
    backward=False,
    head_only=False,
):
    """
    Function to make an instance into human-readable format
    """
    assert not (flip and backward)

    OBJS = data["encoded_objects"]

    # ----------- PARSE THE PROBLEM ----------- #
    INIT, GOAL, init_atoms, goal_preds = parse_problem(problem, data, shuffle)

    # ----------- PLAN TO TEXT ----------- #
    PLAN = "\n"
    plan_file = "sas_plan"
    plan_lines = []
    if get_plan:

        seqs = [["h"]]  # hand
        seqs_rest = []
        for atom in init_atoms:
            atom = str(atom)
            bottom = atom.split("(")[1][0]
            if "ontable" in atom:
                seqs_rest.append([bottom])
        for seq in seqs_rest:
            bottom = seq[0]
            while 1:
                found = False
                for atom in init_atoms:
                    atom = str(atom)
                    if f",{bottom})" in atom:
                        top = atom.split(f",{bottom})")[0][-1]
                        seq.append(top)
                        bottom = top
                        found = True
                if not found:
                    break
        seqs += seqs_rest
        # hand always empty at init and end!
        init_line = "init state (each clause is a stack): "
        for seq_ind, seq in enumerate(seqs_rest):
            colors = [OBJS[obj].split(" ")[0] for obj in seq]
            colors = colors[::-1]
            init_line += " on ".join(colors)
            if seq_ind != len(seqs_rest) - 1:
                init_line += "; "
        init_line += "\n"

        # the issue is finding the initial seqs for flipped, since we don't know the full goal state here, but only the partial one
        goal_line = "goal: "
        final_seqs = []
        for goal in goal_preds:
            if "on(" in str(goal):
                bottom = str(goal).split(",")[1][0]
                top = str(goal).split("(")[1][0]
                if not (
                    any(bottom == seq[-1] for seq in final_seqs)
                    or any(top == seq[0] for seq in final_seqs)
                ):
                    final_seqs.append([bottom, top])
                else:
                    for seq_ind, seq in enumerate(final_seqs):
                        if seq[-1] == bottom:
                            seq.append(top)
                            for seq_1 in final_seqs:
                                if seq_1[0] == top:
                                    final_seqs[seq_ind] = seq + seq_1[1:]
                                    final_seqs.remove(seq_1)
                        elif seq[0] == top:
                            seq.insert(0, bottom)
                            for seq_1 in final_seqs:
                                if seq_1[-1] == bottom:
                                    final_seqs[seq_ind] = seq_1[:-1] + seq
                                    final_seqs.remove(seq_1)
            else:
                raise f"Unexpected goal predicate: {goal}!"
        for seq in final_seqs:
            colors = [OBJS[obj].split(" ")[0] for obj in seq]
            colors = colors[::-1]
            goal_line += " on ".join(colors)
            if seq != final_seqs[-1]:
                goal_line += "; "
        goal_line += "\n"

        with open(plan_file) as f:
            plan = [line.rstrip() for line in f][:-1]

        for action in plan:
            action = action.strip("(").strip(")")
            act_name, objs = action.split(" ")[0], action.split(" ")[1:]
            if "obfuscated" in data["domain_name"]:
                objs = [j.replace("o", "object_") for j in objs]
            elif "blocksworld" in data["domain_name"]:
                objs = [OBJS[obj] for obj in objs]
            elif "logistics" in data["domain_name"]:
                objs = [
                    f"{OBJS[obj[0]].format(*[chr for chr in obj if chr.isdigit()])}"
                    for obj in objs
                ]
            # elif 'depots' in data['domain_name']:  no formatting necessary
            # ADD SPECIFIC TRANSLATION FOR EACH DOMAIN HERE
            plan_line = data["actions"][act_name].format(*objs) + "\n"
            plan_lines.append(plan_line)
            if add_step:  # assume correct plan
                # e.g., (unstack b c) (put-down b) (pick-up c) (stack c b)
                if "unstack" in action:
                    unstack = action.split(" ")[1]
                    for seq in seqs:
                        if seq[-1] == unstack:
                            seq.remove(unstack)
                    seqs[0].append(unstack)  # add to hand
                elif "put-down" in action:
                    put_down = action.split(" ")[1]
                    seqs[0].remove(put_down)
                    seqs.append([put_down])
                elif "pick-up" in action:
                    pick_up = action.split(" ")[1]
                    for seq in seqs:
                        if seq[-1] == pick_up:
                            assert len(seq) == 1
                            seq.remove(pick_up)
                    seqs[0].append(pick_up)  # add to hand
                elif "stack" in action:
                    stack_top = action.split(" ")[1]
                    stack_bottom = action.split(" ")[2]
                    for seq in seqs:
                        if seq[-1] == stack_bottom:
                            seq.append(stack_top)
                    seqs[0].remove(stack_top)
                else:
                    raise f"Unexpected action: {action}!"
                # remove empty seqs
                seqs = [seq for seq in seqs if len(seq) > 0]

                # add
                plan_line = "("
                for seq_ind, seq in enumerate(seqs):
                    if seq[0] != "h":
                        colors = [OBJS[obj].split(" ")[0] for obj in seq]
                        colors = colors[::-1]
                        plan_line += " on ".join(colors)
                        if seq_ind != len(seqs) - 1:
                            plan_line += "; "
                    else:
                        assert len(seq) <= 2
                        if len(seq) == 2:
                            plan_line += f"{OBJS[seq[1]].split(' ')[0]} on hand; "
                plan_line += ")\n"
                plan_lines.append(plan_line)

        # redo... by with the correct seqs
        if flip:
            plan_lines = []  # overwrite

            # convert seqs to new init state
            new_init_line = ""
            for seq_ind, seq in enumerate(seqs):
                if seq[0] != "h":
                    colors = [OBJS[obj].split(" ")[0] for obj in seq]
                    colors = colors[::-1]
                    new_init_line += " on ".join(colors)
                    if seq_ind != len(seqs) - 1:
                        new_init_line += "; "

            # use seqs as the initial seqs...
            # redo - revert plan
            for action in reversed(plan):
                action = action.strip("(").strip(")")
                act_name, objs = action.split(" ")[0], action.split(" ")[1:]
                if "obfuscated" in data["domain_name"]:
                    objs = [j.replace("o", "object_") for j in objs]
                elif "blocksworld" in data["domain_name"]:
                    objs = [OBJS[obj] for obj in objs]
                elif "logistics" in data["domain_name"]:
                    objs = [
                        f"{OBJS[obj[0]].format(*[chr for chr in obj if chr.isdigit()])}"
                        for obj in objs
                    ]
                # elif 'depots' in data['domain_name']:  no formatting necessary
                # ADD SPECIFIC TRANSLATION FOR EACH DOMAIN HERE
                plan_line = data["revert_actions"][act_name].format(*objs) + "\n"
                plan_lines.append(plan_line)
                if add_step:  # assume correct plan
                    # e.g., (unstack b c) (put-down b) (pick-up c) (stack c b)
                    if "unstack" in action:  # always check unstack before stack
                        stack_top = action.split(" ")[1]
                        stack_bottom = action.split(" ")[2]
                        for seq in seqs:
                            if seq[-1] == stack_bottom:
                                seq.append(stack_top)
                        seqs[0].remove(stack_top)
                    elif "stack" in action:  # flipped again
                        unstack = action.split(" ")[1]
                        for seq in seqs:
                            if seq[-1] == unstack:
                                seq.remove(unstack)
                        seqs[0].append(unstack)  # add to hand
                    elif "pick-up" in action:
                        put_down = action.split(" ")[1]
                        seqs[0].remove(put_down)
                        seqs.append([put_down])
                    elif "put-down" in action:
                        pick_up = action.split(" ")[1]
                        for seq in seqs:
                            if seq[-1] == pick_up:
                                assert len(seq) == 1
                                seq.remove(pick_up)
                        seqs[0].append(pick_up)  # add to hand
                    else:
                        raise f"Unexpected action: {action}!"
                    # remove empty seqs
                    seqs = [seq for seq in seqs if len(seq) > 0]

                    # add
                    plan_line = "("
                    for seq_ind, seq in enumerate(seqs):
                        if seq[0] != "h":
                            colors = [OBJS[obj].split(" ")[0] for obj in seq]
                            colors = colors[::-1]
                            plan_line += " on ".join(colors)
                            if seq_ind != len(seqs) - 1:
                                plan_line += "; "
                        else:
                            assert len(seq) <= 2
                            if len(seq) == 2:
                                plan_line += f"{OBJS[seq[1]].split(' ')[0]} on hand; "
                    plan_line += ")\n"
                    plan_lines.append(plan_line)
        if backward:
            if add_goal:
                PLAN += goal_line
                PLAN += "(suppose the final state is: " + plan_lines[-1][1:]
            if add_init:
                PLAN += init_line.replace("init state", "trying to reach init state")
        elif flip and head_only:
            PLAN += init_line
            PLAN += goal_line
            PLAN += "suppose one possible full final state is: " + new_init_line + "\n"
            # PLAN += "==== now flip the problem and then solve the new problem ====\n"
            PLAN += "==== now flip the problem ====\n"
            PLAN += (
                "use the original full final state as the new init state: "
                + new_init_line
                + "\n"
            )
            PLAN += (
                "use the original init state as the new goal state: "
                + init_line.split(": ")[1][:-1]
                + "\n"
            )
        else:
            if add_init:
                PLAN += init_line
            if add_goal:
                PLAN += goal_line
        if backward:
            if add_step:
                # reverse every pair of lines
                new_plan_lines = []
                for i in range(len(plan_lines), 0, -2):
                    new_plan_lines.append(plan_lines[i - 2])
                    if i - 3 >= 0:
                        new_plan_lines.append(plan_lines[i - 3])
                    else:
                        init = "(" + init_line.split(": ")[1].split("\n")[0] + ")\n"
                        new_plan_lines.append(init)
                plan_lines = new_plan_lines
                plan_lines[-1] = plan_lines[-1][:-1] + " Matching init state\n"
            else:
                plan_lines = plan_lines[::-1]
        else:
            if add_step and not head_only:
                plan_lines[-1] = plan_lines[-1][:-1] + " Goal satisfied\n"
        if not head_only:
            PLAN += "".join(plan_lines)
        if backward:
            PLAN += "[REVERSED PLAN END]\n"
        else:  # also for flip
            PLAN += "[PLAN END]\n"

        # overwrite header state
        if not head_only:
            if flip:
                INIT = (
                    "init state (each clause is a stack): "
                    + new_init_line
                    + "\n"
                    + "goal: "
                    + init_line.split(": ")[1][:-2]
                )
            else:
                INIT = (
                    "init state (each clause is a stack): "
                    + init_line.split(": ")[1][:-1]
                    + "\n"
                    + "goal: "
                    + goal_line.split(": ")[1][:-1]
                )

    return INIT, GOAL, PLAN, data


def get_plan_as_text(data, given_plan=None):
    OBJS = data["encoded_objects"]
    PLAN = ""
    # print(given_plan)
    if given_plan:
        for action in given_plan:
            act_name, objs = action.split("_")[0], action.split("_")[1:]
            PLAN += "(" + act_name + " " + " ".join(objs) + ")\n"
            # PLAN += data['actions'][act_name].format(*objs) + "\n"
        return PLAN

    plan_file = "sas_plan"
    PLAN = ""
    with open(plan_file) as f:
        plan = [line.rstrip() for line in f][:-1]

    for action in plan:
        action = action.strip("(").strip(")")
        act_name, objs = action.split(" ")[0], action.split(" ")[1:]
        PLAN += "(" + act_name + " " + " ".join(objs) + ")\n"
        # PLAN += data['actions'][act_name].format(*objs) + "\n"
    return PLAN
