import sys
import time
import copy
import pickle
import pathlib

from typing import Dict, List
from collections import OrderedDict, Counter
from collections.abc import Iterable
from types import SimpleNamespace as SN

# for the path and for the analysis code
sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))

import numpy as np

from utils.hash_tools import HASH_TOOLS
from utils.memory_utils import hash_obs, custom_round


SC2_HASH_SCHEME = 'simple_4'  # 'simple_3', simhash

INFOs = {
    '2m_vs_1z': [(4, 1), (1, 6), (1, 5), (1, 1)],
    '3s_vs_3z': [(4, 1), (3, 6), (2, 6), (2, 1)],
    '3s_vs_4z': [(4, 1), (4, 6), (2, 6), (2, 1)],
}



class GraphNode:
    def __init__(self, key, value, index_of_node_list, traj_len):
        """
        FIXME traj_len is a bit tricky. For varying trajectories, be careful
        data:
            key: <s, a>
            value: tiemstep, visit count, estimated rewards
            next nodes
            pre nodes
        """
        self._key = key
        self._timestep = value.timestep
        self._value = value
        self._next = list()
        self._prev = list()
        self._visit_count = 1
        self._traj_len = traj_len
        self._accumated_rewards = value.reward
        self.index_of_node_list = index_of_node_list

    @property
    def timestep(self):
        return self._timestep

    @property
    def traj_len(self):
        return self._traj_len

    @property
    def key(self):
        return self._key

    @property
    def value(self):
        return self._value

    @property
    def next(self):
        return self._next

    def append_next(self, next_node):
        self._next.append(next_node)

    @property
    def prev(self):
        return self._prev

    def append_prev(self, prev_node):
        self._prev.append(prev_node)

    @property
    def visit_count(self) -> int:
        return self._visit_count

    @visit_count.setter
    def visit_count(self, visit_count: int):
        self._visit_count = visit_count

    def increase_visit_count(self, default=1):
        self._visit_count += default

    def update_accumated_rewards(self, reward):
        self._accumated_rewards += reward

    @property
    def accumated_rewards(self):
        return self._accumated_rewards

    @property
    def averaged_rewards(self):
        return self.accumated_rewards / self.visit_count

    @staticmethod
    def compare_two_nodes(n1, n2, args):
        """
        Can be overried and new metric is welcome.
        """
        if args.env.startswith('sc2'):
            return (n1.key[0] == n2.key[0]) and (n1.key[1] == n2.key[1])

        if isinstance(n1.key[1], Iterable):
            return all(np.equal(n1.key[0], n2.key[0]).flatten()) and \
                all(np.equal(n1.key[1], n2.key[1]).flatten())
        # TODO check for hash values
        else:
            return all(np.equal(n1.key[0], n2.key[0]).flatten()) and n1.key[1] == n2.key[1]


class TrajGraph:
    """
    Build a graph for associative memory to find the key events.

    1. use while training.
    2. classify trajectories given the length of each trajectory.
    3. build the graph from t=0 to t=T-1.
    4. while building the graph, count the visit number.
    5. next searching, when a new trajectory comes,
       find the corresponding ending node, and traverse backward,
       while traversing, the visit count is monotonically increasing and
       find all nodes that ends such monotonicity.
    6. node utilization strategy: use the right nodes to do
       (i) reward swaping,
       (ii) pulling back the future return estimation to the present.
    7. the strategy: find the right nodes.
    8. if the number of some type (with high return) of trajectory is sparse,
       conducting imitation learning to explore, aka exploring while learning.
    """
    def __init__(self, args, trajectory, node_key=('state', 'action'), traj_length=1, hash_tool=None) -> None:
        """
        use dummy data to test the results
        keys should be ('state', 'action') or ('obs', 'action')
        """
        self.args = args
        self._traj_length = traj_length
        self.node_key = node_key
        self.max_traj_len = args.episode_limit if args.use_return else self._traj_length
        self.hash_tool = hash_tool
        self._create_graph(trajectory, self.max_traj_len)

    def _create_graph(self, traj, traj_length):
        # 1. create a time dependent table
        self._time_nodes = OrderedDict({t: list() for t in range(traj_length)})
        # 1.1 create timestep index for leaf node
        self._time_nodes_index = OrderedDict({t: OrderedDict() for t in range(traj_length)})
        # key: hash(leaf node's obs), value: leaf node's obs
        self._time_nodes_index_obs = OrderedDict({t: OrderedDict() for t in range(traj_length)})
        # 2. loop the trajectory and create the graph
        prev_node = None

        if not self.args.use_return:
            assert traj_length == np.array(traj[self.node_key[0]]).shape[0], 'length not equal'
        
        for t in range(len(traj['obs'])):
            s, a, r = traj[self.node_key[0]][t], traj[self.node_key[1]][t], traj['reward'][t]

            s = np.array(s)
            assert isinstance(s, np.ndarray), 'state (obs) is not ndarray'

            # if env is sc2, hash the obs
            if self.args.env.startswith('sc2'):
                infos = {'scheme': SC2_HASH_SCHEME, 'info': INFOs[self.args.env_args['map_name']]}
                key = (self.hash_tool(s, infos=infos), int(a))
            else:
                key = (s, int(a))
            curr_node = GraphNode(key=key,
                                  value=SN(**{'reward': r, 'timestep': t}),
                                  index_of_node_list=len(self._time_nodes[t]),
                                  traj_len=traj_length)

            self._time_nodes[t].append(curr_node)
            # 2.1 if prev_node is not None,
            # then curr node is prev_node's next
            if prev_node is not None:
                prev_node.append_next(curr_node)
            # 2.2 and curr_node's prev node is prev_node
            curr_node.append_prev(prev_node)
            # 2.3 curr_node becomes prev_node
            prev_node = curr_node

            # 2.4 save every level's nodes
            # use the key=(s, a) as the key to save node's index
            if self.args.env.startswith('sc2'):
                infos = {'scheme': SC2_HASH_SCHEME, 'info': INFOs[self.args.env_args['map_name']]}
                _temp_key = (self.hash_tool(s, infos=infos), int(a))
                _s = s
            else:
                _temp_key = (hash_obs(s.astype(int).tobytes()), int(a))
                _s = s.astype(int)

            assert _temp_key not in self._time_nodes_index[t], "duplicated leaf node"

            self._time_nodes_index[t][_temp_key] = curr_node.index_of_node_list
            self._time_nodes_index_obs[t][_temp_key] = _s

    def append_traj(self, trajectory):
        """
        update the current graph with a new trajectory
        """
        prev_node = None
        new_node_created = False
        
        if not self.args.use_return:
            assert self._traj_length == np.array(trajectory[self.node_key[0]]).shape[0], 'length not equal'
        
        new_traj_indices = []  # return the index in the node list of each timestep
        new_traj_indices_mask = []  # return the index in the node list of each timestep
        for t in range(len(trajectory['obs'])):
            s, a, r = trajectory[self.node_key[0]][t], trajectory[self.node_key[1]][t], trajectory['reward'][t]
            s = np.array(s)
            assert isinstance(s, np.ndarray), 'state (obs) is not ndarray'

            # 1. check if current node is in the current graph of timestep t
            # if env is sc2, hash the obs
            if self.args.env.startswith('sc2'):
                infos = {'scheme': SC2_HASH_SCHEME, 'info': INFOs[self.args.env_args['map_name']]}
                key =(self.hash_tool(s, infos=infos), int(a))
                _temp_key = (self.hash_tool(s, infos=infos), int(a))
            else:
                key = (s, int(a))
                _temp_key = (hash_obs(s.astype(int).tobytes()), int(a))
            temp_node = GraphNode(key=key,
                                  value=SN(**{'reward': r, 'timestep': t}),
                                  index_of_node_list=len(self._time_nodes[t]),
                                  traj_len=self.max_traj_len)
            # 1.1 if the node is in the graph, then update the node
            if _temp_key in self._time_nodes_index[t]:
                # if found update the current node, else break and 
                #   create a new node and append to the list
                # compare state and action
                node = self._time_nodes[t][self._time_nodes_index[t][_temp_key]]
                assert GraphNode.compare_two_nodes(node, temp_node, self.args), "node not equal"
                node.increase_visit_count(1)
                node.update_accumated_rewards(r)
                new_traj_indices.append(node.index_of_node_list)
                new_traj_indices_mask.append(False)
                if prev_node is not None:
                    # connect the new node created
                    #   previously in last step
                    if new_node_created:
                        prev_node.append_next(node)
                        node.append_prev(prev_node)
                        new_node_created = False
                prev_node = node
            else:
                curr_node = temp_node
                self._time_nodes[t].append(curr_node)
                new_traj_indices.append(curr_node.index_of_node_list)
                new_traj_indices_mask.append(True)
                # 2.1 if prev_node is not None,
                # then curr node is prev_node's next
                if prev_node is not None:
                    prev_node.append_next(curr_node)
                # 2.2 and curr_node's prev node is prev_node
                curr_node.append_prev(prev_node)
                # 2.3 curr_node becomes prev_node
                prev_node = curr_node
                new_node_created = True

                # 2.4 its a leaf node
                # use the key=(s, a) as the key to node's index
                if self.args.env.startswith('sc2'):
                    infos = {'scheme': SC2_HASH_SCHEME, 'info': INFOs[self.args.env_args['map_name']]}
                    _temp_key =(self.hash_tool(s, infos=infos), int(a))
                    _s = s
                else:
                    _temp_key = (hash_obs(s.astype(int).tobytes()), int(a))
                    _s = s.astype(int)

                assert _temp_key not in self._time_nodes_index[t], "duplicated leaf node"

                self._time_nodes_index[t][_temp_key] = curr_node.index_of_node_list
                self._time_nodes_index_obs[t][_temp_key] = _s
        return new_traj_indices, new_traj_indices_mask

    @staticmethod
    def visit_count_check(graph):
        """
        Given a graph, check all nodes:
            if the total visit count of the current node 
            equals to the sum of visit counts of its next nodes
        """
        check_results = []
        for timestep, node_list in graph.timestep_nodes.items():
            for node in node_list:
                if all([n is None for n in node.next]):
                    continue
                if node.visit_count != sum([n.visit_count for n in node.next if n is not None]):
                    # f"node.visit_count != sum_visit_count(node.next). " \
                    # f"graph info: traj_length: {graph.traj_length} timestep: {timestep}, " \
                    # f"index_of_node_list: node.index_of_node_list."
                    check_results.append({
                        'node': node,
                        'graph': graph
                    })
        return check_results

    def __len__(self):
        """
        return the length of the graph, aka the length of the timestep
        """
        return self._traj_length

    @property
    def traj_length(self):
        return self._traj_length

    @property
    def timestep_nodes(self):
        return self._time_nodes


