from subgoal_graph import build_tree_with_treelib, postorder_traversal, preorder_traversal


def propagate_known_values(lists):
    # Reverse iterate over the list of lists, skipping the last one
    for i in range(len(lists) - 2, -1, -1):
        for j in range(len(lists[i])):
            # If current position is unknown ('x')
            if lists[i][j] == 'XXX':
                # Check the next row for a known value at the same index
                if j < len(lists[i + 1]) and lists[i + 1][j] != 'x':
                    lists[i][j] = lists[i + 1][j]
    return lists

def create_string(entry):
    name = str(entry)
    info = []
    if "PickupSubgoal" in name:
        action = "Pickup"
    elif "DropSubgoal" in name:
        action = "Drop"
    elif "GoNextTo" in name:
        action = "Go next to"
    elif "OpenSubgoal" in name:
        action = "Open"
    else:
        action = f"Not implemented {name}"

    info.append(action)
    print(entry)
    if action == "Pickup" or action == "Open":
        obj = f"{entry['fwd']['color']} {entry['fwd']['type']}"
        info.append(obj)
    elif action == "Drop":
        obj = f"{entry['carrying']['color']} {entry['carrying']['type']}"
        info.append(obj)
    return " ".join(info)

def filter_go_next_to(entries):
    return [[x for x in entry if "Go next to" not in str(x)] for entry in entries]


def remove_consecutive_duplicate_sequences(sequences):
    cleaned_sequences = []
    prev_sequence = None

    for sequence in sequences:
        if sequence != prev_sequence:
            cleaned_sequences.append(sequence)
        prev_sequence = sequence

    return cleaned_sequences


def entry_went_away(stacks_str):
    entries = []
    for i in range(len(stacks_str)):
        if i == len(stacks_str) - 1 or len(stacks_str[i]) != len(stacks_str[i + 1]):
            a = ["XXX" for _ in range(len(stacks_str[i])-1)]
            a.append(create_string(stacks_str[i][-1]))
            entries.append(a)

            print("Taking:", stacks_str[i][-1])

    entries = propagate_known_values(entries)
    entries = filter_go_next_to(entries)
    entries = remove_consecutive_duplicate_sequences(entries)
    for i in entries:
        print(i)

    tree = build_tree_with_treelib(entries)
    print(tree.show(stdout=False, key=False))
    for node in postorder_traversal(tree):
        print(node.tag)


if __name__ == "__main__":
    import json
    with open("stack_str.json", "r") as f:
        stacks_str = json.load(f)
        entry_went_away(stacks_str)