"""Forward-Only Deep Q-Network (DQN) agent."""
import copy
import os
import random
import time
from collections import deque
from functools import partial
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Type, Union
from pathlib import Path

import flax
import flax.linen as nn
import gymnasium as gym
import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint as ocp
import optax
import tqdm
from flax.training import orbax_utils
from flax.training.train_state import TrainState
from gymnasium import spaces
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.vec_env import VecNormalize
from tensorboardX import SummaryWriter

import forward_only.fwd_layers as fo
from forward_only.features_extractors import (BaseFeaturesExtractor,
                                              FlattenExtractor)
from forward_only.utils import EnvSpec, linear_schedule, preprocess_obs
from forward_only.utils.eval import evaluate_policy
from forward_only.utils.logging import console

#########################################################
### Global config variables for the Fwd-DQN algorithm ####
#########################################################
BREAKPOINT_WEIGHTS = False
#########################################################

class RecurrentBufferSamples(NamedTuple):
    observations: np.ndarray
    actions: np.ndarray
    next_observations: np.ndarray
    dones: np.ndarray
    rewards: np.ndarray
    last_activations: List[np.ndarray]

class RecurrentBuffer(ReplayBuffer):
    def __init__(
            self,
            buffer_size: int,
            observation_space: spaces.Space,
            action_space: spaces.Space,
            net_arch: list[int],
            device: str = "cpu",
            n_envs: int = 1,
            optimize_memory_usage: bool = False,
            handle_timeout_termination: bool = True,
    ):
        super().__init__(buffer_size, observation_space, action_space, device, n_envs, optimize_memory_usage,
                         handle_timeout_termination)

        self.last_activations = []
        for hidden_size in net_arch:
            self.last_activations.append(np.zeros((self.buffer_size, n_envs, hidden_size)))

    # Not following ABC's signature for add() so python will complain, but for now whatever
    def add(
        self,
        obs: np.ndarray,
        next_obs: np.ndarray,
        action: np.ndarray,
        reward: np.ndarray,
        done: np.ndarray,
        last_activations: np.ndarray,
        infos: List[Dict[str, Any]],
    ) -> None:

        for i in range(len(last_activations)):
            self.last_activations[i][self.pos] = last_activations[i].copy()

        super().add(obs, next_obs, action, reward, done, infos)

    def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RecurrentBufferSamples:        # Sample randomly the env idx
        env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))

        if self.optimize_memory_usage:
            next_obs = self._normalize_obs(self.observations[(batch_inds + 1) % self.buffer_size, env_indices, :], env)
        else:
            next_obs = self._normalize_obs(self.next_observations[batch_inds, env_indices, :], env)

        data = (
            self._normalize_obs(self.observations[batch_inds, env_indices, :], env),
            self.actions[batch_inds, env_indices, :],
            next_obs,
            # Only use dones that are not due to timeouts
            # deactivated by default (timeouts is initialized as an array of False)
            (self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1),
            self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env),
            [self.last_activations[i][batch_inds, env_indices] for i in range(len(self.last_activations))]
        )

        return RecurrentBufferSamples(*data)


class FwdDQNTrainState(TrainState):
    """Custom DQN training state that also contains target network parameters."""
    target_params: flax.core.FrozenDict[str, Any]


