import json
import os
import random
from pathlib import Path
from typing import Tuple, Any, Dict
from typing import TypeGuard

import numpy as np
from torch.utils.tensorboard.writer import SummaryWriter

from tame.data_handling.trace import Trace, RewardTrace
from tame.hierarchy.base_agent import LevelAgent
from tame.utils.utils import hasmethod

trace_types = {"none": None, "full": Trace, "reward": RewardTrace}


def is_nested_dict(
        obj: Dict[str, np.ndarray] | Dict[str, Dict[str, np.ndarray]],
) -> TypeGuard[Dict[str, Dict[str, np.ndarray]]]:
    """Check if a dictionary is a nested dictionary.

    This function determines if a dictionary contains dictionaries as values.
    It checks whether all values in the input dictionary are themselves dictionaries.

    Args:
        obj: A dictionary that contains either numpy arrays or dictionaries of numpy arrays.
            Can be either Dict[str, np.ndarray] or Dict[str, Dict[str, np.ndarray]].

    Returns:
        TypeGuard[Dict[str, Dict[str, np.ndarray]]]: True if all values in the dictionary
            are dictionaries, False otherwise. The TypeGuard ensures type narrowing
            in a type checker.
    """
    return all(isinstance(v, dict) for v in obj.values())


class LevelEnv:
    """A hierarchical environment wrapper that manages multiple agents at a specific level."""

    def __init__(
            self,
            agents: Dict[str, LevelAgent],
            uplinks: Dict[str, list] | None,
            downlinks: Dict[str, list],
            env: Any,
            trace_type: str = "none",
            name: str = "level_0",
            verbose: int = 0,
            action_freq: int = 1,
            action_space: Dict | None = None,
            save_period: int = 100000
    ):
        """Initialize LevelEnv - action_ifreq is kept for config compatibility but not used internally."""
        self.verbose = verbose
        self.obs_from_bottom = {}
        self.rewards_from_bottom = {}
        self.agents = agents
        self.env = env
        self.uplinks = uplinks
        self.downlinks = downlinks
        self.action_freq = max(1, action_freq)
        self.action_space = action_space

        if trace_type not in trace_types:
            raise ValueError(
                f"Wrong trace type specified for {name}. Given: {trace_type} - Available: {trace_types.keys()}"
            )
        self.trace = trace_types[trace_type]
        if self.trace is not None:
            self.trace = self.trace()

        self.level_ts = 0  # This one is not reset
        self.episode_ts = 0  # This one is reset
        self.episode_returns = None
        self.episode_idx = -1
        self.action = None
        self.directives = None

        self.name = name
        self.logger = None
        self.save_path = None
        self.save_period = save_period

    def seed(self, seed):
        """
        Sets the random seed for this environment and its components.

        This method initializes random number generators for the environment by setting
        the seed for Python's random module, NumPy's random number generator, all agents
        in the environment, and the underlying environment if it supports seeding.

        Args:
            seed (int): The seed value to use for random number generation.

        Returns:
            None
        """
        random.seed(seed)
        np.random.seed(seed)
        for agent in self.agents.values():
            agent.seed(seed=seed)

        if hasmethod(self.env, "seed"):
            self.env.seed(seed)

    def set_logger(self, logger: SummaryWriter | None, save_path: str | Path):
        """Sets logger and save path for the environment.

        This method initializes or updates the logging configuration for the environment by setting
        the logger instance and creating the necessary directory structure for saving data.

        Args:
            logger (SummaryWriter | None): The logger instance to be used for logging. Can be None.
            save_path (str | Path): The base path where environment data will be saved. Will be
                combined with the environment name to create the final save location.

        Notes:
            - Creates the save directory if it doesn't exist
            - Stores the logger instance as self.logger
            - Constructs the final save path by joining save_path with environment name
            - Prints status messages if verbose level is >= 2
        """
        self.save_path = Path(save_path) / self.name
        if self.verbose >= 2:
            print(
                f"{self.name} -  Updating logger and save path: {self.save_path}")

        if not self.save_path.exists():
            if self.verbose >= 2:
                print(
                    f"{self.name} -  Save path {self.save_path} does not exists. Creating it."
                )

            os.makedirs(self.save_path, exist_ok=True)
        self.logger = logger

    def flatten_rewards(self, env_rew: Dict[str, Dict[str, float]]):
        result = {}
        for agent_name, rewards in env_rew.items():
            reward_vector = np.array(list(rewards.values()))
            result[agent_name] = self.agents[agent_name].target_reward(
                reward_vector)
        return result

    def reset(self, training: bool = True) -> Tuple[
        Dict[str, Dict[str, np.ndarray]] | None, dict | None]:
        """Reset the level environment."""
        if self.verbose >= 1:
            print(f"{self.name} - Reset")

        # Log episode return
        if self.episode_idx >= 0 and self.episode_returns is not None:
            if self.logger is not None:
                self.logger.add_scalar(
                    f"{self.name}/total_return",
                    np.sum([ret for ret in self.episode_returns.values()]),
                    self.level_ts,
                )
            if self.verbose >= 2:
                print(
                    f"{self.name} - level_step: {self.level_ts} - Ep. return={np.sum([ret for ret in self.episode_returns.values()])}")

        # Save trace and empty it
        if self.save_path is not None and self.trace is not None and len(
                self.trace):
            self.trace.save_trace(save_path=self.save_path,
                                  episode=self.episode_idx)
            self.trace.empty()

        # Reset agents
        for agent in self.agents.values():
            if hasmethod(agent, "reset"):
                agent.reset()

        # Reset lower level env
        if isinstance(self.env, LevelEnv):
            # If lower level is also a LevelEnv, pass training parameter
            env_obs, env_inf = self.env.reset(training=training)
        else:
            # If it's a base environment, call reset without training parameter
            env_obs, env_inf = self.env.reset()

        env_obs = self.save_obs_from_bottom(env_obs=env_obs)
        env_rew = self.save_rewards_from_bottom(env_rew=None)

        msg_obs, _, _, _, msg_infos = self.upward_inference(
            env_obs=env_obs,
            env_rew=env_rew,
            env_term=None,
            env_trunc=None,
            env_inf=env_inf,
            training=training
        )

        # Reset episode counters
        self.episode_returns = {ag: 0.0 for ag in self.agents}
        self.episode_ts = 0
        self.episode_idx += 1

        return msg_obs, msg_infos

    @staticmethod
    def init_reward_accumulator(env_rew):
        """Initialize reward accumulator based on the structure of env_rew."""
        if env_rew is None:
            return {}

        if isinstance(env_rew, dict):
            if all(isinstance(v, dict) for v in env_rew.values()):
                # Nested reward structure: {agent: {subagent: reward}}
                return {agent: {sub: 0.0 for sub in sub_rewards.keys()}
                        for agent, sub_rewards in env_rew.items()}
            else:
                # Flat reward structure: {agent: reward}
                return {agent: 0.0 for agent in env_rew.keys()}
        return {}

    @staticmethod
    def accumulate_rewards(accumulated_rewards, env_rew):
        """Add env_rew to accumulated_rewards, maintaining the same structure."""
        if env_rew is None or accumulated_rewards is None:
            return accumulated_rewards

        if isinstance(env_rew, dict) and isinstance(accumulated_rewards, dict):
            if all(isinstance(v, dict) for v in env_rew.values()):
                # Nested structure: {agent: {subagent: reward}}
                for agent, sub_rewards in env_rew.items():
                    if agent in accumulated_rewards:
                        for sub_agent, reward in sub_rewards.items():
                            if sub_agent in accumulated_rewards[agent]:
                                accumulated_rewards[agent][sub_agent] += reward
                            else:
                                accumulated_rewards[agent][sub_agent] = reward
                    else:
                        accumulated_rewards[agent] = {sub_agent: reward
                                                      for sub_agent, reward in
                                                      sub_rewards.items()}
            else:
                # Flat structure: {agent: reward}
                for agent, reward in env_rew.items():
                    if agent in accumulated_rewards:
                        accumulated_rewards[agent] += reward
                    else:
                        accumulated_rewards[agent] = reward

        return accumulated_rewards

    @staticmethod
    def accumulate_dones(accumulated_dones, env_dones):
        """Accumulate done flags using OR operation."""
        if env_dones is None:
            return accumulated_dones

        if accumulated_dones is None:
            return {agent: done for agent, done in env_dones.items()}

        for agent, done in env_dones.items():
            if agent in accumulated_dones:
                accumulated_dones[agent] = accumulated_dones[agent] or done
            else:
                accumulated_dones[agent] = done

        return accumulated_dones

    def make_env_step(self, action: Dict[str, Any], training: bool = True) -> \
            Tuple[
                Dict[str, np.ndarray] | Dict[str, Dict[str, np.ndarray]],
                Dict[str, float] | Dict[str, Dict[str, float]],
                Dict[str, bool],
                Dict[str, bool],
                dict
            ]:
        """
        Execute multiple steps of the lower level environment according to action_ifreq,
        aggregating results and potentially returning early on episode termination.
        """
        accumulated_rewards = None
        accumulated_terminations = None
        accumulated_truncations = None
        final_obs = None
        final_info = None

        for step_i in range(self.action_freq):
            # Execute one step
            if isinstance(self.env, LevelEnv):
                env_obs, env_rew, env_term, env_trunc, env_inf = self.env.step(
                    action, training=training)
            else:
                env_obs, env_rew, env_term, env_trunc, env_inf = self.env.step(
                    action)

            # Always keep the latest observations and info
            final_obs = env_obs
            final_info = env_inf

            # Initialize accumulators on first step
            if step_i == 0:
                accumulated_rewards = self.init_reward_accumulator(env_rew)
                accumulated_terminations = {agent: False for agent in
                                            env_term.keys()} if env_term else {}
                accumulated_truncations = {agent: False for agent in
                                           env_trunc.keys()} if env_trunc else {}

            # Accumulate rewards
            accumulated_rewards = self.accumulate_rewards(accumulated_rewards,
                                                          env_rew)

            # Accumulate termination flags
            accumulated_terminations = self.accumulate_dones(
                accumulated_terminations, env_term)
            accumulated_truncations = self.accumulate_dones(
                accumulated_truncations, env_trunc)

            # Early termination if episode ended
            if any(env_term.values()) or any(env_trunc.values()):
                break

        return final_obs, accumulated_rewards, \
            accumulated_terminations, accumulated_truncations, final_info

    def save_obs_from_bottom(
            self,
            env_obs: Dict[str, np.ndarray] | Dict[str, Dict[str, np.ndarray]]
    ) -> Dict[str, Dict[str, np.ndarray]]:
        """Wraps and stores the observation from the bottom level of the hierarchy.

        This method takes observations from a lower level and reformats them if necessary to maintain
        a consistent hierarchical structure. It supports both nested and flat observation dictionaries.

        Args:
            env_obs (Dict[str, np.ndarray] | Dict[str, Dict[str, np.ndarray]]): Observations from the lower level.
                Can be either a flat dictionary mapping agent IDs to observations,
                or a nested dictionary with agent-subagent structure.

        Returns:
            Dict[str, Dict[str, np.ndarray]]: A nested dictionary where:
                - First level keys are agent IDs
                - Second level keys are subagent IDs
                - Values are the corresponding observations as numpy arrays

        Note:
            The method stores the processed observations in self.obs_from_bottom before returning them.
        """
        if is_nested_dict(env_obs):
            self.obs_from_bottom = env_obs
            return env_obs
        else:
            reformatted_obs = {}
            for agent, subagents in self.downlinks.items():
                reformatted_obs[agent] = {subagent: env_obs[subagent] for
                                          subagent in subagents}
            self.obs_from_bottom = reformatted_obs
            return reformatted_obs

    def save_rewards_from_bottom(
            self,
            env_rew: Dict[str, float] | Dict[str, Dict[str, float]] | None
    ) -> Dict[str, Dict[str, float]]:
        """
        Wrap rewards from the bottom level into a nested dict consistently.

        Returns nested rewards shaped as {this_level_agent: {child_agent: float}}.
        If env_rew is None, returns zeros. If env_rew is flat, wraps via downlinks.
        If env_rew is already nested, returns it unchanged.
        """
        if env_rew is None:
            self.rewards_from_bottom = {ag: {sub: 0.0 for sub in subs} for
                                        ag, subs in self.downlinks.items()}
            return self.rewards_from_bottom

        # Detect nested by inspecting a value
        if is_nested_dict(env_rew):  # type: ignore
            self.rewards_from_bottom = env_rew  # type: ignore
            return self.rewards_from_bottom

        nested: Dict[str, Dict[str, float]] = {}
        for agent, subagents in self.downlinks.items():
            nested[agent] = {subagent: float(env_rew.get(subagent, 0.0)) for
                             subagent in
                             subagents}
        self.rewards_from_bottom = nested
        return self.rewards_from_bottom

    def step(
            self,
            directive: Dict[str, np.ndarray] | None,
            training: bool = True,
    ) -> Tuple[
        Dict[str, Dict[str, np.ndarray]] | None,
        Dict[str, Dict[str, float]] | None,
        Dict[str, bool] | None,
        Dict[str, bool] | None,
        dict | None,
    ]:
        """Perform one step - now using make_env_step wrapper."""

        # Sample action
        self.action = self.downward_inference(
            observation=self.obs_from_bottom,
            directive=directive,
            training=training
        )

        # Execute multiple environment steps according to action_ifreq
        env_obs, env_rew, env_term, env_trunc, env_inf = self.make_env_step(
            self.action, training)

        # Save observations and rewards
        prev_env_obs = self.obs_from_bottom.copy()
        prev_env_rew = self.rewards_from_bottom.copy()

        env_obs = self.save_obs_from_bottom(env_obs=env_obs)
        env_rew = self.save_rewards_from_bottom(env_rew=env_rew)

        dones = {agent: (env_term[agent] or env_trunc[agent]) for agent in
                 self.agents}

        # Generate upward messages
        msg_obs, msg_rew, msg_term, msg_trunc, msg_infos = self.upward_inference(
            env_obs=env_obs,
            env_rew=env_rew,
            env_trunc=env_trunc,
            env_term=env_term,
            env_inf=env_inf,
            training=training
        )

        # Training: store transitions and update agents
        if training:
            for agent_name, agent in self.agents.items():
                reward_vector = np.array(list(env_rew[agent_name].values()))
                aggregated_reward = agent.target_reward(reward=reward_vector)

                # Prepare state
                state_components = list(prev_env_obs[agent_name].values())
                if self.directives and agent_name in self.directives and \
                        self.directives[agent_name] is not None:
                    directive_val = self.directives[agent_name]
                    if isinstance(directive_val, (int, np.integer)):
                        directive_array = np.array([float(directive_val)])
                    else:
                        directive_array = np.atleast_1d(directive_val)
                    state_components.append(directive_array)

                # Store main transition
                agent.store(
                    state=np.concatenate(state_components, axis=None),
                    action=self.action.get(agent_name),
                    reward=aggregated_reward,
                    done=dones[agent_name],
                )

                # Store for psi if exists
                if hasattr(agent, 'psi') and agent.psi is not None:
                    psi_state = np.concatenate(
                        list(prev_env_obs[agent_name].values()) +
                        [np.array(list(prev_env_rew[agent_name].values()))],
                        axis=None)

                    proxy_reward = 0.0
                    for hl_agent, rewards_dict in msg_rew.items():
                        if agent_name in rewards_dict:
                            proxy_reward = rewards_dict[agent_name]
                            break
                    agent.psi_store(
                        state=psi_state,
                        action=proxy_reward,
                        reward=aggregated_reward,
                        done=dones[agent_name]
                    )

                # Store for phi if exists
                if hasattr(agent, 'phi') and agent.phi is not None:
                    phi_state = np.concatenate(
                        list(prev_env_obs[agent_name].values()) +
                        [np.array(list(prev_env_rew[agent_name].values()))],
                        axis=None)
                    # Find the message for this agent
                    agent_message = None
                    for hl_agent, messages in msg_obs.items():
                        if agent_name in messages:
                            agent_message = messages[agent_name]
                            break

                    agent.phi_store(
                        state=phi_state,
                        action=agent_message,
                        reward=aggregated_reward,
                        done=dones[agent_name]
                    )

                # Store final state if episode ended
                if dones[agent_name]:
                    final_state_components = list(
                        self.obs_from_bottom[agent_name].values())
                    if self.directives and agent_name in self.directives and \
                            self.directives[agent_name] is not None:
                        directive_val = self.directives[agent_name]
                        if isinstance(directive_val, (int, np.integer)):
                            directive_array = np.array([float(directive_val)])
                        else:
                            directive_array = np.atleast_1d(directive_val)
                        final_state_components.append(directive_array)

                    agent.store(
                        state=np.concatenate(final_state_components,
                                             axis=None),
                        action=None,
                        reward=None,
                        done=True
                    )

                    if hasattr(agent, 'psi') and agent.psi is not None:
                        final_psi_state = np.concatenate(
                            list(self.obs_from_bottom[agent_name].values()) +
                            [np.array(list(self.rewards_from_bottom[
                                               agent_name].values()))],
                            axis=None)
                        agent.psi_store(
                            state=final_psi_state,
                            action=None,
                            reward=None,
                            done=True
                        )

                    if hasattr(agent, 'phi') and agent.phi is not None:
                        final_phi_state = np.concatenate(
                            list(self.obs_from_bottom[agent_name].values()) +
                            [np.array(list(self.rewards_from_bottom[
                                               agent_name].values()))],
                            axis=None)
                        agent.phi_store(
                            state=final_phi_state,
                            action=None,
                            reward=None,
                            done=True
                        )

            # Update agents
            self.update_step()

        # Save important infos
        if self.trace is not None:
            self.trace.add(
                episode=self.episode_idx,
                actions=self.action,
                observations=prev_env_obs,
                rewards=self.flatten_rewards(env_rew),
                terminations=env_term,
                truncations=env_trunc,
                return_observations=msg_obs,
                return_rew=msg_rew,
            )

            if any(dones.values()):
                self.trace.add_final_obs(self.obs_from_bottom)

        # Update episode returns
        for agent in self.episode_returns:
            reward_vector = np.array(list(env_rew.get(agent, {}).values()))
            self.episode_returns[agent] += float(
                self.agents[agent].target_reward(reward=reward_vector))

        # Periodic save
        if self.level_ts % self.save_period == 0 and self.save_path is not None:
            if self.verbose >= 2:
                print(
                    f"{self.name} - Saving model at level step {self.level_ts} in {self.save_path}")
            with open(self.save_path / "last_model.json", "w") as f:
                json.dump({"save_step": self.level_ts}, f)
            for agent_name, agent in self.agents.items():
                agent.save_agent(save_path=self.save_path, name=agent_name)

        self.level_ts += 1
        self.episode_ts += 1
        return msg_obs, msg_rew, msg_term, msg_trunc, msg_infos

    def update_step(self):
        """
        This function is called for each step and iterates through all agents in the level,
        calling their individual update_step methods. The level's global timestep and logger
        are passed to each agent's update step.

        This function is called for each step.
        """
        for agent in self.agents.values():
            agent.update_step(global_step=self.level_ts, writer=self.logger)

    def save(self, save_path: str | Path):
        """Save the level and its agents to disk.

        This method saves all agents in the level to the specified directory. The directory
        will be created if it doesn't exist. Each agent is saved with its corresponding name
        as a prefix.

        Args:
            save_path (str | Path): Base directory where to save the level data. A subdirectory
                                   with the level's name will be created inside this path.

        Example:
            >>> level.save("/path/to/save/directory")
            # Creates /path/to/save/directory/level_name/ and saves agents there

        Notes:
            - If verbose >= 1, prints information about the saving process
            - If verbose >= 2, also prints if the save path differs from the original one
        """
        save_path = Path(save_path) / self.name
        if not save_path.exists():
            os.makedirs(save_path, exist_ok=True)

        if self.verbose >= 1:
            print_message = f"{self.name} -  Save function called with {save_path}. "
            if self.verbose >= 2 and save_path != self.save_path:
                print_message += f"Original save path was different: {self.save_path}. "
            print_message += "Saving agents."
            print(print_message)

        for agent_name, agent in self.agents.items():
            agent.save_agent(save_path=save_path, name=f"{agent_name}")

    def load(self, load_path: str | Path) -> bool:
        """Loads all the agents of the level from the specified path.

        Args:
            load_path (str | Path): The path from where to load the agents.

        Returns:
            bool: True if all agents were loaded successfully, False if any agent failed to load.

        Example:
            >>> level.load("/path/to/agents")
            True
        """
        all_loaded = True
        for agent_name, agent in self.agents.items():
            loaded = agent.load_agent(load_path=load_path, name=agent_name)
            if not loaded:
                print(f"Model {agent_name} cannot be loaded")
            all_loaded = all_loaded and loaded
        return all_loaded

    def upward_inference(
            self,
            env_obs: Dict[str, Dict[str, np.ndarray]],
            env_rew: Dict[str, Dict[str, float]],
            env_term: Dict[str, bool] | None,
            env_trunc: Dict[str, bool] | None,
            env_inf: dict | None,
            training: bool = False
    ) -> (
            Tuple[
                Dict[str, Dict[str, np.ndarray]],
                Dict[str, Dict[str, float]],
                Dict[str, bool],
                Dict[str, bool],
                dict,
            ]
            | Tuple[None, None, None, None, None]
    ):
        """Processes information upward in the hierarchy by generating communications and proxy rewards.

        Args:
            env_obs: Nested observations from lower level
            env_rew: Nested rewards from lower level
            env_term: Termination flags
            env_trunc: Truncation flags
            env_inf: Info dictionaries
            training: Whether to use training mode for agent communications

        Returns:
            Tuple of processed observations, rewards, terminations, truncations, and infos for higher level
        """
        if self.uplinks is None:
            return {}, {}, env_term or {}, env_trunc or {}, {}

        observations = {hl_agent: {} for hl_agent in self.uplinks}
        rewards: Dict[str, Dict[str, float]] = {hl_agent: {} for hl_agent in
                                                self.uplinks}
        terminations = {hl_agent: False for hl_agent in self.uplinks}
        truncations = {hl_agent: False for hl_agent in self.uplinks}
        infos = {hl_agent: {} for hl_agent in self.uplinks}

        flat_observations = {}
        flat_proxies = {}

        for agent in self.agents:
            reward_vector = np.array(list(env_rew.get(agent, {}).values()))
            observation_array = np.concatenate(list(env_obs[agent].values()),
                                               axis=None)

            if training:
                if hasattr(self.agents[agent], 'comm_train'):
                    flat_observations[agent] = self.agents[agent].comm_train(
                        observation_array, reward_vector,
                        global_step=self.level_ts
                    )
                else:
                    flat_observations[agent] = self.agents[agent].comm(
                        observation_array, reward_vector
                    )

                if hasattr(self.agents[agent], 'proxy_reward_train'):
                    flat_proxies[agent] = self.agents[
                        agent].proxy_reward_train(
                        observation_array, reward_vector,
                        global_step=self.level_ts
                    )
                else:
                    flat_proxies[agent] = self.agents[agent].proxy_reward(
                        observation_array, reward_vector
                    )
            else:
                flat_observations[agent] = self.agents[agent].comm(
                    observation_array, reward_vector
                )
                flat_proxies[agent] = self.agents[agent].proxy_reward(
                    observation_array, reward_vector
                )

        for hl_agent, agent_names in self.uplinks.items():
            for agent in agent_names:
                observations[hl_agent][agent] = flat_observations[agent]
                rewards[hl_agent][agent] = flat_proxies[agent]

                if env_term is not None:
                    terminations[hl_agent] = terminations[
                                                 hl_agent] or env_term.get(
                        agent, False)
                if env_trunc is not None:
                    truncations[hl_agent] = truncations[
                                                hl_agent] or env_trunc.get(
                        agent, False)

                if env_inf is not None:
                    infos[hl_agent][agent] = env_inf.get(agent, {})
                else:
                    infos[hl_agent][agent] = {}

        return observations, rewards, terminations, truncations, infos

    def downward_inference(
            self, observation: Dict[str, Dict[str, Any]],
            directive: Dict[str, Any] | None,
            training: bool = False
    ) -> Dict[str, Any]:
        """
        Generates this level actions for the lower level.
        This is the top->bottom direction in the hierarchy.

        Args:
            observation (Dict[str, Dict[str, Any]]): Nested dictionary containing observations.
                First level key is this level's agent name,
                Second level key is bottom level agent name,
                Value is the corresponding observation.
            directive (Dict[str, Any]): Actions/directives from higher level agents.
            training (bool, optional): Whether the action is generated during training.
                Defaults to False.

        Returns:
            Dict[str, Any]: Actions dictionary from this level for the lower one.
                Keys are agent names and values are their corresponding actions.
        """
        agent_actions = {}

        # Process directives from higher level
        flat_directives = {agent: None for agent in self.agents}
        if directive is not None and self.uplinks is not None:
            for sup_agent, agents in self.uplinks.items():
                for agent in agents:
                    if sup_agent in directive:
                        flat_directives[agent] = directive[sup_agent]

        # Store directives for later use in training
        self.directives = flat_directives

        for agent_name in self.agents:
            ag_observation = np.concatenate(
                list(observation[agent_name].values()), axis=None)
            ag_directive = flat_directives[agent_name]

            if training:
                agent_act = self.agents[agent_name].act_train(
                    ag_observation,
                    ag_directive,
                    global_step=self.level_ts,
                )
            else:
                agent_act = self.agents[agent_name].act(
                    ag_observation,
                    ag_directive
                )

            agent_actions[agent_name] = agent_act
        return agent_actions

    def act(self, observation: Dict[str, Any]) -> Dict[str, Any]:
        """Generate actions for this level based on observations.

        Args:
            observation: Observations for agents at this level

        Returns:
            Actions to be sent to lower level
        """
        return self.downward_inference(
            observation=observation,
            directive=None,
            training=False
        )
