import json
from copy import deepcopy

import numpy
from treelib import Tree

from subgoal_graph import build_tree_with_treelib, postorder_traversal, preorder_traversal


def propagate_known_values(lists):
    # Reverse iterate over the list of lists, skipping the last one
    for i in range(len(lists) - 2, -1, -1):
        for j in range(len(lists[i])):
            # If current position is unknown ('x')
            if lists[i][j] == 'XXX':
                # Check the next row for a known value at the same index
                if j < len(lists[i + 1]) and lists[i + 1][j] != 'x':
                    lists[i][j] = lists[i + 1][j]
    return lists

def create_string(entry):
    name = str(entry)
    info = []
    if "PickupSubgoal" in name:
        action = "Pickup"
    elif "DropSubgoal" in name:
        action = "Drop"
    elif "GoNextTo" in name:
        action = "Go next to"
    elif "OpenSubgoal" in name:
        action = "Open"
    elif "CloseSubgoal" in name:
        action = "Close"
    elif "FindDropLocationSubgoal" in name:
        action = "Go somewhere else"
    else:
        action = f"Not implemented {name}"

    info.append(action)
    # print(entry)
    if action == "Pickup" or action == "Open" or action == "Close":
        obj = f"{entry['fwd']['color']} {entry['fwd']['type']}"
        info.append(obj)
    elif action == "Drop":
        obj = f"{entry['carrying']['color']} {entry['carrying']['type']}"
        info.append(obj)
    elif action == "Go next to":
        if entry["object"] is not None:
            obj = f"{entry['object']['color']} {entry['object']['type']}"
            info.append(obj)
        else:
            info.pop()
            info.append("Go somewhere else")

        if entry["reason"] is not None:
            r = None
            if entry['reason'] == "Open":
                re = "open it"
            elif entry['reason'] == "Pickup":
                re = "pick it up"
            elif entry['reason'] == "Drop":
                re = "drop it"
            elif entry['reason'] == "PutNext":
                r = "PutNext"

            if r is not None:
                reason = f"to {r}"
                info.append(reason)
    return " ".join(info)

def filter_xxx(entries):
    return [[x for x in entry if "XXX" not in str(x)] for entry in entries]


def remove_consecutive_duplicate_sequences(sequences, possible_actions=None):
    cleaned_sequences = []
    cleaned_actions = []
    prev_sequence = None
    if possible_actions is None:
        possible_actions = ["" for _ in range(len(sequences))]
    for sequence, action in zip(sequences, possible_actions):
        if sequence != prev_sequence:
            cleaned_sequences.append(sequence)
            cleaned_actions.append(action)
        prev_sequence = sequence

    return cleaned_sequences, cleaned_actions


def get_how_many_subgoals_done(stack):
        # gets a list of stacks
        # returns the number of subgoals done
        # uses the entry_went_away method but doesn't create a string

        # Apply gradual change to capture intermediate states
        stack = gradual_change(deepcopy(stack))

        entries = []
        for i in range(len(stack)):
            if len(stack[i]) != 0 and (i == len(stack) - 1 or len(stack[i]) != len(stack[i + 1]) or stack[i] != stack[i + 1]):
                a = ["XXX" for _ in range(len(stack[i]) - 1)]
                a.append(stack[i][-1])
                entries.append(a)

        entries = propagate_known_values(entries)
        entries = filter_xxx(entries)
        entries, _ = remove_consecutive_duplicate_sequences(entries)

        return len(entries)

def gradual_change(history, possible_actions=None):
    result = [history[0]]  # Start with the first state
    poss_result = None
    if possible_actions is not None:
        poss_result = [possible_actions[0]]
    for i in range(len(history) - 1):
        curr = history[i]      # Current state
        next_state = history[i + 1]  # Next state

        if len(next_state) >= len(curr) - 1:
            result.append(next_state)  # If elements are added or unchanged, append directly
            if possible_actions is not None:
                poss_result.append(possible_actions[i+1])
        else:
            # Gradually remove elements until we reach the target length
            temp = deepcopy(curr[:])
            while len(temp) > len(next_state):
                result.append(deepcopy(temp[:]))  # Append the intermediate step
                if possible_actions is not None:
                    poss_result.append(possible_actions[i])
                temp.pop()  # Remove one element at a time
            result.append(next_state)
            if possible_actions is not None:
                poss_result.append(possible_actions[i+1])
    # print("AAAAA:", len(result), len(poss_result))
    return result, poss_result