class FwdDQN:
    r"""Forward-Only Deep Q-Network (DQN) agent.

    Default hyperparameters are taken from Nature DQN paper, and modified
    as needed to work with forward-only layers.

    To improve learning stability, we found that the following changes
    helped: using a larger batch size, updating the target network less
    frequently, and annealing the learning rate over time.

    Args:
        env_spec: The environment to train on. Either the ID of a registered
            environment or an :class:`EnvSpec` object.
        features_extractor_cls: The feature extractor class.
            Default: ``FlattenExtractor``.
        net_arch: A list of integers specifying the number of units in each
            hidden layer of the forward-only network. Default: ``[64, 64]``.
        q_net_kwargs: Keyword arguments to pass to the forward-only layers.
        learning_rate: A float or :class:`optax.Schedule` specifying the
            learning rate. Default: 1e-4. If a float is given, a constant
            learning rate schedule will be created using that value.
        huber_loss: Whether to use the Huber loss function instead of the
            MSE loss function. Default: ``True``.
        double_q: Whether to use the double Q-learning algorithm. Default:
            ``True``.
        buffer_size: The size of the replay buffer. Default: 1_000_000.
        gamma: The discount factor. Default: 0.99.
        tau: The soft update factor. Default: 1.0.
        target_network_frequency: The frequency (in number of steps) at which
            to update the target network. Default: 10000.
        max_grad_norm: The maximum norm of the gradients. Default: 10.
            Use ``None`` to disable gradient clipping.
        batch_size: The batch size to use for training. Default: 32.
        start_eps: The initial exploration rate. Default: 1.0.
        end_eps: The final exploration rate. Default: 0.05.
        exploration_fraction: The fraction of the total number of steps over
            which the exploration rate is annealed from `start_eps` to
            `end_eps`. Default: 0.1.
        learning_starts: The number of steps to wait before starting training.
            Default: 10000.
        train_frequency: The frequency (in number of steps) at which to train
            the network. Default: 4.
        seed: A seed to use for the environment and Jax. Default: 1.
        device: The device to use for training. Use None to automatically
            select a device (CPU or CUDA). Default: None.
    """

    features_extractor: BaseFeaturesExtractor
    huber_loss: bool
    double_q: bool
    gamma: float
    tau: float
    target_network_frequency: int
    max_grad_norm: Optional[float]
    batch_size: int
    start_eps: float
    end_eps: float
    exploration_fraction: float
    learning_starts: int
    train_frequency: int
    seed: int

    # Private Instance Attributes:
    #   _lr_schedule: The learning rate schedule.
    #   _rb: The replay buffer.
    #   _env_spec: The environment specification.
    #   _q_net_layers: The layers of the forward-only Q-network.
    #   _q_net_states: The training states of the forward-only Q-network.
    _lr_schedule: optax.Schedule
    _rb: ReplayBuffer
    _env_spec: EnvSpec
    _q_net_layers: list[nn.Module]
    _q_net_states: list[FwdDQNTrainState]

    def __init__(
            self,
            env_spec: Union[str, EnvSpec],
            features_extractor_cls: Type[BaseFeaturesExtractor] = FlattenExtractor,
            net_arch: list[int] = [64, 64],
            q_net_kwargs: Any = {},
            learning_rate: Union[float, optax.Schedule] = 1e-4,
            recurrent_connections: bool = False,
            backward_connections: bool = True,
            average_q_values: bool = True,
            huber_loss: bool = True,
            double_q: bool = True,
            buffer_size: int = 1_000_000,
            gamma: float = 0.99,
            tau: float = 1.0,
            target_network_frequency: int = 10000,
            max_grad_norm: Optional[float] = 10,
            batch_size: int = 32,
            start_eps: float = 1.0,
            end_eps: float = 0.05,
            exploration_fraction: float = 0.5,
            learning_starts: int = 10000,
            train_frequency: int = 4,
            seed: Optional[int] = 1
    ) -> None:
        """Initialize the DQN agent."""
        super().__init__()

        self._env_spec = EnvSpec(env_spec) if isinstance(
            env_spec, str) else env_spec

        assert isinstance(self._env_spec.action_space, gym.spaces.Discrete), \
            'Only discrete action space is supported, but got: ' \
            f'{self._env_spec.action_space}'

        self.features_extractor = features_extractor_cls(
            self._env_spec.observation_space)
        num_actions = int(self._env_spec.action_space.n)

        # Create a JAX PRNG key
        self.seed = seed or int(time.time())
        key = jax.random.PRNGKey(self.seed)
        _, q_key = jax.random.split(key)

        # Define an update rule for the gradients based on the Adam optimizer
        self._lr_schedule = learning_rate if callable(
            learning_rate) else optax.constant_schedule(learning_rate)
        grad_clip_fn = optax.clip_by_global_norm(
            max_grad_norm) if max_grad_norm is not None else optax.identity()
        grad_transform = optax.chain(
            grad_clip_fn,  # Clip the gradients by their global norm
            optax.scale_by_adam(b2=0.99),  # Use the updates from Adam
            #optax.add_decayed_weights(1e-3),  # AdamW weight decay
            # Scale the updates by the learning rate schedule
            optax.scale_by_schedule(self._lr_schedule),
            # Scale updates by -1 since optax.apply_updates is additive and we want to descend on the loss.
            optax.scale(-1),
        )

        self.recurrent_connections = recurrent_connections
        self.backward_connections = backward_connections
        self.average_q_values = average_q_values

        # Create the Q-network and layer training states
        self._q_net_layers, self._q_net_states = [], []
        for i, hidden_size in enumerate(net_arch):
            # Create a dummy input to initialize the parameters
            # In the context of the recurrent model, the input is a concatenation
            # of the previous layer's activations (or observations if it's the first layer),
            # the current layer's activations, and the next layer's activations (if it exists)
            last_act_size = net_arch[i - 1] if i > 0 else self.features_extractor.features_dim
            next_act_size = net_arch[i + 1] if i < len(net_arch) - 1 else 0

            if self.recurrent_connections and self.backward_connections:
                dummy_x = jnp.zeros((last_act_size + hidden_size + next_act_size,))
            elif self.recurrent_connections:
                dummy_x = jnp.zeros((last_act_size + hidden_size,))
            elif self.backward_connections:
                dummy_x = jnp.zeros((last_act_size + next_act_size,))
            else:
                dummy_x = jnp.zeros((last_act_size,))

            layer = fo.FoldingFwdLinear(
                hidden_size, num_actions, **q_net_kwargs)

            train_state = FwdDQNTrainState.create(
                apply_fn=layer.apply,
                params=layer.init(q_key, dummy_x),
                target_params=layer.init(q_key, dummy_x),
                tx=grad_transform,
            )
            # Apply jit to the layer for faster inference
            layer.apply = jax.jit(layer.apply)

            self._q_net_layers.append(layer)
            self._q_net_states.append(train_state)

        self._rb = RecurrentBuffer(
            buffer_size,
            self._env_spec.observation_space,
            self._env_spec.action_space,
            net_arch,
            device='cpu',
            optimize_memory_usage=True,
            handle_timeout_termination=False,
            n_envs=self._env_spec.num_envs
        )

        # Store the hyperparameters
        self.net_arch = net_arch
        self.q_net_kwargs = q_net_kwargs
        self.learning_rate = learning_rate
        self.huber_loss = huber_loss
        self.double_q = double_q
        self.gamma = gamma
        self.tau = tau
        self.target_network_frequency = target_network_frequency
        self.max_grad_norm = max_grad_norm
        self.batch_size = batch_size
        self.start_eps = start_eps
        self.end_eps = end_eps
        self.exploration_fraction = exploration_fraction
        self.learning_starts = learning_starts
        self.train_frequency = train_frequency

        self._zero_activations = [
            jnp.zeros((1, layer_activation_size))
            for layer_activation_size in self.net_arch
        ]

    def forward(self, obs: jax.typing.ArrayLike,
                last_activations: jax.typing.ArrayLike,
                layer_index: Optional[int] = None,
                use_target_net: bool = False) -> jax.Array:
        """Predict the Q-values for the given observation.

        Args:
            obs: The observation to predict an action for.
            layer_index: The index of the layer to return the output of.
                If None, return the output of the last layer. Default: ``None``.
            use_target_net: Whether to use the target network instead of the
                Q-network. Default: ``False``.

        Returns:
            The Q-values for the given observation.

        Raises:
            IndexError: If the layer index is out of range.
        """
        if last_activations is None:
            last_activations = self._zero_activations


        layer_index = layer_index or len(self._q_net_layers) - 1
        assert 0 <= layer_index < len(
            self._q_net_layers), 'Invalid layer index'

        params = [q_state.target_params if use_target_net else q_state.params
                  for q_state in self._q_net_states]
        return self._forward_helper(self,
                                    self.recurrent_connections,
                                    self.backward_connections,
                                    self.average_q_values,
                                    params,
                                    obs,
                                    last_activations)

    @partial(jax.jit, static_argnums=(0, 1, 2, 3, 4))
    def _forward_helper(
        cls,
        fwd_dqn: 'FwdDQN',
        recurrent_connections: bool,
        backward_connections: bool,
        average_q_values: bool,
        params: list[flax.core.FrozenDict[str, Any]],
        obs: jax.Array,
        last_activations: jax.Array,
    ) -> jax.Array:
        """A JIT-compiled forward pass helper function.

        This will iterate through the first ``len(params)`` layers of the
        Q-network and return the output of the last layer it iterates through.

        Args[[[:
            fwd_dqn: The FwdDQN instance.
            params: The parameters of the Q-network for as many layers as
                should be used in the forward pass. The length of this list
                must be between 1 and the number of layers in the Q-network.
            obs: The observation to predict an action for.
        """
        x = fwd_dqn._get_features(obs)
        activations = []

        # Sum the q_values from every layer
        if average_q_values:
            q_values_sum = jnp.zeros((x.shape[0], fwd_dqn._env_spec.action_space.n))

        for i, layer in enumerate(fwd_dqn._q_net_layers[:len(params)]):
            if i < len(params) - 1:
                if recurrent_connections and backward_connections:
                    x = jnp.concatenate([x, last_activations[i], last_activations[i + 1]], axis=1)
                elif recurrent_connections:
                    x = jnp.concatenate([x, last_activations[i]], axis=1)
                elif backward_connections:
                    x = jnp.concatenate([x, last_activations[i + 1]], axis=1)
                else:
                    pass
            else:
                if recurrent_connections and backward_connections:
                    x = jnp.concatenate([x, last_activations[i]], axis=1)
                elif recurrent_connections:
                    x = jnp.concatenate([x, last_activations[i]], axis=1)
                elif backward_connections:
                    pass
                else:
                    pass

            x, q_values = layer.apply(params[i], x)

            if average_q_values:
                q_values_sum += q_values


            activations.append(x)

        if average_q_values:
            # Divide by the number of layers to get the average q_values
            return q_values_sum / len(params), activations

        else:
            # Return the q_values from the last layer
            return q_values, activations

    @partial(jax.jit, static_argnums=0)
    def _get_features(self, obs: jax.typing.ArrayLike) -> jax.Array:
        """Get the features from the observation."""
        return self.features_extractor(
            preprocess_obs(obs, self._env_spec.observation_space)
        )

    def predict(self, obs: jax.Array, last_activations: jax.Array, **kwargs: Any) -> np.ndarray:
        """Select an action for the given observation.

        Args:
            obs: The observation to select an action for.
            last_activations: The per-layer activations from the last time step.
            **kwargs: Additional keyword arguments to pass :meth:`forward`.

        Returns:
            The selected action, as a numpy array.
        """
        q_values, activations = self.forward(obs, last_activations, **kwargs)
        return jax.device_get(jnp.argmax(q_values, axis=-1)), activations

    def get_policy(self, layer_index: Optional[int] = None) \
            -> Callable[[np.ndarray], np.ndarray]:
        """Return a function that selects an action given an observation.

        Args:
            layer_index: The index of the layer to return the output of.
                If None, return the output of the last layer. Default: ``None``.

        Returns:
            A function that takes in an observation and returns an action.

        Remarks:
            This is a convenience method that wraps the :meth:`predict` method
            in a lambda function with the given ``layer_index``.
        """
        return lambda obs, last_activations: self.predict(obs, last_activations, layer_index=layer_index)

    def evaluate(
        self,
        global_step: int = 0,
        writer: Optional[SummaryWriter] = None,
        verbose: int = 0,
        num_episodes: int = 10,
        eval_envs: Optional[list[gym.Env]] = None
    ) -> tuple[list[tuple[float, float]], list[gym.Env]]:
        r"""Evaluate the agent's performance layer-by-layer and log the results.

        During evaluation, the agent's policy is replaced with a policy that
        only uses the first ``i`` layers of the network, where ``i`` is the
        current layer being evaluated. This allows us to see how the agent's
        performance changes as it learns to use more layers of the network.

        Args:
            global_step: The current global step. This is used for logging
                purposes only. Default: ``0``.
            writer: The TensorBoard writer to log the results and game videos
                to. If ``None``, no results or videos will be logged.
                Default: ``None``.
            verbose: The verbosity level. If 0, do not print any messages.
                If 1, print a message at the start of each evaluation. If 2,
                print a message for each episode, show the progress bar for
                each episode, and print a summary at the end of each evaluation.
                Default: 0.
            num_episodes: The number of episodes to evaluate for each layer.
                Default: 10.
            eval_envs: A list of environments to use for evaluation, one for
                each layer of the network. If ``None``, a new environment will
                be created for each layer using the env spec of this agent.
                Default: ``None``.

        Returns:
            A list of tuples, where each tuple contains the mean and standard
            deviation of the episodic returns for the corresponding layer, and
            a list of the environments used for evaluation.
        """
        if eval_envs is None:
            eval_envs = [
                self._env_spec.make_env(
                    record_video=writer is not None,
                    record_video_freq=1,  # Record every episode
                    run_log_dir=f'{writer.logdir}/eval/layer_{i + 1}' \
                    if writer is not None else None,
                    # Avoid seed collisions with training env(s)
                    seed=self.seed + self._env_spec.num_envs + 1 + i,
                )
                for i in range(self.num_layers)
            ]

        if verbose > 0:
            console.log(f'Evaluating agent at global step {global_step}...')

        # Evaluate each layer of the network
        eval_results = []
        for i, eval_env in enumerate(eval_envs):
            if verbose > 1:
                console.log(f'\tEvaluating layer {i + 1}...')

            mean_ep_return, std_ep_return = evaluate_policy(
                self.get_policy(i),
                eval_env,
                num_episodes=num_episodes,
                show_progress=verbose > 1,
            )

            eval_results.append((mean_ep_return, std_ep_return))
            if writer is not None:
                # Log the results to TensorBoard
                writer.add_scalar(
                    f'eval/layer_{i + 1}/mean_episodic_return', mean_ep_return, global_step)
                writer.add_scalar(
                    f'eval/layer_{i + 1}/std_episodic_return', std_ep_return, global_step)

            if verbose > 1:
                console.log(
                    f'\t\tMean episodic return: {mean_ep_return:.3f} ± {std_ep_return:.3f}')

        # Evaluate the full network
        mean_ep_return, std_ep_return = evaluate_policy(self.get_policy(),
                                                        eval_env,
                                                        num_episodes=num_episodes,
                                                        show_progress=verbose > 1)
        eval_results.append((mean_ep_return, std_ep_return))

        return eval_results, eval_envs

    def learn(  # noqa: C901
            self,
            total_timesteps: int = 500000,
            log_frequency: int = 100,
            eval: bool = True,
            eval_episodes: int = 10,
            eval_frequency: Optional[int] = None,
            exp_name: str = 'FwdDQN',
            save_checkpoints: bool = True,
            track: bool = False,
            wandb_project_name: Optional[str] = None,
            wandb_entity: Optional[str] = None,
            record_video: bool = False,
            show_progress: bool = True,
            verbose: int = 1,
            callback: Optional[EvalCallback] = None
    ) -> dict[str, Any]:
        """Train the agent for the given number of timesteps.

        Args:
            total_timesteps: The total number of timesteps to train for.
                Default: 500000.
            log_frequency: The number of timesteps between logging progress.
                Default: 100.
            eval: Whether to evaluate the agent's performance before, during,
                and after training. Default: ``True``. Each evaluation consists
                of a number of episodes equal to ``eval_episodes``.
            eval_episodes: The number of episodes to evaluate the agent for.
                Default: 10. This is ignored if ``eval`` is ``False``.
            eval_frequency: The number of timesteps between evaluations.
                Evaluations will be performed before training begins (i.e. on
                the network as it was initialized or last trained), every
                ``eval_frequency`` timesteps during training, and after
                training completes. If ``None``, evaluations will only be
                performed before training begins and after training completes.
                Default: ``None``. This is ignored if ``eval`` is ``False``.
            exp_name: The name of the experiment.
            track: Whether to track the experiment using Weights & Biases.
                Default: ``False``.
            wandb_project_name: The name of the Weights & Biases project to
                log to. If ``None``, wandb will generate a project name.
                Default: ``None``.
            wandb_entity: The Weights & Biases entity to log to. Default: ``None``.
            record_video: Whether to record a video of the agent's performance
                every ``log_frequency`` timesteps. Default: ``False``.
                The video will be saved to the TensorBoard log directory and,
                if tracking is enabled, uploaded to Weights & Biases.
            show_progress: Whether to show a progress bar during training.
                Default: ``True``.
            verbose: The verbosity level: 0 none, 1 training information,
                2 debug. Default: 1. If set to 0, the progress bar will not
                be shown, even if ``show_progress`` is ``True``.
            callback: An optional callback to use during training. Default:
                ``None``.

        Returns:
            A dictionary containing the training statistics and metadata:

                - ``total_timesteps``: The total number of timesteps trained for.
                - ``wall_time``: The total training time, in seconds.
                - ``total_episodes``: The total number of episodes completed.
                - ``episode_infos``: A list of dictionaries containing info
                    about each episode.
                - ``log_dir``: The path to the TensorBoard log directory.
                - ``eval_results``: A dictionary containing the results of the
                    evaluations, if ``eval`` is ``True``. Each key corresponds
                    to a global step at which the evaluation was performed,
                    and each value is a list of tuples containing the mean and
                    standard deviation of the episodic returns for each layer
                    of the network.
        """
        run_name = f'{self._env_spec.env_id}__{exp_name}__{self.seed}__{int(time.time())}_{self.q_net_kwargs["folding_mode"]}_{self.q_net_kwargs.get("goodness_type", "variance")}'
        if track:
            import wandb
            wandb.init(
                project=wandb_project_name,
                entity=wandb_entity,
                sync_tensorboard=True,
                config=self.config,
                name=run_name,
                monitor_gym=True,
                save_code=True,
            )

        log_dir = f'runs/{run_name}'
        writer = SummaryWriter(log_dir)
        writer.add_text(
            'hyperparameters',
            '|param|value|\n|-|-|\n%s' % ('\n'.join([
                f'|{key}|{value}|'for key, value in self.config.items()
            ])),
        )

        if save_checkpoints:
            # checkpoint_path = os.path.join(
            #     writer.logdir, 'checkpoints', 'model_{step}.pt')
            checkpoint_path = (Path(log_dir) / 'checkpoints').absolute().as_posix()

            # Save the model with the best performance on the evaluation environment every eval_frequency timesteps
            options = ocp.CheckpointManagerOptions(
                max_to_keep=3,
                best_fn=lambda eval_results: eval_results[-1][-1],
                best_mode='max',
                create=True,  # Allow creating the directory if it doesn't exist
            )

            checkpoint_manager = ocp.CheckpointManager(
                checkpoint_path,
                ocp.PyTreeCheckpointer(),
                options,
            )

        if verbose > 0:
            console.log(
                f'Training {self.__class__.__name__} on {self._env_spec.env_id}...')
            console.log('Run name:', run_name)
            console.log('Logging to:', writer.logdir)
            console.log('Logging to Weights & Biases:',
                        'Yes' if track else 'No')
            if track:
                console.log(f'\tWeights & Biases project: {wandb_project_name}')
                console.log(f'\tWeights & Biases entity: {wandb_entity}')
            console.log('Logging video:', 'Yes' if record_video else 'No')
            console.log('Total timesteps:', total_timesteps)
            console.log('Log frequency:', log_frequency)
            console.log('Hyperparameters:', self.config)

        # Seeding
        random.seed(self.seed)
        np.random.seed(self.seed)

        env = self._env_spec.make_env(
            record_video=record_video, run_log_dir=writer.logdir, seed=self.seed)

        ep_infos = deque(maxlen=log_frequency)
        start_time = time.time()
        obs, _ = env.reset(seed=self.seed)

        eval_results = []
        eval_envs = None

        def _eval(global_step: int) -> None:
            """Wrapper for conditionally evaluating the agent and mutating ``eval_results``."""
            if not eval:
                return

            # Evaluate if on the first timestep, the last timestep, or at a multiple of ``eval_frequency``
            is_first_timestep = global_step == 0
            is_last_timestep = global_step == total_timesteps - 1
            is_multiple_of_eval_frequency = eval_frequency is not None and global_step % eval_frequency == 0
            if is_first_timestep or is_last_timestep or is_multiple_of_eval_frequency:
                nonlocal eval_envs
                current_eval_results, eval_envs = self.evaluate(
                    global_step,
                    writer,
                    verbose,
                    eval_episodes,
                    eval_envs
                )
                eval_results.append(current_eval_results)

                if save_checkpoints: 
                    # Save the model if it's the best so far
                    checkpoint_manager.save(global_step, self._q_net_states, metrics=eval_results)


        show_progress = show_progress and verbose > 0
        with tqdm.trange(total_timesteps, desc='Training', disable=not show_progress) as progress:
            if callback:
                callback.init_callback(self, env)
                callback.on_training_start(
                    locals(), globals(), current_timestep=0)

            # Initialize the last activations to zeros
            last_activations = self._zero_activations

            #last_100_activations = deque(maxlen=100)
            #last_100_layer_grads = [deque(maxlen=100) for _ in range(self.num_layers)]

            for global_step in progress:
                if callback:
                    callback.on_rollout_start()

                # if save_checkpoints and checkpoint_manager.latest_step() is not None:

                # Get the current epsilon value according to a linear schedule
                epsilon = self._get_epsilon(global_step, total_timesteps)


                # Randomly decide whether to explore or exploit the environment
                # If exp_exp_tradeoff > greater than epsilon --> exploitation
                # Otherwise, --> exploration (random action)
                if random.random() < epsilon:
                    # Take a random action
                    actions = env.action_space.sample()
                    activations = self._zero_activations

                else:
                    # Take the best action according to the current Q-network
                    actions, activations = self.predict(obs, last_activations)

                # Execute the game and log data.
                next_obs, rewards, terminated, truncated, infos = env.step(
                    actions)

                if callback:
                    callback.update_locals(locals())
                    # If the callback returns False, this learning trial is pruned; end the training
                    if callback.on_step(global_step) is False:
                        break

                # Record rewards for plotting purposes
                if 'final_info' in infos:
                    for info in infos['final_info']:
                        # Skip the envs that are not done
                        if 'episode' not in info:
                            continue

                        episodic_return = info['episode']['r']
                        episodic_length = info['episode']['l']
                        ep_infos.append(info['episode'])

                        writer.add_scalar(
                            'charts/episodic_return', episodic_return, global_step)
                        writer.add_scalar(
                            'charts/episodic_length', episodic_length, global_step)
                        writer.add_scalar('charts/epsilon',
                                          epsilon, global_step)

                # Save data to reply buffer; handle `final_observation`
                real_next_obs = next_obs.copy()
                for i, d in enumerate(truncated):
                    if d:
                        real_next_obs[i] = infos['final_observation'][i]

                self._rb.add(obs, real_next_obs, actions,
                             rewards, terminated, last_activations, infos)
                obs = next_obs

                # If the episode is terminated or truncated, set last_activations to zeros
                if terminated or truncated:
                    last_activations = self._zero_activations
                else:
                    last_activations = activations

                # Train the Q-network
                if global_step > self.learning_starts:
                    if global_step % self.train_frequency == 0:
                        if callback:
                            callback.on_rollout_end()

                        # Perform a gradient-descent step on the sampled transitions
                        data = self._rb.sample(self.batch_size)
                        layer_losses, layer_q_values, layer_grads, layer_activations = self._opt_step(
                            data.observations,
                            data.actions,
                            data.next_observations,
                            data.rewards.flatten(),
                            data.dones.flatten(),
                            data.last_activations
                        )

                        # [last_100_layer_grads[i].append(
                        #     layer_grads[i]) for i in range(self.num_layers)]

                        if global_step % log_frequency == 0:
                            metrics = {}

                            if len(ep_infos) > 0:
                                metrics['charts/mean_episodic_return'] = np.mean([
                                    ep_info['r'] for ep_info in ep_infos]).item()
                                metrics['charts/mean_episodic_length'] = np.mean([
                                    ep_info['l'] for ep_info in ep_infos]).item()

                            for i, (loss, q_values) in enumerate(zip(layer_losses, layer_q_values)):
                                metrics[f'losses/layer_{i+1}/td_loss'] = jax.device_get(
                                    loss).item()
                                metrics[f'losses/layer_{i+1}/q_values'] = jax.device_get(
                                    q_values).mean()

                            steps_per_second = int(
                                global_step / (time.time() - start_time))
                            metrics['charts/sps'] = steps_per_second

                            metrics['charts/epsilon'] = epsilon
                            metrics['charts/learning_rate'] = self._lr_schedule(
                                global_step)

                            # metrics['activations/mean_last_100_activations'] = np.mean(
                            #     last_100_activations).item()
                            # metrics['activations/median_last_100_activations'] = np.median(
                            #     last_100_activations).item()
                            # metrics['activations/max_last_100_activations'] = np.max(
                            #     last_100_activations).item()
                            # metrics['activations/var_last_100_activations'] = np.var(
                            #     last_100_activations).item()
                            #
                            # for i, layer_grads in enumerate(last_100_layer_grads):
                            #     metrics[f'grads/mean_last_100_layer_{i+1}_grads'] = np.mean(
                            #         layer_grads).item()
                            #     metrics[f'grads/median_last_100_layer_{i+1}_grads'] = np.median(
                            #         layer_grads).item()
                            #     metrics[f'grads/max_last_100_layer_{i+1}_grads'] = np.max(
                            #         layer_grads).item()
                            #     metrics[f'grads/var_last_100_layer_{i+1}_grads'] = np.var(
                            #         layer_grads).item()

                            for k, v in metrics.items():
                                writer.add_scalar(k, v, global_step)

                            progress.set_postfix({
                                k.split('/')[-1]: str(round(v, 4))
                                for k, v in metrics.items()
                            })

                    # update target network
                    if global_step % self.target_network_frequency == 0:
                        for i, q_state in enumerate(self._q_net_states):
                            self._q_net_states[i] = q_state.replace(
                                target_params=optax.incremental_update(
                                    q_state.params,
                                    q_state.target_params,
                                    self.tau
                                )
                            )

                # Evaluate the Q-network
                _eval(global_step)

        if callback:
            callback.on_training_end()

        _eval(global_step)
        if eval_envs is not None:
            for eval_env in eval_envs:
                eval_env.close()

        # Close the env and write monitor result info to disk
        env.close()
        writer.close()
        if track:
            wandb.finish()

        return dict(
            total_timesteps=total_timesteps,
            wall_time=time.time() - start_time,
            total_episodes=len(ep_infos),
            episode_infos=ep_infos,
            log_dir=writer.logdir,
            eval_results=eval_results
        )

    def _opt_step(
        self,
        observations: np.ndarray,
        actions: np.ndarray,
        next_observations: np.ndarray,
        rewards: np.ndarray,
        dones: np.ndarray,
        last_activations: np.ndarray
    ) -> tuple[list[float], list[float]]:
        """Perform an in-place optimization step on the online Q-network.

        This will perform a gradient-descent step on each layer of the online
        Q-network using local TD targets computed on a per-layer basis. This
        means that a single 'global' optimization step actually consists of
        :math:`L` layer-wise gradient-descent steps.

        Args:
            observations: The observations from the replay buffer.
            actions: The actions from the replay buffer.
            next_observations: The next observations from the replay buffer.
            rewards: The rewards from the replay buffer.
            dones: The dones from the replay buffer.

        Returns:
            The TD-loss and Q-values for each layer.
        """
        # First do a pass through the network to get the activations that are used to compute the TD-targets,
        # since computing the TD targets need the backward connections that carry the activations of the last time step
        _, activations = self.forward(observations, last_activations)
        _, activations_target = self.forward(observations, last_activations, use_target_net=True)

        # Preprocess the observations and next observations into features. We do this after the first pass because
        # forward() expects raw observations.
        observations = self._get_features(observations)
        next_observations = self._get_features(next_observations)
        next_observations_target = next_observations.copy()

        # Just to make the implementation of the for loop below cleaner;
        # the last layer has no last_activations from a next layer
        last_activations.append(jnp.empty((self.batch_size, 0)))
        activations.append(jnp.empty((self.batch_size, 0)))
        activations_target.append(jnp.empty((self.batch_size, 0)))

        losses, q_values = [], []
        grads = []
        for i, (layer, q_state) in enumerate(zip(self._q_net_layers, self._q_net_states)):

            loss, old_val, q_state, observations, next_observations, next_observations_target, layer_grads = \
                self._local_opt_step(layer,
                                     q_state,
                                     observations,
                                     actions,
                                     next_observations,
                                     next_observations_target,
                                     rewards,
                                     dones,
                                     last_activations[i],
                                     last_activations[i + 1],
                                     activations[i],
                                     activations[i + 1],
                                     activations_target[i],
                                     activations_target[i + 1],
                                     self.recurrent_connections,
                                     self.backward_connections,
                                     self.double_q,
                                     self.gamma,
                                     self.huber_loss)

            losses.append(loss)
            q_values.append(old_val)
            self._q_net_states[i] = q_state

            grads.append(layer_grads)

        return losses, q_values, grads, activations

    @classmethod
    @partial(jax.jit, static_argnames=('cls', 'layer', 'double_q', 'gamma', 'huber_loss', 'recurrent_connections', 'backward_connections'))
    def _local_opt_step(
        cls,
        layer: nn.Module,
        q_state: FwdDQNTrainState,
        observations: jax.typing.ArrayLike,
        actions: jax.typing.ArrayLike,
        next_observations: jax.typing.ArrayLike,
        next_observations_target: jax.typing.ArrayLike,
        rewards: jax.typing.ArrayLike,
        dones: jax.typing.ArrayLike,
        layer_last_activations: jax.typing.ArrayLike,
        next_layer_last_activations: jax.typing.ArrayLike,
        layer_activations: jax.typing.ArrayLike,
        next_layer_activations: jax.typing.ArrayLike,
        layer_activations_target: jax.typing.ArrayLike,
        next_layer_activations_target: jax.typing.ArrayLike,
        recurrent_connections: bool,
        backward_connections: bool,
        double_q: bool = True,
        gamma: float = 0.99,
        huber_loss: bool = True
    ) -> tuple[jax.Array, jax.Array, FwdDQNTrainState, jax.typing.ArrayLike, jax.typing.ArrayLike, jax.typing.ArrayLike, list[jax.Array]]:
        """Perform a local optimization step on the given Q-network layer.

        This is a JIT-compiled helper function for the :meth:`_opt_step` method
        that performs a single gradient-descent step on the given Q-network
        layer using local TD targets.

        Args:
            layer: The Q-network layer to optimize. This is treated as a
                constant value by the JAX JIT compiler.
            q_state: The Q-network layer state.
            observations: The features for time :math:`t`. This is either the
                preprocessed observations or the hidden features from the
                previous layer.
            actions: The actions from the replay buffer.
            next_observations: The features for time :math:`t+1`. Like
                ``observations``, this is either the preprocessed observations
                or the hidden features from the previous layer.
            next_observations_target: The features for time :math:`t+1` for
                the target network. Like ``observations``, this is either the
                preprocessed observations or the hidden features from the
                previous layer.
            rewards: The rewards from the replay buffer.
            dones: The dones from the replay buffer.
            double_q: Whether to use double Q-learning. Default: ``True``.
            gamma: The discount factor. Defaults to ``0.99``.
            huber_loss: Whether to use Huber loss instead of MSE. Default: ``True``.

        Returns:
            The TD-loss, the Q-values for the given layer, and the updated
            Q-network layer state.
        """
        

        if recurrent_connections and backward_connections:
            next_x = jnp.concatenate([next_observations, layer_activations, next_layer_activations], axis=1)
            next_x_target = jnp.concatenate([next_observations_target, layer_activations_target, next_layer_activations_target], axis=1)
        elif recurrent_connections:
            next_x = jnp.concatenate([next_observations, layer_activations], axis=1)
            next_x_target = jnp.concatenate([next_observations_target, layer_activations_target], axis=1)
        elif backward_connections:
            next_x = jnp.concatenate([next_observations, next_layer_activations], axis=1)
            next_x_target = jnp.concatenate([next_observations_target, next_layer_activations_target], axis=1)
        else:
            next_x = next_observations
            next_x_target = next_observations_target

        if double_q:
            # Double DQN: select actions for the next states using the
            # online network; a_t+1 = argmax_a Q(s_t+1, a; theta)
            next_outputs, next_q_preds = layer.apply(q_state.params, next_x)
            next_actions = jnp.argmax(next_q_preds, axis=-1)  # (batch_size,)

            # Compute the Q-values in the next state for the chosen next actions using the target network
            next_target_outputs, next_target_q_preds = layer.apply(q_state.target_params, next_x_target)
            q_next_target = jnp.take_along_axis(next_target_q_preds, next_actions[:, None], axis=-1)[:, 0]

        else:
            # Regular DQN: compute the Q-values in the next state using the target network
            next_target_outputs, next_target_q_preds = layer.apply(q_state.target_params, next_x_target)
            q_next_target = jnp.max(next_target_q_preds, axis=-1)  # (batch_size,)

        # Compute the temporal difference (TD) targets using the Bellman equation
        # Note that we use a mask to zero out the Q-values for the terminal states
        next_q_value = rewards + (1 - dones) * gamma * q_next_target

        def loss_fn(params: flax.core.FrozenDict[str, Any]) -> tuple[jax.Array, tuple[jax.Array, jax.Array, jax.Array, jax.Array]]:
            """The Q-learning objective function to minimize."""

            if recurrent_connections and backward_connections:
                x = jnp.concatenate([observations, layer_last_activations, next_layer_last_activations], axis=1)
            elif recurrent_connections:
                x = jnp.concatenate([observations, layer_last_activations], axis=1)
            elif backward_connections:
                x = jnp.concatenate([observations, next_layer_last_activations], axis=1)
            else:
                x = observations

            # Compute Q(s_t, a) - the Q values for the current state
            #
            # The model computes Q(s_t), then we select the columns
            # of actions taken. These are the actions which would've
            # been taken for each batch state according to the q_net
            outputs, q_pred = layer.apply(params, x)  # (batch_size, num_actions)
            # (batch_size,)
            # q_pred = q_pred[jnp.arange(q_pred.shape[0]), actions.squeeze()]
            q_pred = jnp.take_along_axis(q_pred, actions, axis=-1).squeeze()

            if huber_loss:
                # Use Huber (smooth L1) loss
                td_loss = optax.huber_loss(q_pred, next_q_value).mean()
            else:
                # Use MSE loss
                td_loss = optax.squared_error(q_pred, next_q_value).mean()

            return td_loss, (q_pred, outputs, next_outputs, next_target_outputs)

        # Compute the gradients of the objective function with respect to the
        # Q-network parameters for the given layer
        (loss_value, (q_pred, outputs, next_outputs, next_target_outputs)), grads = jax.value_and_grad(
            loss_fn, has_aux=True)(q_state.params)


        q_state = q_state.apply_gradients(grads=grads)
        return loss_value, q_pred, q_state, outputs, next_outputs, next_target_outputs, grads['params']['dense']['kernel']

    def _get_epsilon(self, step: int, total_timesteps: int) -> float:
        """Return the epsilon value for the given training step.

        The epsilon value is linearly annealed from ``start_eps`` to
        ``end_eps`` over the first ``exploration_fraction`` of the total
        training steps, excluding the ``learning_starts`` steps at the
        beginning of training, and then kept constant for the remainder
        of training.

        Args:
            step: The current training step.
            total_timesteps: The total number of training steps.

        Returns:
            The epsilon value for the given training step.
        """
        return linear_schedule(
            self.start_eps,
            self.end_eps,
            int(self.exploration_fraction * total_timesteps),
            max(step - self.learning_starts, 0),
        )

    @property
    def config(self) -> dict[str, Any]:
        """Return the network configuration."""
        conf = {k: v for k, v in vars(self).items() if not k.startswith('_')}
        conf['buffer_size'] = self._rb.buffer_size
        return conf

    @property
    def num_layers(self) -> int:
        """Return the number of layers in the network."""
        return len(self._q_net_layers)