class TrajMemory:
    """
    The Trajecotory Memory to store the the trajecotories
    """
    def __init__(self, args, hash_tool=None) -> None:
        
        self.hash_tool = hash_tool

        if isinstance(args, dict):
            if 'debug' not in args:
                args['debug'] = False
            if 'filter_strategy' not in args:
                args['filter_strategy'] = 'strict'
            if 'env' not in args:
                args['env'] = 'default'
            if 'trie_search_max_traj_num' not in args:
                args['trie_search_max_traj_num'] = 30
            if 'round_scheme' not in args:
                args['round_scheme'] = 'ceil'  # do not use return by default

            args = SN(**args)
        else:  # SimpleNamespace
            if 'debug' not in vars(args):
                args.debug = False
            if 'filter_strategy' not in vars(args):
                args.filter_strategy = 'strict'
            if 'env' not in vars(args):
                args.env = 'default'
            if 'trie_search_max_traj_num' not in vars(args):
                args.trie_search_max_traj_num = 30
            if 'round_scheme' not in vars(args):
                args.round_scheme = 'ceil'  # do not use return by default
        
        self.args = args
        self.graph_dict = {}  # all return types share a graph
        self.graph_return_dict = {}  # length: return_val: graph
        self.graph_return_raw_dict = {}  # length: return_val: graph
        self.graph_summary = {}  # summarize the results length: return_val: time_step: counts: total number

    def update_memory(self, trajectories: List, info: Dict) -> None:
        """
        use trajectories (a batch) to build the graph.
        info is used to help building the graph.
        """
        # NOTE trajectories is a list of trajectories, the default value is 1
        # do not contain any DL-related calculation, use numpy instead
        length = info['length']  # shape: [bs]
        # TODO be careful on the float numbers, use str instead
        retruns = info['return']  # shape: [bs], return is the sum of total rewards of trajectory
        retruns_raw = info['return_raw']  # shape: [bs], return is the sum of total rewards of trajectory
        for idx in range(len(length)):
            # use mask to get the length of trajectories
            self._build_graph(trajectories[idx], length[idx], retruns[idx], retruns_raw[idx])

    def _build_seperate_return_graph(self, trajectory: List, length: int, return_val: str, return_val_raw: float):
        key = return_val if self.args.use_return else length

        # add the graph data into self.graph_return_dict
        if key not in self.graph_return_dict:
            self.graph_return_dict[key] = {}
            self.graph_return_raw_dict[key] = {}

        val = return_val if not self.args.use_return else length
        if val not in self.graph_return_dict[key]:
            self.graph_return_dict[key][val] = TrajGraph(self.args, trajectory, self.args.node_key, length, hash_tool=self.hash_tool)
            self.graph_return_raw_dict[key][val] = [return_val_raw]
        else:
            self.graph_return_dict[key][val].append_traj(trajectory)
            self.graph_return_raw_dict[key][val].append(return_val_raw)

        # post checking the trajecotries
        if self.args.use_return:
            data_len = sum([int(len(dd)!=0) for dd in self.graph_return_dict[key][length]._time_nodes.values()])
            assert data_len == length, "length not match"

    def _build_graph(self, trajectory: List, length: int, return_val: str, return_val_raw: float) -> None:
        
        key = return_val if self.args.use_return else length

        if key in self.graph_dict:
            if not self.args.use_return:
                assert key == len(self.graph_dict[key]), 'length is not equal'
            # append a new trajectory to the current graph
            if self.args.debug and not self.args.use_return:
                graph_copy = copy.deepcopy(self.graph_dict[key])
            
            new_traj_indices, new_traj_indices_mask = self.graph_dict[key].append_traj(trajectory)

            # add the graph data into self.graph_return_dict
            self._build_seperate_return_graph(trajectory, length, return_val, return_val_raw)

            # NOTE debug
            if self.args.debug and not self.args.use_return:
                if key == 10: # check
                    graph = self.graph_dict[key]
                    check_results = TrajGraph.visit_count_check(graph)
                    if len(check_results) > 0:
                        # flags = [
                        #     GraphNode.compare_two_nodes(n1, n2) for n1, n2 in zip(
                        #                 [graph_copy.timestep_nodes[i][v] for i, v in enumerate(new_traj_indices)],
                        #                 [graph.timestep_nodes[i][v] for i, v in enumerate(new_traj_indices)])
                        # ]
                        print('after appending, the trajectory changed')
        else:
            self.graph_dict[key] = TrajGraph(self.args, trajectory, self.args.node_key, length, hash_tool=self.hash_tool)
            # add the graph data into self.graph_return_dict
            self.graph_return_dict[key] = {}
            self.graph_return_raw_dict[key] = {}
            self._build_seperate_return_graph(trajectory, length, return_val, return_val_raw)

    def update_summary(self):
        """
        NOTE: do not call this function frequently

        update the self.graph_summary = {
            length_val: {
                return_val: {
                    timestep_val: {
                        pick_timestep: {
                            count_val: value,
                            total_val: value,
                        }
                    }
                }
            }
        }  
        """
        # TODO 
        # check if self.graph_summary is empty
        # 1. loop the length
        # 2.    loop the return_val
        # 3.        search each trajectory
        for length, data in self.graph_return_dict.items():
            if length not in self.graph_summary:
                self.graph_summary[length] = {}
            
            for retunr_val, graph in data.items():
                if retunr_val not in self.graph_summary[length]:
                    self.graph_summary[length][retunr_val] = {}
                results = {}
                for leaf_node in graph.timestep_nodes[list(graph.timestep_nodes.keys())[-1]]:
                    pass

    def search(self, one_traj, traj_timesteps=None) -> List:
        """
        search from leaf node, find trajectories
        1. which have monotonically increasing visit_count of nodes
        2. if there is a another seperate increasing visit_count of
           sub-trajectory after the node which has the max visit count
           of the previous sub-trajectory
        3. sort the trajectories given the visit count
        """
        if traj_timesteps is not None:
            assert self.args.env.startswith('sc2'), 'only support sc2 envs when traj_timesteps is not None:'

        stats_info = {}  # for debug, store the stats info, e.g. time cost of searching
        # 0.1 get the return of the trajectory
        return_val = custom_round(np.sum(one_traj['reward']), round_scheme=self.args.round_scheme) \
            if self.args.env.startswith('sc2') else f"{sum(one_traj['reward']):.2f}"
        # 0.2 get the length of the trajectory
        # one_traj['state'], shape: [1 x max_timestep_len]
        length = np.array(one_traj[self.args.node_key[0]]).shape[0]
        # 1. find the leaf nodes
        st = time.time()
        leaf_node = self.get_leaf_node(one_traj, length, return_val)  # the leaf node is indentical
        stats_info['timecost_get_leaf_node'] = time.time() - st
        # 2. search the trajectories
        st = time.time()
        # NOTE search from the leaf node
        #      so the first node of the trajectory is the leaf node
        trajectories = self.search_trajectories(self, leaf_node, length)
        stats_info['timecost_search_trajectories'] = time.time() - st
        # 3. find the trajectories satifying the given condition
        st = time.time()
        if self.args.env.startswith('sc2'):
            # By using quick filtering, parallelization can be implemented
            filtering_results = self._quick_filtering(trajectories, traj_timesteps=traj_timesteps, curr_traj=one_traj)
        else:
            filtering_results = self.filtering_trajectories(trajectories, traj_timesteps=traj_timesteps)
            # filtering_results = self._quick_filtering(trajectories, traj_timesteps=traj_timesteps, curr_traj=one_traj)
        stats_info['timecost_filtering_trajectories'] = time.time() - st
        return filtering_results, stats_info

    def get_leaf_node(self, one_traj, length, return_val='0.1'):
        # NOTE s is obs, not state
        s, a, r = one_traj[self.args.node_key[0]][-1], one_traj[self.args.node_key[1]][-1], one_traj['reward'][-1]
        s = np.array(s)
        assert isinstance(s, np.ndarray), 'state (obs) is not ndarray'

        if self.args.use_return:
            key, val = return_val, length            
        else:
            key, val = length, return_val

        # get the corresponding graph
        try:
            # NOTE caveat: if retrun_val is wrongly calculated but accidnetly can return a graph
            #              it will raise the exception in
            #              finding the _index = graph._time_nodes_index[last_timestep][_temp_key]
            graph = self.graph_return_dict[key][val]
        except KeyError:
            if not self.args.use_return:
                raise KeyError(f'return_val {return_val} not in graph_return_dict[length] (length: {length}), \
                available return_val: {self.graph_return_dict[length].keys()},\
                one_traj["reward"]: {one_traj["reward"]},\
                all keys for all length: {[f"{key}: {list(val.keys())}" for key, val in sorted(self.graph_return_dict.items(), key=lambda x: x[0])]}, \
                print all return_raw: {[(length, val_dict) for length, val_dict in sorted(self.graph_return_raw_dict.items(), key=lambda x: x[0])]}')
                # TODO add function to avoid this issue
            else:
                raise KeyError(f'length {length} not in graph_return_dict[return_val] (return_val: {return_val}), \
                available length: {self.graph_return_dict[return_val].keys()},\
                one_traj["reward"]: {one_traj["reward"]},\
                all keys for all return_val: {[f"{key}: {list(val.keys())}" for key, val in sorted(self.graph_return_dict.items(), key=lambda x: x[0])]}, \
                print all return_raw: {[(return_val, val_dict) for return_val, val_dict in sorted(self.graph_return_raw_dict.items(), key=lambda x: x[0])]}')
                # TODO add function to avoid this issue

        last_timestep = sum([int(len(dd)!=0) for _, dd in graph.timestep_nodes.items()]) - 1 \
            if self.args.use_return else list(graph.timestep_nodes.keys())[-1]

        if self.args.env.startswith('sc2'):
            infos = {'scheme': SC2_HASH_SCHEME, 'info': INFOs[self.args.env_args['map_name']]}
            node_key =(self.hash_tool(s, infos=infos), int(a))
            _temp_key =(self.hash_tool(s, infos=infos), int(a))
        else:
            node_key = (s, a)
            _temp_key = (hash_obs(s.astype(int).tobytes()), int(a))
        temp_node = GraphNode(node_key, SN(**{'reward': r, 'timestep': length-1}), -1, length)
        leaf_node, leaf_nodes = None, []

        # find the leaf node by using the node_key
        # NOTE s and a should be discrete, 
        #      else the search will be not right, since the node_key may not right
        try:
            _index = graph._time_nodes_index[last_timestep][_temp_key]
        except KeyError:
            _keys = graph._time_nodes_index[last_timestep].keys()
            _keys_obs = {k: graph._time_nodes_index_obs[last_timestep][k] for k in _keys}
            if not self.args.use_return:
                raise KeyError(f'key {_temp_key} not in graph._time_nodes_index[last_timestep]. \
                    current traj length: {length} \
                    last_timestep: {last_timestep}. \n s: {s}, a: {a}. \
                    Available keys: {graph._time_nodes_index[last_timestep].keys()},\
                    show all keys_obs: {_keys_obs} \
                    --------------------------------------------- \
                    *** show more usefull info *** \
                    return_val {return_val} not in graph_return_dict[length] (length: {length}), \
                    available return_val: {self.graph_return_dict[length].keys()},\
                    one_traj["reward"]: {one_traj["reward"]},\
                    all keys for all length: {[f"{k}: {list(v.keys())}" for k, v in sorted(self.graph_return_dict.items(), key=lambda x: x[0])]}, \
                    print all return_raw: {[(length, val_dict) for length, val_dict in sorted(self.graph_return_raw_dict.items(), key=lambda x: x[0])]}')
            else:
                # print('hello')
                # return_val_bak = custom_round(np.sum(one_traj['reward']), round_scheme=self.args.round_scheme) \
                #     if self.args.env.startswith('sc2') else f"{sum(one_traj['reward']):.2f}"
                raise KeyError(f'key {_temp_key} not in graph._time_nodes_index[last_timestep]. \
                    current traj length: {length} \
                    last_timestep: {last_timestep}. \n s: {s}, a: {a}. \
                    Available keys: {graph._time_nodes_index[last_timestep].keys()},\
                    show all keys_obs: {_keys_obs}')

        leaf_node = graph.timestep_nodes[last_timestep][_index]
        assert GraphNode.compare_two_nodes(leaf_node, temp_node, self.args), \
            "leaf node is not the same, please double check!!!!"

        # # TODO replace it with the above code (better hash with md5)
        # for node in graph.timestep_nodes[last_timestep]:
        #     if GraphNode.compare_two_nodes(node, temp_node, self.args):
        #         leaf_node = node
        #         leaf_nodes.append(node)
        if leaf_node is None:
            # make sure the leaf node is in the graph
            raise KeyError(f'Node (key=({temp_node.key})) not found in leaf list in the graph')
        # assert len(leaf_nodes) == 1, 'found multiple leaf nodes, which is illegal!'
        return leaf_node

    @staticmethod
    def search_trajectories(this, leaf_node, length):
        chosen_trajectories = []
        # traverse. DFS recursive version
        def _dfs(this, node, path, results, length):
            # FIXME maybe there are cilcles in the graph
            path.append(node)
            if all([n is None for n in node.prev]):  # node is root node
                results.append(path[:])  # copy the path
                # FIXME be carefully on varying length graph
                assert len(results[-1]) == length, \
                    "It seems len(results[-1]) != length"
                if len(results) == this.args.trie_search_max_traj_num:  # early stop
                    return
            else:
                for n in node.prev:  # FIXME make sure there is no circle
                    if n is not None:
                        _dfs(this, n, path, results, length)
                        if len(results) == this.args.trie_search_max_traj_num:  # early stop
                            break
            path.pop()
            # early stop for all environments
            if len(results) == this.args.trie_search_max_traj_num: return
        # search the tree and find the all the trajectories
        _dfs(this, leaf_node, list(), chosen_trajectories, length)
        return chosen_trajectories

    def _quick_filtering(self, trajectories, traj_timesteps=None, curr_traj=None):
        res = OrderedDict()
        # for each traj_timesteps, calculate the merged nodes and find the best one

        # for sc2, actions should be zero
        _traj_timesteps = [1] if (traj_timesteps is None or len(traj_timesteps) == 0) else traj_timesteps
        for timestep in _traj_timesteps:
            res[f'pivot_timestep_{timestep}'] = OrderedDict({
                'time_step':  None,
                'data': OrderedDict(),
                'node_visit_count': OrderedDict(),
                'action': curr_traj['action'],
                'right_time_step':  len(trajectories[0])-timestep-(self.args.max_delay-1),
                'zero_action_timesteps': set(),
            })
            # NOTE: t = 0 of the trajectory is the leafnode
            for i, traj in  enumerate(trajectories):
                for t, node in enumerate(reversed(traj[timestep:])):  # from leaf+1 to starting node
                    if res[f'pivot_timestep_{timestep}']['action'][t+timestep-1] == 0:
                        res[f'pivot_timestep_{timestep}']['zero_action_timesteps'].add(t+timestep-1)
                    # 1. find if the node is in the dict
                    if t not in res[f'pivot_timestep_{timestep}']['data']:
                        res[f'pivot_timestep_{timestep}']['data'][t] = OrderedDict({node: 1})
                    else:
                        # find if the node is in the dict
                        for _node in res[f'pivot_timestep_{timestep}']['data'][t].keys():
                            if node == _node:  # same node, same instance
                                res[f'pivot_timestep_{timestep}']['data'][t][_node] += 1
                                break
                        else:
                            res[f'pivot_timestep_{timestep}']['data'][t][node] = 1
            
            for t, node_count_dict in res[f'pivot_timestep_{timestep}']['data'].items():
                if t not in res[f'pivot_timestep_{timestep}']['node_visit_count']:
                    res[f'pivot_timestep_{timestep}']['node_visit_count'][t] = OrderedDict()
                for node, _ in node_count_dict.items():
                    res[f'pivot_timestep_{timestep}']['node_visit_count'][t][node] = node.visit_count

            # 2. find the best timestep
            # 2.1 sort the dict first
            counter = OrderedDict()
            for t, node_dict in res[f'pivot_timestep_{timestep}']['node_visit_count'].items():
                if t in res[f'pivot_timestep_{timestep}']['zero_action_timesteps']:
                    continue
                # how many duplicated nodes, 1 is the best; len(trajectories) is the worst
                # check if all nodes have visit_count of 1, 2 or 3
                check_res = [v for v in node_dict.values() if v not in {1, 2, 3}]
                if len(check_res) != 0:
                    counter[t] = len(check_res)
            # 2.2 save the result
            _set = {len(traj[timestep:])-1, len(traj[timestep:])-2, len(traj[timestep:])-3}
            if len(counter) != 0:
                min_count = sorted(counter.items(), key=lambda v: v[1], reverse=False)[0][1]
                min_v_timesteps = list(sorted([t for t, v in counter.items() if v == min_count and t not in _set], reverse=False))
            else:
                min_v_timesteps = [len(traj[timestep:]) - 1]
            res[f'pivot_timestep_{timestep}']['min_v_timesteps'] = min_v_timesteps
            res[f'pivot_timestep_{timestep}']['counter'] = counter

            if len(min_v_timesteps) == 0:
                min_v_timesteps = [len(traj[timestep:]) - 1]

            chosen_time = min_v_timesteps[0]

            res[f'pivot_timestep_{timestep}']['time_step'] = chosen_time  # proximal to the leaf node
            res[f'pivot_timestep_{timestep}']['max_visit_count'] = \
                list(reversed(trajectories[0]))[chosen_time].visit_count  # proximal to the leaf node
            res[f'pivot_timestep_{timestep}']['weight'] = sum(res[f'pivot_timestep_{timestep}']['data'][chosen_time].values())

        summarized_res = OrderedDict({0: res })
        return summarized_res

    def filtering_trajectories(self, trajectories, traj_timesteps=None):
        """
        Filtering the trajectories and find the right timeslot
        """
        res = OrderedDict()
        # TODO use multiprocessing to speed up the process, for example using Numba
        # NOTE for traj: traj[0] is the leaft node!!! very important
        for i, traj in enumerate(trajectories):
            reversed_checked = False  # reverse checking -1

            if len(traj) == 1:
                continue
            # sequential reversed, save the data of the first node
            res[i] = OrderedDict()
            visit_counts = list(sorted(set([node.visit_count for node in traj])))
            visit_counts_table = {value: idx for idx, value in enumerate(visit_counts)}
            visit_count_mask = [visit_counts_table[node.visit_count] for node in traj]
            
            # NOTE traverse from the node of the chosen timestep to the root node
            # NOTE 1 is the first timestep, aka skip the leafnode (index 0)
            _traj_timesteps = [1] if (traj_timesteps is None or len(traj_timesteps) == 0) else traj_timesteps
            for timestep in _traj_timesteps:
                res[i][f'pivot_timestep_{timestep}'] = OrderedDict({
                    # NOTE it subjects to the environments, may not the exact timestep
                    'right_time_step':  len(traj)-timestep-(self.args.max_delay-1),
                    # 'visit_count_mask': visit_count_mask[timestep:],  # NOTE skip the exact timestep
                    'visit_count_mask': visit_count_mask,
                    'mask': [1],
                    'visit_count': [ traj[timestep-1].visit_count ],
                    'action': [_node.key[1] for _node in traj]  # save actions for this trajectory
                })
                for idx, node in enumerate(traj[timestep:]):
                    # NOTE make sure the prev_node can be correctly found
                    prev_node = traj[idx + (timestep-1)]

                    # FIXME delete it as we have already checked the code and this line takes some time
                    assert  any([n == prev_node for n in node.next]), \
                        "something wrong in traversing the trajectory???"

                    flag = self._filtering_strategy(curr_visit_count=node.visit_count,
                                                    prev_visit_count=prev_node.visit_count,
                                                    prev_flag=res[i][f'pivot_timestep_{timestep}']['mask'][-1])

                    res[i][f'pivot_timestep_{timestep}']['visit_count'].append(node.visit_count)
                    res[i][f'pivot_timestep_{timestep}']['mask'].append(flag)
            res, pre_processed = self._get_the_timesteps(res, i, traj, reversed_checked, _traj_timesteps)

        # postprocess the results
        postprocessed_res = self._post_process(res, trajectories, pre_processed, traj_timesteps) # TODO
        # summarize the results across all trajectories
        summarized_res = self._summarize(postprocessed_res)  # TODO if 0 is the best check the second one
        return summarized_res

    def _summarize(self, res):
        # summarize is the result, for each pivot timestep, select the timestep with the max count
        # there is only one trajectory
        if len(res) == 1: return res

        new_res = OrderedDict({0: OrderedDict()})
        temp_counter = OrderedDict()

        # new_res[0][f'pivot_timestep_{pivot_timestep+1}']['time_step'] = 0
        # new_res[0][f'pivot_timestep_{pivot_timestep+1}']['max_visit_count'] = common_max_visit_count

        for traj_idx, timestep_dict in res.items():
            for pivot_timestep_key, timestep_data in timestep_dict.items():
                if pivot_timestep_key not in temp_counter:
                    temp_counter[pivot_timestep_key] = [timestep_data]
                else:
                    temp_counter[pivot_timestep_key].append(timestep_data)

        # summarize the results
        for pivot_timestep_key, timestep_data_list in temp_counter.items():
            # create the counter
            counter = Counter([timestep_data['time_step'] for timestep_data in timestep_data_list])
            counter_max = counter.most_common(1)[0][0]
            counter_max_index = [timestep_data['time_step'] for timestep_data in timestep_data_list].index(counter_max)
            new_res[0][pivot_timestep_key] = OrderedDict({
                'time_step': counter_max,
                'max_visit_count': timestep_data_list[counter_max_index]['max_visit_count']
            })
        return new_res

    @staticmethod
    def _reverse_check_decrease(timestep, traj_idx, res):
        # reversed order 0, 0, 1, 1, 2, 3
        # reversed order 2, 0, 1, 2, 3, 4
        # NOTE the reverse here should be careful
        mask = reversed(res[traj_idx][f'pivot_timestep_{timestep}']['mask'])
        mask = list(mask)[:-timestep]
        res_idx = []
        idx = 1 # starts from 1
        while idx < len(mask):
            if mask[idx] == 1 and mask[idx-1] == 1:
                idx += 1
            elif mask[idx] == -1 and (mask[idx-1] == 1 or mask[idx-1] == 0):
                idx += 1
            elif mask[idx] == -1 and (mask[idx-1] == -1 or mask[idx-1] == 0):
                idx += 1
            elif (mask[idx] == 1 or mask[idx] == 0) and mask[idx-1] == -1 and (idx - 2 >=0 and mask[idx-2] == -1):
                # check its next step
                if idx + 1 < len(mask) and (mask[idx+1] == 1 or mask[idx+1] == 0):
                    # check there are two -1 -1 before 1
                    res_idx.append(idx)
                    break
                else:
                    idx += 1  # it mean -1 0 -1 or -1 1 -1, skip this step
            else:
                idx += 1

        return_info = {
            'reversed_checked': True,
            'res_idx': res_idx,
            'pre_processed': False
        }
        if len(res_idx) == 0:
            return_info['the_index'] = -1
            return_info['max_visit_count'] = -1
            max_visit_count = -1
        else:
            return_info['pre_processed'] = True
            # NOTE should add timestep
            the_index = res_idx[-1]
            max_visit_count = max(res[traj_idx][f'pivot_timestep_{timestep}']['visit_count'][:the_index+1-timestep])
            
            return_info['the_index'] = the_index
            return_info['max_visit_count'] = max_visit_count
        return return_info

    @staticmethod
    def _reverse_check(timestep, traj_idx, res):
        # reversed order 0, 0, 1, 1, 2, 3
        # reversed order 2, 0, 1, 2, 3, 4
        # NOTE the reverse here should be careful
        _visit_count_mask = reversed(res[traj_idx][f'pivot_timestep_{timestep}']['visit_count_mask'])
        visit_count_mask = list(_visit_count_mask)[:-timestep]
        res_idx = []
        left, right = 0, 1
        while right < len(visit_count_mask):
            if visit_count_mask[right] == visit_count_mask[left]:
                right += 1
            elif visit_count_mask[right] > visit_count_mask[left]:
                res_idx.append(right)   # NOTE no need to add timestep as it is reversed
                left = right
                right += 1
            elif visit_count_mask[right] < visit_count_mask[left]:
                # if it starts decrease, stop it
                    break

        return_info = {
            'reversed_checked': True,
            'res_idx': res_idx,
            'pre_processed': False
        }
        if len(res_idx) == 0:
            return_info['the_index'] = -1
            return_info['max_visit_count'] = -1
            max_visit_count = -1
        else:
            return_info['pre_processed'] = True
            # NOTE should add timestep
            the_index = (timestep + len(res[traj_idx][f'pivot_timestep_{timestep}']['mask'])) - 1 - res_idx[-1]
            max_visit_count = max(res[traj_idx][f'pivot_timestep_{timestep}']['visit_count'][:the_index+1-timestep])
            
            return_info['the_index'] = the_index
            return_info['max_visit_count'] = max_visit_count
        return return_info

    def _get_the_timesteps(self, res, traj_idx, traj, reversed_checked, traj_timesteps=None):
        """
        res: the result of filtering
        curr_traj_idx: current trajectory's index
        traj: current trajectory
        i: the index of the trajectory
        traj_timesteps: the timesteps when reward is positive (only for sc2)
        """
        # select the first non-decreasing monontonic subjectory
        # then return the time step when the first maxstep occurs
        # slicing, get the index
        # update the res with max_visit_count and time_step

        # NOTE strict strategy is preferred
        # strict strategy:
        #   inputs:  1  2  3  9  9  8  7  6  5  5  8  9  9
        #   flags:   1  1  1  1  0 -1 -1 -1 -1  0  1  1  0
        #   chooce:  -------- 1 --------------------------
        # normal strategy:
        #   inputs:  1  2  3  9  9  8  7  6  5  5  8  9  9
        #   flags:   1  1  1  1  0 -1 -1 -1 -1  0  1  1  0
        #   chooce:  ----------- 1 -----------------------
        pre_processed = False  # processed with heuristic method
        metric = self._filtering_metric()
        revsered_results = OrderedDict()
        for timestep in traj_timesteps:
            # add reversed results
            revsered_results[timestep] = OrderedDict()
            # get the -1 set
            continous_flag = False
            minus_set = set(res[traj_idx][f'pivot_timestep_{timestep}']['mask'][2:])
            counts = Counter(res[traj_idx][f'pivot_timestep_{timestep}']['mask'][2:])
            agent_dead = False  # TODO early stop, if there is no 0 in the mask[2:]
            for idx, m in (enumerate(res[traj_idx][f'pivot_timestep_{timestep}']['mask']) if counts.get(0, 0) > 1 else []): # NOTE it seems there is not enough 0
                # NOTE agent dies
                if res[traj_idx][f'pivot_timestep_{timestep}']['action'][idx+timestep-1] == 0:
                    agent_dead = True
                    continue

                # NOTE if agent dead and the next action is not 0, then select that idx
                if agent_dead and res[traj_idx][f'pivot_timestep_{timestep}']['action'][idx+timestep-1] != 0:
                    # get the index
                    the_index = idx + timestep - 1  # TODO delte it, it is not right as no pattern
                    if the_index < 0:
                        the_index = 0
                    max_visit_count = max(res[traj_idx][f'pivot_timestep_{timestep}']['visit_count'][:the_index+1-timestep])
                    break

                if idx == 0: continue  # skip the leaf node

                if idx == 1 and (m == metric or m == -1):  # HACK skip some instances
                    continue
                
                # HACK skip nodes with visit count 1, 2 and 3
                if res[traj_idx][f'pivot_timestep_{timestep}']['visit_count'][idx] in set({1, 2, 3}):
                    continue

                # when m == -1 check it in the reversed order with two pointers
                if not reversed_checked and -1 in minus_set:
                    return_info = self._reverse_check(timestep, traj_idx, res)  # find increasing pattern from t=0
                    reversed_checked = return_info['reversed_checked']
                    res_idx = return_info['res_idx']
                    if len(res_idx) != 0:
                        # do not use continue, instead let it go to the next subsection
                        the_index = return_info['the_index']  # revserd index
                        max_visit_count = return_info['max_visit_count']

                        # pre convert 
                        the_index = timestep + len(res[traj_idx][f'pivot_timestep_{timestep}']['mask']) - 1 - the_index
                        # save and compare the maxinum the_index
                        revsered_results[timestep]['the_index'] = the_index
                        revsered_results[timestep]['max_visit_count'] = max_visit_count

                # add new strategy find the continous 1 (increasing) and then continous -1 (at least 2)
                if sum(res[traj_idx][f'pivot_timestep_{timestep}']['mask'][2:idx+1]) == idx-2+1 and \
                    idx + 1 < len(res[traj_idx][f'pivot_timestep_{timestep}']['mask']) and \
                    res[traj_idx][f'pivot_timestep_{timestep}']['mask'][idx+1] == -1:
                    # check if the next steps at least 3 steps are -1
                    if sum(res[traj_idx][f'pivot_timestep_{timestep}']['mask'][idx+1:idx+1+3]) == -3:
                        the_index = idx
                        max_visit_count = max(res[traj_idx][f'pivot_timestep_{timestep}']['visit_count'][:the_index+1])
                        the_index = idx + timestep  # NOTE a new index due to timestep
                        continous_flag = True
                        break

                # NOTE check -1 first above see the reversed code
                elif m == metric and reversed_checked and idx + 1 < len(res[traj_idx][f'pivot_timestep_{timestep}']['mask']):
                    # check
                    if res[traj_idx][f'pivot_timestep_{timestep}']['mask'][idx+1] == -1:
                        the_index = idx
                        max_visit_count = max(res[traj_idx][f'pivot_timestep_{timestep}']['visit_count'][:the_index+1])
                        the_index = idx + timestep  # NOTE a new index due to timestep
                        break
                    else:
                        continue

                elif m == metric:
                    # example:
                    #       data: 8  19  34  87  155  309  309  309  309  309
                    #       mask: 1  1   1   1   1    1    0    0    0    0
                    the_index = idx + timestep - 1
                    
                    if the_index < 0:
                        the_index = 0
                    # slicing and get the maximal visit count
                    # TODO: make sure the visit count is right
                    max_visit_count = max(res[traj_idx][f'pivot_timestep_{timestep}']['visit_count'][:the_index+1-timestep])
                    break
            else:
                flag = False
                if counts.get(0, 0) <= 1 and self.args.env.startswith('afforest'):
                    return_info = self._reverse_check_decrease(timestep, traj_idx, res)  # find increasing pattern from t=0
                    res_idx = return_info['res_idx']
                    if len(res_idx) != 0:
                        flag = True
                        # do not use continue, instead let it go to the next subsection
                        the_index = return_info['the_index']  # revserd index
                        max_visit_count = return_info['max_visit_count']

                        # pre convert 
                        the_index = timestep + len(res[traj_idx][f'pivot_timestep_{timestep}']['mask']) - 1 - the_index
                        # save and compare the maxinum the_index
                        revsered_results[timestep]['the_index'] = the_index
                        revsered_results[timestep]['max_visit_count'] = max_visit_count
                if not flag:
                    the_index = timestep + len(res[traj_idx][f'pivot_timestep_{timestep}']['mask']) - 1
                    # slicing and get the maximal visit count
                    max_visit_count = max(res[traj_idx][f'pivot_timestep_{timestep}']['visit_count'])   
            
            # NOTE check the visit count of the trajectory
            # if the visit count of all the nodes are indentical and > 3
            # then the time step should be 0
            visit_count_set = list(set([node.visit_count for node in traj]))
            if len(visit_count_set) == 1 and visit_count_set[0] > 3:
                res[traj_idx][f'pivot_timestep_{timestep}']['time_step'] = 0
                res[traj_idx][f'pivot_timestep_{timestep}']['max_visit_count'] = visit_count_set[0]
            else:
                # NOTE get the max_visit_count's timestep, reverse it n-1-i
                # given i, find its reversed index: n-1-i

                # compare the the_index with that of the reversed index
                if len(revsered_results[timestep]) != 0 and revsered_results[timestep]['the_index'] != -1 and continous_flag:
                    the_index = max(revsered_results[timestep]['the_index'], the_index)
                    max_visit_count = max(revsered_results[timestep]['max_visit_count'], max_visit_count)

                res[traj_idx][f'pivot_timestep_{timestep}']['time_step'] = (timestep + len(res[traj_idx][f'pivot_timestep_{timestep}']['mask'])) - 1 - the_index
                res[traj_idx][f'pivot_timestep_{timestep}']['max_visit_count'] = max_visit_count

                assert res[traj_idx][f'pivot_timestep_{timestep}']['time_step'] <= len(traj) - 1, \
                    f'{res[traj_idx][f"pivot_timestep_{timestep}"]["time_step"]} > {len(traj) - 1}'

        return res, pre_processed

    def _post_process(self, res, trajectories, pre_processed=False, traj_timesteps=None):
        # for each pivot timestep, update the results
        res_data = copy.deepcopy(res)
        _traj_pivot_timesteps = [1] if (traj_timesteps is None or len(traj_timesteps) == 0) else traj_timesteps
        for pivot_timestep in _traj_pivot_timesteps:
            _res, _flag = self.__post_process(res_data, trajectories, pre_processed, pivot_timestep-1)
            if _flag:
                res_data[0][f'pivot_timestep_{pivot_timestep}'] = _res[0][f'pivot_timestep_{pivot_timestep}']
                # delete 1 and the rest
                for key in list(res_data.keys()):
                    if key != 0:  # and f'pivot_timestep_{pivot_timestep}' in res[key]:
                        del res_data[key]

                assert 'max_visit_count' in res_data[0][f'pivot_timestep_{pivot_timestep}'], \
                    f'max_visit_count not found in {res_data[0][f"pivot_timestep_{pivot_timestep}"]}. Show res_data: {res_data}'

        return res_data

    def __post_process(self, res, trajectories, pre_processed=False, pivot_timestep=0):
        # get the common nodes
        common_nodes = []
        trajectories = [traj[pivot_timestep:] for traj in trajectories]

        if len(trajectories) > 1 and not pre_processed:
            # find the common path of trajectories
            # starts from the leaf node
            length = min([len(traj) for traj in trajectories])            
            common_nodes = [trajectories[-1][0]]  # add leaf node
            for idx in range(1, length):
                if len(set([traj[idx] for traj in trajectories])) == 1 \
                    and trajectories[-1][idx].key[1] != 0:  # NOTE action should not be 0
                    common_nodes.append(trajectories[-1][idx])  # no need to break
            
            if len(common_nodes) > 1:
                # get the max visit count of the common path
                common_max_visit_count = max([node.visit_count for node in common_nodes[1:]])
                if common_nodes[-1].timestep == 0: 
                    reverse_common_nodes = False
                else:
                    reverse_common_nodes = True
                # get the timestep of the max visit count of the common path
                # NOTE 1. when th common_nodes[-1] is at timestep 0, 
                #         do not select its timestep as the right timestep
                #      2. there are many types of trajectories, A, B, C, D and E
                for node in (reversed(common_nodes) if reverse_common_nodes else common_nodes):
                    # TODO delete: if idx == 0: continue
                    # TODO delete: it would be great if the leafnode happens to be the proximal node
                    if node.visit_count == common_max_visit_count:
                        common_max_visit_count_timestep = node.timestep
                        break

                # check common_max_visit_count is the max visit count of all trajectories, excluding the leafnode
                is_common_max_visit_count = True
                for traj in trajectories:
                    _max_count = max([node.visit_count for node in \
                        (traj[1:-1] if self.args.env.startswith('afforest') else traj[1:])])
                    if _max_count != common_max_visit_count:
                        is_common_max_visit_count = False
                        break

                new_res = OrderedDict({0: OrderedDict()})
                new_res[0][f'pivot_timestep_{pivot_timestep+1}'] = OrderedDict()
                # NOTE
                if is_common_max_visit_count:
                    new_res[0][f'pivot_timestep_{pivot_timestep+1}']['time_step'] = common_max_visit_count_timestep
                    new_res[0][f'pivot_timestep_{pivot_timestep+1}']['max_visit_count'] = common_max_visit_count
                else:
                    # current results are not valid, return an dummy result
                    new_res[0][f'pivot_timestep_{pivot_timestep+1}']['time_step'] = 0
                    new_res[0][f'pivot_timestep_{pivot_timestep+1}']['max_visit_count'] = common_max_visit_count
                return new_res, True
        return res, False

    def _filtering_metric(self):
        # if self.args.env.startswith('sc2'):
        #     return -1
        
        if self.args.filter_strategy == 'strict':
            # the node of a prefix tree, and it starts new branches
            metric = 0
        elif self.args.filter_strategy == 'normal':
            # not non-decreasing monotonic now.
            metric = -1
        else:
            raise ValueError(f'filter_strategy: {self.args.filter_strategy} not available!')
        return metric

    def _filtering_strategy(self, curr_visit_count, prev_visit_count, prev_flag):
        """
        return the signal (-1, 0, 1)

        strict strategy:
            inputs:  1  2  3  9  9  8  7  6  5  5  8  9  9
            outputs: 1  1  1  1  0 -1 -1 -1 -1  0  1  1  0
        
        normal strategy:
            inputs:  1  2  3  9  9  8  7  6  5  5  8  9  9
            outputs: 1  1  1  1  1 -1 -1 -1 -1 -1  1  1  1
        """
        flag = curr_visit_count - prev_visit_count

        if flag > 0:
            flag = 1
        elif flag == 0:
            if self.args.filter_strategy == 'strict':
                flag = 0
            elif self.args.filter_strategy == 'lazy':
                if prev_flag == -1:
                    flag = -1
                else:  # prev_flag is 1
                    flag = 1
            else:
                raise ValueError(f'filter_strategy: {self.args.filter_strategy} not available!')
        else:
            flag = -1

        return flag

    @staticmethod
    def new(traj_memory):
        obj = TrajMemory(args=traj_memory.args)
        obj.graph_dict = traj_memory.graph_dict
        obj.graph_return_dict = traj_memory.graph_return_dict
        return obj

    def sanity_check(self, visit_count=True):
        """
        Do some check:
        1. visit count
        """
        check_results = {}
        if visit_count:
            check_results['visit_count'] = dict()
            for traj_len, graph in self.graph_dict.items():
                check_results['visit_count'][traj_len] = TrajGraph.visit_count_check(graph)
        return check_results