def filter_following_same_entries(entries):
    # Filter out entries that are the same as the previous one
    filtered_entries = []
    for i in range(len(entries)):
        if i == 0 or entries[i] != entries[i - 1]:
            filtered_entries.append(entries[i])
    return filtered_entries

def to_dict(tree, nid=None):
    """Transform the whole tree into a simplified dict with id, name, and children."""
    nid = tree.root if nid is None else nid
    ntag = tree[nid].tag
    tree_dict = {"id": nid, "name": ntag, "children": []}

    if tree[nid].expanded:
        for child in tree[nid].successors(tree._identifier):
            tree_dict["children"].append(to_dict(tree, child))

    return tree_dict

def get_possible_actions(entries, possible_actions):
    actions = []
    for i in range(len(entries)-1):
        # print(len(entries[i]), len(entries[i+1]))
        if len(entries[i+1]) < len(entries[i]) or (len(entries[i+1]) >= len(entries[i]) and entries[i][-1]["action"] != entries[i+1][len(entries[i])-1]["action"]):
            actions.append(possible_actions[i])
    actions.append(possible_actions[-1])
    return actions

def entry_went_away(stacks_str, possible_actions=None):

    # print("LOOOOOL", len(stacks_str), len(possible_actions))
    last_poss_act = possible_actions[-1]
    stacks_str, possible_actions = gradual_change(deepcopy(stacks_str), possible_actions)
    # print("LOOOOOL2", len(stacks_str), len(possible_actions))

    possible_actions = get_possible_actions(stacks_str, possible_actions)
    # print("LOOOOOL3", len(stacks_str), len(possible_actions))
    possible_actions.append(last_poss_act)

    entries = []
    for i in range(len(stacks_str)):
        if len(stacks_str[i]) != 0 and (i == len(stacks_str) - 1 or len(stacks_str[i]) != len(stacks_str[i + 1]) or stacks_str[i] != stacks_str[i + 1]):
            a = ["XXX" for _ in range(len(stacks_str[i]) - 1)]
            a.append(create_string(stacks_str[i][-1]))
            entries.append(a)

    # print("ENTRIES:", len(entries))
    entries = propagate_known_values(entries)
    entries = filter_xxx(entries)
    # print("ENTRIES:", len(entries))

    tree = build_tree_with_treelib(entries)
    t = to_dict(tree)
    # print(tree.show(stdout=False, key=False))
    g = []



    for node in postorder_traversal(tree):
        g.append(node.tag)
    # g = filter_following_same_entries(g)




    # print("HAHAHAHA:", len(g), len(possible_actions))
    # for i in range(len(g)):
    #     print(g[i], "---", possible_actions[i])

    return t, g, possible_actions

def stack_entry_to_info(entry):
    if entry.datum is not None and not isinstance(entry.datum, tuple) and not isinstance(entry.datum, numpy.ndarray):
        obj = {
            "color": entry.datum.color if entry.datum.color is not None else "the",
            "type": entry.datum.type
        }
    else:
        obj = None

    if entry.fwd_cell is not None:
        fwd = {
            "color": entry.fwd_cell.color,
            "type": entry.fwd_cell.type
        }
    else:
        fwd = None

    if entry.carrying is not None:
        carry = {
            "color": entry.carrying.color,
            "type": entry.carrying.type
        }
    else:
        carry = None
    return {
        "action": str(entry),
        "object": obj,
        "reason": entry.reason,
        "fwd": fwd,
        "carrying": carry
    }


