from .finite_state import TaskStateValue
import uuid


class TrajectoryNode:
    """
        Each state change corresponds to a Traj Node
    """
    def __init__(self, task_state, agent_state, trans_param, parent_node_id, node_id):
        
        self.task_state = task_state
        self.agent_state = agent_state 
        self.trans_param = trans_param
        
        self._node_id = node_id
        
        # memory will be stored in db, while local memory will not be stored.
        self.memory = {}
        self.local_memory = {}
        
        self._p_node_id = parent_node_id
        self.children = []
        
        # node is activated; error paths will not be activated.
        self.isActive = True
        
    def deactivate(self):
        self.isActive = False
        
    def get_parent_node_id(self):
        return self._p_node_id

    def get_node_id(self):
        return self._node_id
    
    def get_transition(self):
        return self.task_state, self.agent_state, self.trans_param

    def add_child(self, child):
        self.children.append(child)
    
    def set_trans_param(self, trans_param):
        self.trans_param = trans_param
    
    def set_memory(self, memory, local=False):

        for k, v in memory.items():
            if local:
                self.local_memory[k] = v
            else:
                self.memory[k] = v
    
    def get_memory(self, local=False):
        
        if local:
            return self.local_memory
        else:
            return self.memory


class TreeTrajectory:
    
    def __init__(self, task_state, agent_state, trans_param, traj_id):
        
        self.init_node = TrajectoryNode(task_state, agent_state, trans_param, None, traj_id)
        
        self.id2node = {}
        self.id2node[self.init_node.get_node_id()] = self.init_node
        
    def next_node(self, task_state, agent_state, trans_param, parent_traj, node_id):
        node = TrajectoryNode(task_state, agent_state, trans_param, parent_traj.get_node_id(), node_id)
        parent_traj.add_child(node.get_node_id())
        self.id2node[node_id] = node

    def find_leaves(self):
        
        leaves = []
        def dfs(current_node):
            if current_node is None:
                return
            
            if len(current_node.children):
                for c_idx, child_id in enumerate(current_node.children):
                    child = self.id2node[child_id]
                    # print(c_idx, child.agent_state, child.get_node_id())
                    dfs(child)
            else:
                if current_node.task_state != TaskStateValue.Solved:
                    leaves.append(current_node)

        dfs(self.init_node)
        return leaves

    def find_all_leaves(self):
        
        leaves = []
        def dfs(current_node):
            if current_node is None:
                return
            
            if len(current_node.children):
                for c_idx, child_id in enumerate(current_node.children):
                    child = self.id2node[child_id]
                    dfs(child)
            else:
                leaves.append(current_node)

        dfs(self.init_node)
        return leaves

    def find_ancestors(self, traj_id):
        
        traj_path = []
        def dfs(current_node):
            
            if current_node is None:
                return False

            traj_path.append(current_node)
            if current_node.get_node_id() == traj_id:
                return True 

            for child_id in current_node.children:
                child = self.id2node[child_id]
                if dfs(child):
                    return True
            
            traj_path.pop()
            return False
        
        dfs(self.init_node)
        return traj_path
    
    def find_paths(self, node, current_path, all_paths):
        
        current_path.append(node)
        
        if not len(node.children):
            all_paths.append(current_path[:])
        else:
            for child_id in node.children:
                child = self.id2node[child_id]
                self.find_paths(child, current_path, all_paths)
        
        current_path.pop()
        
    def get_all_paths(self, filter_stop=False):
                
        all_paths = []
        self.find_paths(self.init_node, [], all_paths)
        new_paths = []
        for path in all_paths:
            if filter_stop:
                if path[-1].task_state == TaskStateValue.Solved:
                    new_paths.append(path)
            else:
                new_paths.append(path)
        return new_paths

    def update(self, node_id, serial, trans_param=None, local=False):
        self.id2node[node_id].set_memory(serial, local)
        if trans_param:
            self.id2node[node_id].set_trans_param(trans_param)
    
    def get(self, node_id):
        return self.id2node[node_id]

    def get_memory(self, node_id):
        return self.id2node[node_id].get_memory()