class AgentsTrajMemories:
    """
    Create seperate TrajMemory instance for each agent.
    Each agent only builds its trajectory graph memory given
    its past (obs, action, reward) pairs.
    """
    def __init__(self, args) -> None:
        # do not convert args into SN(**args)
        self.args = args
        if self.args.env.startswith('sc2'):
            self.args.use_return = True
        else:
            self.args.use_return = False
        
        try:
            self.n_agents = args.n_agents
        except:
            self.n_agents = args['n_agents']

        self.hash_tool = HASH_TOOLS[args.hash_tool](args)
        self._create_traj_memory()

    @staticmethod
    def load_graph(data_path=''):
        if len(data_path) != 0:
            import sys
            from pathlib import Path
            dir_path = Path(__file__).parent
            sys.path.append(str(dir_path))
            with open(data_path, 'rb') as f:
                obj = pickle.load(f)
            return obj['traj_graph_mem_agents']

    def _create_traj_memory(self):
        """
        create seperate TrajMemory instance for each agent
        """
        self._agents_traj_memory = {}
        for agent_id in range(self.n_agents):
            # create empty TrajMemory instance
            self._agents_traj_memory[agent_id] = TrajMemory(self.args, hash_tool=self.hash_tool)

    def update_memory(self, agent_trajectories: List, info: Dict) -> None:
        assert len(agent_trajectories) == self.n_agents, "length is not equal"
 
        for agent_id, trajectories in enumerate(agent_trajectories):
            traj_memory = self._agents_traj_memory[agent_id]
            traj_memory.update_memory(trajectories, info)

    def search(self, agent_trajs, traj_reward=None):
        """
        Given agents' sync trajectories (episode), the tuple is
            {..., (... agent_id: traj ...), ...}
        search the right timestep.
        """
        #      3. create hash for keys
        #      4. do not use hash for envs except for sc2, compare the node with hash
        #      5. search the right timestep for each agent
        results = []
        stats_info = {}  # for debug and stats (e.g. time cost for searching)
        
        reward_index = []
        traj_timesteps = None

        if isinstance(self.args, dict):
            self.args = SN(**self.args)

        if traj_reward is not None and self.args.env.startswith('sc2'):
            # split the trajectories backwards given the reward signals
            traj_timesteps = self.create_reward_intervals(traj_reward)  # timesteps of the trajectory
            reward_index = list(reversed([idx - (self.args.max_delay-1) for idx, r in enumerate(traj_reward) if r > 0]))

        # add no_trie mode for sc2 to debug the code
        if self.args.debug_no_trie:
            return self._debug_search_no_trie(agent_trajs, traj_reward=traj_reward, reward_index=reward_index, traj_timesteps=traj_timesteps)

        st = time.time()
        for agent_id, traj in enumerate(agent_trajs.values()):
            agent_graph_memory = self.agents_traj_memory[agent_id]
            result, info = agent_graph_memory.search(traj, traj_timesteps=traj_timesteps)
            results.append(result)

            # NOTE save stats info, for debug and stats
            for key, val in info.items():
                if key not in stats_info:
                    stats_info[key] = [val]
                else:
                    stats_info[key].append(val)
        
        for key, val_list in list(stats_info.items()):  # calculate the average
            stats_info[key+'_peragent'] = np.mean(val_list)

        stats_info['timecost_search_per_agent_trajs'] = time.time() - st
        # res[f'pivot_timestep_{timestep}']['weight']
        # if self.args.env.startswith('sc2'):
        #     # NOTE reassign the value to get the best result
        #     timestep_dict = OrderedDict()
        #     for i, res in enumerate(results): # agents
        #         # Find the item that has the max weight
        #         weights = []
        #         for t, vs_dict in res[0].items():
        #             if t not in timestep_dict:
        #                 timestep_dict[t] = [vs_dict['weight']]
        #             else:
        #                 timestep_dict[t].append(vs_dict['weight'])

        #     # Get the agent index
        #     time_step_agent_idx = OrderedDict()
        #     for timestep, agent_weight_list in timestep_dict.items():
        #         time_step_agent_idx[timestep] = np.argmax(agent_weight_list)

        # reset the results # FIXME I cannot understand why I need to reset the results
        # for agent_id, res in enumerate(results):
        #     for t, vs_dict in res[0].items():
        #         if agent_id != time_step_agent_idx[t]:
        #             res[0][t]['time_step'] = 0

        return results, stats_info, reward_index

    def _debug_search_no_trie(self, agent_trajs, traj_reward=None, reward_index=None, traj_timesteps=None):
        results, stats_info = [], {}

        _rew_indices = [t for t, r in enumerate(traj_reward) if r > 0]
        assert len(_rew_indices) == len(traj_timesteps) and self.args.env.startswith('sc2'), \
            "debug no trie is only allowed for sc2 envs"

        if len(traj_timesteps) == 0:
            traj_timesteps, _rew_indices = [len(traj_reward)-1], [1]

        for _, _ in enumerate(agent_trajs.values()):
            res = OrderedDict({0: OrderedDict()})
            for i, _idx in enumerate(reversed(_rew_indices)):
                res[0][f'pivot_timestep_{traj_timesteps[i]}'] = OrderedDict({
                        'time_step': _idx,
                        'max_visit_count': 10,
                        "traj_len": len(traj_reward),
                        "rewards": traj_reward,
                })
            results.append(res)
        return results, stats_info, reward_index

    @property
    def agents_traj_memory(self):
        return self._agents_traj_memory

    def create_reward_intervals(self, traj_reward):
        """
        Given the reward signals of each agent, create the reward intervals for the current trajectory
        """
        # example:
        #         index:       [0, 1, 2, 3, 4, 5, 6, 7]
        #         traj_reward: [1, 0, 1, 0, 1, 0, 1, 0]
        #         reversed:    [0, 1, 0, 1, 0, 1, 0, 1]
        #         outcome:     [(1, 2), (3, 4), (5, 6)] -> [2, 4, 6] -> [6, 4, 2]
        # 1. find the time steps of which reward is greater than zero  # FIXME maybe reversed is not needed
        timesteps = [t + 1 for t, r in enumerate(reversed(traj_reward)) if r > 0 and t < len(traj_reward) - 1]
        # 2. no searching, just use the timesteps to find the first timestep
        # return list(reversed(timesteps))
        return list(timesteps)