class GraphCreator:
    def __init__(self):
        self.tree = Tree()
        self.stacks = []
        self.stacks_str = []
        self.to_tree = []
        self.possible_actions = []

    def add_end_possible_actions(self, possible_actions):
        self.possible_actions.append(possible_actions)

    def add_stack(self, stack, possible_actions=[]):
        # stack = stack[1:]
        stack = [x for x in stack if "Explore" not in str(x)]
        stack_str = [stack_entry_to_info(x) for x in stack]
        self.possible_actions.append(possible_actions)
        self.stacks_str.append(stack_str)
        self.stacks.append(stack)

    def check_subgoals_completion(self, action=None):
        """
        Checks which subgoals from the mission have been completed after an agent step.
        Also propagates information between subgoals.

        Args:
            action: The last action taken by the agent (optional)

        Returns:
            dict: Information about completed and remaining subgoals
        """
        # Update the agent's state with the latest action if provided
        if action is not None:
            self.mistake_handler_agent.replan(action_taken=action)

        # Get the current stack and make a copy of the original mission subgoals
        current_stack = self.mistake_handler_agent.stack

        # Create a graph to track the stack changes
        graph = GraphCreator()
        graph.add_stack(current_stack)

        # Use the graph method to find non-finished subgoals
        remaining_subgoals = graph.get_non_finished_mission_subgoals(self.mission_subgoals)

        # Calculate completed subgoals
        completed_subgoals = [sg for sg in self.mission_subgoals if sg not in remaining_subgoals]

        # Propagate information between subgoals (existing or remaining ones)
        enriched_subgoals = self.propagate_subgoal_info(self.mission_subgoals)

        return {
            "completed": completed_subgoals,
            "remaining": remaining_subgoals,
            "progress_percentage": len(completed_subgoals) / len(
                self.mission_subgoals) * 100 if self.mission_subgoals else 100,
            "enriched_subgoals": enriched_subgoals
        }

    def propagate_subgoal_info(self, subgoals):
        """
        Propagates missing information between subgoals.

        Args:
            subgoals: List of subgoals to process

        Returns:
            List of subgoals with propagated information
        """
        # Create a deep copy to avoid modifying the original subgoals
        enriched_subgoals = deepcopy(subgoals)

        # First pass: right to left to propagate information from GoNextToSubgoals
        pickup_info = None

        for i in range(len(enriched_subgoals) - 1, -1, -1):
            sg = enriched_subgoals[i]

            # Store pickup information when encountered
            if "PickupSubgoal" in str(sg):
                if hasattr(sg, 'datum') and sg.datum is not None:
                    pickup_info = {"color": sg.datum.color, "type": sg.datum.type}

            # Propagate information from GoNextToSubgoal to PickupSubgoal or OpenSubgoal
            if "GoNextToSubgoal" in str(sg) and hasattr(sg, 'object') and sg.object is not None:
                # Find the preceding subgoal that needs this information
                for j in range(i - 1, -1, -1):
                    prev_sg = enriched_subgoals[j]

                    if ("PickupSubgoal" in str(prev_sg) or "OpenSubgoal" in str(prev_sg)) and \
                            (not hasattr(prev_sg, 'datum') or prev_sg.datum is None):
                        # Propagate the information to the previous subgoal
                        prev_sg.datum = sg.object
                        break

        # Second pass: left to right to propagate information to DropSubgoal
        for i in range(len(enriched_subgoals) - 1):
            sg = enriched_subgoals[i]

            # Propagate information to DropSubgoal from the next PickupSubgoal
            if "DropSubgoal" in str(sg):
                # Look for the next PickupSubgoal
                for j in range(i + 1, len(enriched_subgoals)):
                    next_sg = enriched_subgoals[j]

                    if "PickupSubgoal" in str(next_sg) and hasattr(next_sg, 'datum') and next_sg.datum is not None:
                        sg.carrying = next_sg.datum
                        break

        return enriched_subgoals


    def print(self):
        return entry_went_away(self.stacks_str, self.possible_actions)

    def subgoals_done(self):
        return get_how_many_subgoals_done(self.stacks)
