import os
import shutil
import json
import logging
import time
from collections import deque
from typing import Tuple
import tree
import numpy as np
import jax
import jax.numpy as jnp
from jax.lax import stop_gradient
from flax import serialization as flax_serialization
from flax.training.train_state import TrainState
from flax.training import orbax_utils
import orbax.checkpoint
import optax
import wandb
import pickle
from functools import partial

from irl_baselines.algorithms.trirl_trpl.flax.general_properties import GeneralProperties
from irl_baselines.algorithms.trirl_trpl.flax.policy import get_policy
from irl_baselines.algorithms.trirl_trpl.flax.critic import get_critic
from irl_baselines.algorithms.trirl_trpl.flax.discriminator import get_discriminator, get_reward_approximator, DiscriminatorBuffer
from irl_baselines.algorithms.trirl_trpl.flax.batch import Batch
from irl_baselines.algorithms.ppo.flax.batch import Batch as BatchPPO
from irl_baselines.algorithms.trirl_trpl.flax.trust_region_layer import *
from irl_baselines.algorithms.data_utils import prepare_expert_data

rlx_logger = logging.getLogger("rl_x")
# jax.config.update("jax_default_matmul_precision", "high")

class TRIRL_TRPL:
    def __init__(self, config, env, eval_env, run_path, writer) -> None:
        self.config = config
        self.env = env
        self.writer = writer

        self.save_model = config.runner.save_model
        self.save_path = os.path.join(run_path, "models")
        self.track_console = config.runner.track_console
        self.track_tb = config.runner.track_tb
        self.track_wandb = config.runner.track_wandb
        self.seed = config.environment.seed
        self.total_timesteps = config.algorithm.total_timesteps
        self.nr_envs = config.environment.nr_envs
        self.learning_rate = config.algorithm.learning_rate
        self.anneal_learning_rate = config.algorithm.anneal_learning_rate
        self.nr_steps = config.algorithm.nr_steps
        self.nr_epochs = config.algorithm.nr_epochs
        self.minibatch_size = config.algorithm.minibatch_size
        self.gamma = config.algorithm.gamma
        self.gae_lambda = config.algorithm.gae_lambda
        self.clip_range = config.algorithm.clip_range
        self.entropy_coef = config.algorithm.entropy_coef
        self.critic_coef = config.algorithm.critic_coef
        self.max_grad_norm = config.algorithm.max_grad_norm
        self.std_dev = config.algorithm.std_dev
        self.nr_hidden_units = config.algorithm.nr_hidden_units
        self.evaluation_frequency = config.algorithm.evaluation_frequency
        self.evaluation_episodes = config.algorithm.evaluation_episodes
        self.batch_size = config.environment.nr_envs * config.algorithm.nr_steps
        self.nr_updates = config.algorithm.total_timesteps // self.batch_size
        self.nr_minibatches = self.batch_size // self.minibatch_size
        self.subsampling_cutoff = config.algorithm.get("subsampling_cutoff", 1)

        # Global Reward Experiment Flag
        self.global_rew_experiment = config.algorithm.global_rew_experiment

        # Projection layer
        self.mean_bound = config.algorithm.mean_bound
        self.cov_bound = config.algorithm.cov_bound
        self.trust_region_coef = config.algorithm.trust_region_coef

        # TRIRL Specific
        self.data_path = config.algorithm.data_path
        self.nr_epochs_disc = config.algorithm.nr_epochs_disc
        self.learning_rate_disc = config.algorithm.learning_rate_disc
        self.env_reward_frac = config.algorithm.env_reward_frac
        self.handle_absorbing_states = config.algorithm.handle_absorbing_states
        self.epsilon = config.algorithm.epsilon
        self.disc_buffer_capacity = config.algorithm.disc_buffer_capacity
        self.gp_lambda = config.algorithm.gp_lambda
        self.gp_alpha = config.algorithm.gp_alpha
        self.beta = 1/config.algorithm.entropy_coef
        self.disc_buffer = DiscriminatorBuffer(self.disc_buffer_capacity, (self.nr_steps, self.nr_envs))
        self.chunk_size_dict = {10:30, 50:8, 100:4} # dict mapping nr_steps to chunk size (in the case where eta's are not on demand)
        self.reward_type = config.algorithm.reward_type
        self.reward_approximator_type = config.algorithm.reward_approximator_type
        self.save_delay = 50

        # TRIRL - etas on demand
        self.on_demand_etas = config.algorithm.on_demand_etas
        if self.on_demand_etas:
            self.chunk_size_dict = {k: int(v/2 - 1) for k, v in self.chunk_size_dict.items()}
            self.chunk_size_dict[10] = 20
            self.maximum_eta = False # always use per state eta if computing etas on demand
        else:
            self.maximum_eta = True

        # TRIRL - Reward Approximation
        self.reward_fn_approximator = config.algorithm.reward_fn_approximator
        self.nr_epochs_rew = config.algorithm.nr_epochs_rew
        self.learning_rate_reward_fn = config.algorithm.learning_rate_reward_fn

        if self.evaluation_frequency % (self.nr_steps * self.nr_envs) != 0 and self.evaluation_frequency != -1:
            raise ValueError("Evaluation frequency must be a multiple of the number of steps and environments.")

        rlx_logger.info(f"Using device: {jax.default_backend()}")
        
        self.key = jax.random.PRNGKey(self.seed)
        self.key, policy_key, critic_key, discriminator_key = jax.random.split(self.key, 4)

        self.os_shape = env.single_observation_space.shape
        self.as_shape = env.single_action_space.shape
        self.dim = np.prod(self.as_shape).item()
        
        self.policy, self.get_processed_action = get_policy(config, env)
        self.critic = get_critic(config, env)
        self.discriminator = get_discriminator(config, env, reward_type=self.reward_type)

        if self.reward_fn_approximator:
            self.reward_fn = get_reward_approximator(config, env, reward_approximator_type=self.reward_approximator_type)

        self.policy.apply = jax.jit(self.policy.apply)
        self.critic.apply = jax.jit(self.critic.apply)
        self.discriminator.apply = jax.jit(self.discriminator.apply)
        self.H_terminal = jnp.sum(jnp.log(env.single_action_space.high - env.single_action_space.low))

        def linear_schedule(count):
            fraction = 1.0 - (count // (self.nr_minibatches * self.nr_epochs)) / self.nr_updates
            return self.learning_rate * fraction

        def linear_schedule_disc(count):
            fraction = 1.0 - (count // (self.nr_minibatches * self.nr_epochs_disc)) / ((self.nr_updates * self.nr_epochs) / self.nr_epochs_disc)
            return self.learning_rate * fraction

        learning_rate = linear_schedule if self.anneal_learning_rate else self.learning_rate
        learning_rate_disc = linear_schedule_disc if self.anneal_learning_rate else self.learning_rate_disc

        state = jnp.array([env.single_observation_space.sample()])
        action = jnp.array([env.single_action_space.sample()])
        next_state = jnp.array([env.single_observation_space.sample()])
        absorbing = jnp.array([0.0])

        self.policy_state = TrainState.create(
            apply_fn=self.policy.apply,
            params=self.policy.init(policy_key, state),
            tx=optax.chain(
                optax.clip_by_global_norm(self.max_grad_norm),
                optax.inject_hyperparams(optax.adam)(learning_rate=learning_rate),
            )
        )

        self.critic_state = TrainState.create(
            apply_fn=self.critic.apply,
            params=self.critic.init(critic_key, state),
            tx=optax.chain(
                optax.clip_by_global_norm(self.max_grad_norm),
                optax.inject_hyperparams(optax.adam)(learning_rate=learning_rate),
            )
        )

        self.discriminator_state = TrainState.create(
            apply_fn=self.discriminator.apply,
            params=self.discriminator.init(discriminator_key, state, action, next_state, absorbing),
            tx=optax.chain(
                optax.clip_by_global_norm(self.max_grad_norm),
                optax.inject_hyperparams(optax.adam)(learning_rate=learning_rate_disc),
            )
        )

        if self.reward_fn_approximator:
            learning_rate_reward_fn = self.learning_rate_reward_fn
            self.key, reward_fn_key = jax.random.split(self.key)
            self.reward_fn_state = TrainState.create(
                apply_fn=self.reward_fn.apply,
                params=self.reward_fn.init(reward_fn_key, state, action, next_state),
                tx=optax.chain(
                    optax.clip_by_global_norm(self.max_grad_norm),
                    optax.inject_hyperparams(optax.adam)(learning_rate=learning_rate_reward_fn),
                )
            )

        if self.on_demand_etas:
            # Save init policy weights
            self.disc_buffer.append_policy(self.policy_state.params)

        if self.save_model:
            os.makedirs(self.save_path)
            self.best_mean_return = -np.inf
            self.best_model_file_name = "best.model"
            self.best_model_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
    
    def train_irl(self):    
        @jax.jit
        def get_action_and_value(policy_state: TrainState, critic_state: TrainState, state: np.ndarray, key: jax.random.PRNGKey):
            action_mean, action_logstd = self.policy.apply(policy_state.params, state)
            action_std = jnp.exp(action_logstd)
            key, subkey = jax.random.split(key)
            action = action_mean + action_std * jax.random.normal(subkey, shape=action_mean.shape)
            log_prob = -0.5 * ((action - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd
            value = self.critic.apply(critic_state.params, state)
            processed_action = self.get_processed_action(action)
            return processed_action, action, value.reshape(-1), log_prob.sum(1), key, action_mean, action_logstd

        @jax.jit
        def get_log_density_ratio(discriminator_params, state: np.ndarray, action: np.ndarray, next_state: np.ndarray, absorbing: np.ndarray):
            logits = self.discriminator.apply(discriminator_params, state, action, next_state, absorbing)
            return logits
        
        get_log_density_ratio = jax.vmap(get_log_density_ratio, in_axes=(None, 0, 0, 0, 0), out_axes=0)

        @jax.jit
        def get_reward_prediction(reward_fn_params, state: np.ndarray, action: np.ndarray, next_state: np.ndarray):
            logits = self.reward_fn.apply(reward_fn_params, state, action, next_state)
            return logits
        
        get_reward_prediction = jax.vmap(get_reward_prediction, in_axes=(None, 0, 0, 0), out_axes=0)

        @jax.jit
        def calculate_gae_advantages(critic_state: TrainState, next_states: np.ndarray, rewards: np.ndarray, terminations: np.ndarray, values: np.ndarray):
            def compute_advantages(carry, t):
                prev_advantage = carry[0]
                advantage = delta[t] + self.gamma * self.gae_lambda * (1 - terminations[t]) * prev_advantage
                return (advantage,), advantage

            next_values = self.critic.apply(critic_state.params, next_states).squeeze(-1)
            delta = rewards + self.gamma * next_values * (1.0 - terminations) - values
            init_advantages = delta[-1]
            _, advantages = jax.lax.scan(compute_advantages, (init_advantages,), jnp.arange(self.nr_steps - 2, -1, -1))
            advantages = jnp.concatenate([advantages[::-1], jnp.array([init_advantages])])
            returns = advantages + values
            return advantages, returns

        @jax.jit
        def calculate_gae_advantages_absorbing(critic_state: TrainState, next_states: np.ndarray, rewards: np.ndarray, rewards_next_state: np.ndarray, terminations: np.ndarray, values: np.ndarray):
            """
            Correctly handle absorbing state value and entropy (instead of setting to 0.0)
            """
            def compute_advantages(carry, t):
                prev_advantage = carry[0]
                advantage = delta[t] + self.gamma * self.gae_lambda * (1 - terminations[t]) * prev_advantage
                return (advantage,), advantage

            next_values = self.critic.apply(critic_state.params, next_states).squeeze(-1)
            terminal_tail = (self.gamma / (1.0 - self.gamma)) * (rewards_next_state + self.entropy_coef * self.H_terminal)
            delta = rewards + self.gamma * next_values * (1.0 - terminations) + (terminations * terminal_tail) - values
            init_advantages = delta[-1]
            _, advantages = jax.lax.scan(compute_advantages, (init_advantages,), jnp.arange(self.nr_steps - 2, -1, -1))
            advantages = jnp.concatenate([advantages[::-1], jnp.array([init_advantages])])
            returns = advantages + values
            return advantages, returns


        """
        TRIRL Update
        """
        @partial(jax.jit, static_argnames=('reward_type'))
        def trirl_update(discriminator_state: TrainState,
                   states: np.ndarray, actions: np.ndarray,
                   next_states: np.ndarray, absorbing: np.ndarray,
                   expert_states: np.ndarray, expert_actions: np.ndarray,
                   expert_next_states: np.ndarray, expert_absorbing: np.ndarray,
                   key: jax.random.PRNGKey, reward_type='state-action'):

            def trirl_loss_fn(discriminator_params, state, action, expert_state, expert_action, label, expert_label, next_state=None, absorbing=None, expert_next_state=None, expert_absorbing=None):
                logits = self.discriminator.apply(discriminator_params, state, action, next_state, absorbing)
                expert_logits = self.discriminator.apply(discriminator_params, expert_state, expert_action, expert_next_state, expert_absorbing)
                bce_agent = optax.sigmoid_binary_cross_entropy(logits, label).mean()
                bce_expert = optax.sigmoid_binary_cross_entropy(expert_logits, expert_label).mean()

                # Gradient penalty on an interpolated sample between the expert and agent
                alpha = self.gp_alpha
                gp_lambda = self.gp_lambda
                interpolated_state = alpha * expert_state + (1 - alpha) * state
                interpolated_action = alpha * expert_action + (1 - alpha) * action
                interpolated_next_state = alpha * expert_next_state + (1 - alpha) * next_state
                interpolated_abs = 0.0 * expert_absorbing # assume interpolated state to be non-absorbing                
                grad_state, grad_action, grad_next_state = jax.grad(lambda s, a, sn, ab: jnp.sum(self.discriminator.apply(discriminator_params, s, a, sn, ab)), argnums=(0, 1, 2))(interpolated_state, interpolated_action, interpolated_next_state, interpolated_abs)                
                if reward_type in ["state-action", "uncorrelated"]:
                    grad_norm = jnp.sqrt(jnp.sum(jnp.square(grad_state)) + jnp.sum(jnp.square(grad_action)))
                elif reward_type in ["state-based", "shaped"]:
                    grad_norm = jnp.sqrt(jnp.sum(jnp.square(grad_state)) + jnp.sum(jnp.square(grad_next_state)))
                elif reward_type in ["shaped-sa"]:
                    grad_norm = jnp.sqrt(jnp.sum(jnp.square(grad_state)) + jnp.sum(jnp.square(grad_action)) + jnp.sum(jnp.square(grad_next_state)))
                gp = (grad_norm - 1.0) ** 2

                bce_loss = bce_agent + bce_expert + gp_lambda * gp
                metrics = {
                    "loss/discriminator_loss": bce_loss,
                    "loss/discriminator_agent_loss": bce_agent,
                    "loss/discriminator_expert_loss": bce_expert,
                    "loss/discriminator_gp": gp
                }
                return bce_loss, (metrics)


            batch_states = states.reshape((-1,) + self.os_shape)
            batch_actions = actions.reshape((-1,) + self.as_shape)

            # Expert batch
            key, shuffle_key = jax.random.split(key)
            perm = jax.random.permutation(shuffle_key, expert_states.shape[0])
            expert_states = expert_states[perm]
            expert_actions = expert_actions[perm]
            batch_expert_states = expert_states[:self.batch_size]
            batch_expert_actions = expert_actions[:self.batch_size]
            expert_labels = jnp.ones((self.batch_size, 1), dtype=jnp.float32)
            rollout_labels = jnp.zeros((self.batch_size, 1), dtype=jnp.float32)

            batch_next_states = next_states.reshape((-1,) + self.os_shape)
            batch_absorbing = absorbing.reshape(-1)
            expert_next_states = expert_next_states[perm]
            expert_absorbing = expert_absorbing[perm]
            batch_expert_next_states = expert_next_states[:self.batch_size]
            batch_expert_absorbing = expert_absorbing[:self.batch_size]

            vmap_trirl_loss_fn = jax.vmap(trirl_loss_fn, in_axes=(None, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), out_axes=0)
            safe_mean = lambda x: jnp.mean(x) if x is not None else x
            mean_vmapped_trirl_loss_fn = lambda *a, **k: tree.map_structure(safe_mean, vmap_trirl_loss_fn(*a, **k))
            grad_trirl_loss_fn = jax.value_and_grad(mean_vmapped_trirl_loss_fn, argnums=(0), has_aux=True)

            key, subkey = jax.random.split(key)
            batch_indices_disc = jnp.tile(jnp.arange(self.batch_size), (self.nr_epochs_disc, 1))
            batch_indices_disc = jax.random.permutation(subkey, batch_indices_disc, axis=1, independent=True)
            batch_indices_disc = batch_indices_disc.reshape((self.nr_epochs_disc * self.nr_minibatches, self.minibatch_size))

            def minibatch_update(carry, minibatch_indices_disc):

                discriminator_state = carry

                # TRIRL UPDATE
                (trirl_loss, (metrics)), (discriminator_gradients) = grad_trirl_loss_fn(
                    discriminator_state.params,
                    batch_states[minibatch_indices_disc],
                    batch_actions[minibatch_indices_disc],
                    batch_expert_states[minibatch_indices_disc],
                    batch_expert_actions[minibatch_indices_disc],
                    rollout_labels[minibatch_indices_disc],
                    expert_labels[minibatch_indices_disc],
                    batch_next_states[minibatch_indices_disc],
                    batch_absorbing[minibatch_indices_disc],
                    batch_expert_next_states[minibatch_indices_disc],
                    batch_expert_absorbing[minibatch_indices_disc],
                )

                discriminator_state = discriminator_state.apply_gradients(grads=discriminator_gradients)
                carry = (discriminator_state)

                return carry, (metrics)
            
            init_carry = (discriminator_state)
            carry, (metrics) = jax.lax.scan(minibatch_update, init_carry, batch_indices_disc)
            discriminator_state = carry

            # Calculate mean metrics
            mean_metrics = {key: jnp.mean(metrics[key]) for key in metrics}
            mean_metrics["lr/learning_rate_disc"] = discriminator_state.opt_state[1].hyperparams["learning_rate"]

            return discriminator_state, mean_metrics, key


        @jax.jit
        def update(policy_state: TrainState, critic_state: TrainState,
                   states: np.ndarray, actions: np.ndarray, advantages: np.ndarray, returns: np.ndarray, values: np.ndarray, log_probs: np.ndarray,
                   key: jax.random.PRNGKey,
                   old_action_means:np.ndarray, old_action_logstd:np.ndarray, global_step):
            
            def loss_fn(policy_params, critic_params, state_b, action_b, log_prob_b, return_b, advantage_b, old_mean_b, old_logstd_b, global_step):
                # Policy eval
                action_mean, action_logstd = self.policy.apply(policy_params, state_b)

                # Entropy projection (project std to satisfy an entropy constraint)
                old_logstd_b = jnp.expand_dims(old_logstd_b, 0) # shape (1, as_shape)
                old_entropy = jnp.sum(old_logstd_b + 0.5 * jnp.log(2.0 * jnp.pi * jnp.e))
                tau = 0.5
                beta = old_entropy * tau ** (10 * jnp.clip(global_step, a_max=self.total_timesteps) / self.total_timesteps)
                # Entropy projection before KL projection
                # action_logstd = entropy_projection(action_logstd, beta, dim=self.dim)

                # Projection layer
                action_std = jnp.exp(action_logstd)
                old_action_mean = old_mean_b
                old_action_std = jnp.exp(old_logstd_b)

                proj_action_mean, proj_action_std, eta_mu, eta_cov, kl_mean_part, post_proj_kl_mean_part, kl_cov_part, post_proj_kl_cov_part = kl_projection(action_mean, action_std, old_action_mean, old_action_std, self.mean_bound, self.cov_bound)
                # Entropy projection after KL projection
                proj_action_logstd = jnp.log(proj_action_std)
                proj_action_logstd = entropy_projection(proj_action_logstd, beta, dim=self.dim)

                # Trust Region Loss (amortized optimization)
                proj_action_mean_det = jax.lax.stop_gradient(proj_action_mean) # detaching to prevent gradient flow through projection from the regression term
                proj_action_std_det  = jax.lax.stop_gradient(proj_action_std)
                tr_loss_maha = 0.5 * jnp.sum(((proj_action_mean_det - action_mean)/ proj_action_std_det) ** 2, axis=1)
                tr_loss_cov_part = 0.5 * jnp.sum(2.0 * (jnp.log(proj_action_std_det) - jnp.log(action_std)) + (action_std/proj_action_std_det)**2 - 1.0, axis=1)
                trust_region_loss = tr_loss_maha + tr_loss_cov_part

                action_mean = proj_action_mean
                action_logstd = proj_action_logstd
                action_std = jnp.exp(action_logstd)

                new_log_prob = -0.5 * ((action_b - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd
                new_log_prob = new_log_prob.sum(1)
                entropy = action_logstd + 0.5 * jnp.log(2.0 * jnp.pi * jnp.e)
                entropy = self.entropy_coef * entropy # scaling to reduce critic loss
                
                logratio = new_log_prob - log_prob_b
                ratio = jnp.exp(logratio)
                approx_kl_div = (ratio - 1) - logratio
                clip_fraction = jnp.float32((jnp.abs(ratio - 1) > self.clip_range))

                pg_loss1 = -advantage_b * ratio
                pg_loss2 = -advantage_b * jnp.clip(ratio, 1 - self.clip_range, 1 + self.clip_range)
                pg_loss = jnp.maximum(pg_loss1, pg_loss2)
                
                entropy_loss = entropy.sum(1)
                
                # Critic loss
                new_value = self.critic.apply(critic_params, state_b) 
                # return_b = return_b * self.entropy_coef
                critic_loss = 0.5 * (new_value - return_b) ** 2

                # Combine losses
                loss = pg_loss - self.entropy_coef * entropy_loss + self.critic_coef * critic_loss + self.trust_region_coef * trust_region_loss

                # Create metrics
                metrics = {
                    "loss/policy_gradient_loss": pg_loss,
                    "loss/critic_loss": critic_loss,
                    "loss/entropy_loss": entropy_loss,
                    "loss/trust_region_loss": trust_region_loss,
                    "policy_ratio/approx_kl": approx_kl_div,
                    "policy_ratio/clip_fraction": clip_fraction,
                    "projection/eta_mu": eta_mu,
                    "projection/eta_cov": eta_cov,
                    "projection/unprojected_kl_mean": kl_mean_part,
                    "projection/projected_kl_mean": post_proj_kl_mean_part,
                    "projection/unprojected_kl_cov": kl_cov_part,
                    "projection/projected_kl_cov": post_proj_kl_cov_part,
                }

                return loss, (metrics, jnp.maximum(eta_mu, eta_cov))
            
            # Gradients must only flow through the predicted ones
            old_action_means = stop_gradient(old_action_means)
            old_action_logstd = stop_gradient(old_action_logstd)

            batch_states = states.reshape((-1,) + self.os_shape)
            batch_actions = actions.reshape((-1,) + self.as_shape)
            batch_advantages = advantages.reshape(-1)
            batch_returns = returns.reshape(-1)
            batch_log_probs = log_probs.reshape(-1)
            batch_action_means = old_action_means.reshape((-1,) + self.as_shape)

            vmap_loss_fn = jax.vmap(loss_fn, in_axes=(None, None, 0, 0, 0, 0, 0, 0, None, None), out_axes=0)
            safe_mean = lambda x: jnp.mean(x) if x is not None else x
            mean_vmapped_loss_fn = lambda *a, **k: (lambda out: (tree.map_structure(safe_mean, out[0]), (tree.map_structure(safe_mean, out[1][0]), out[1][1])))(vmap_loss_fn(*a, **k))

            grad_loss_fn = jax.value_and_grad(mean_vmapped_loss_fn, argnums=(0, 1), has_aux=True)

            key, subkey = jax.random.split(key)
            batch_indices = jnp.tile(jnp.arange(self.batch_size), (self.nr_epochs, 1))
            batch_indices = jax.random.permutation(subkey, batch_indices, axis=1, independent=True)
            batch_indices = batch_indices.reshape((self.nr_epochs * self.nr_minibatches, self.minibatch_size))

            def minibatch_update(carry, minibatch_indices):
                policy_state, critic_state = carry

                minibatch_advantages = batch_advantages[minibatch_indices]
                minibatch_advantages = (minibatch_advantages - jnp.mean(minibatch_advantages)) / (jnp.std(minibatch_advantages) + 1e-8)

                (loss, (metrics, etas)), (policy_gradients, critic_gradients) = grad_loss_fn(
                    policy_state.params,
                    critic_state.params,
                    batch_states[minibatch_indices],
                    batch_actions[minibatch_indices],
                    batch_log_probs[minibatch_indices],
                    batch_returns[minibatch_indices],
                    minibatch_advantages,
                    batch_action_means[minibatch_indices],
                    old_action_logstd, global_step
                )

                policy_state = policy_state.apply_gradients(grads=policy_gradients)
                critic_state = critic_state.apply_gradients(grads=critic_gradients)

                metrics["gradients/policy_grad_norm"] = optax.global_norm(policy_gradients)
                metrics["gradients/critic_grad_norm"] = optax.global_norm(critic_gradients)

                carry = (policy_state, critic_state)

                return carry, (metrics, etas)
            
            init_carry = (policy_state, critic_state)
            carry, (metrics, etas) = jax.lax.scan(minibatch_update, init_carry, batch_indices)
            policy_state, critic_state = carry

            # Calculate mean metrics
            mean_metrics = {key: jnp.mean(metrics[key]) for key in metrics}
            mean_metrics["lr/learning_rate"] = policy_state.opt_state[1].hyperparams["learning_rate"]
            mean_metrics["v_value/explained_variance"] = 1 - jnp.var(returns - values) / (jnp.var(returns) + 1e-8)
            mean_metrics["policy/std_dev"] = jnp.mean(jnp.exp(policy_state.params["params"]["policy_logstd"]))

            # Calculate max per-epoch eta
            etas = etas.reshape((self.nr_epochs, self.nr_steps, self.nr_envs))
            etas = etas.max(axis=0)

            return policy_state, critic_state, mean_metrics, etas, key


        """
        Reward Approximator Update
        """

        @jax.jit
        def reward_fn_update(reward_fn_state: TrainState,
                   states: np.ndarray, actions: np.ndarray, next_states: np.ndarray,
                   target_reward: np.ndarray,
                   key: jax.random.PRNGKey):

            def loss_fn(reward_fn_params, state, action, target_reward, next_state=None):
                logits = self.reward_fn.apply(reward_fn_params, state, action, next_state)
                mse = optax.squared_error(logits, target_reward).mean()

                metrics = {
                    "loss/reward_fn_loss": mse,
                }
                return mse, (metrics)


            batch_states = states.reshape((-1,) + self.os_shape)
            batch_actions = actions.reshape((-1,) + self.as_shape)
            batch_target_reward = target_reward.reshape((-1,) + (1,))

            batch_next_states = next_states.reshape((-1,) + self.os_shape)
            vmap_loss_fn = jax.vmap(loss_fn, in_axes=(None, 0, 0, 0, 0), out_axes=0)
            safe_mean = lambda x: jnp.mean(x) if x is not None else x
            mean_vmapped_loss_fn = lambda *a, **k: tree.map_structure(safe_mean, vmap_loss_fn(*a, **k))
            grad_loss_fn = jax.value_and_grad(mean_vmapped_loss_fn, argnums=(0), has_aux=True)

            key, subkey = jax.random.split(key)
            batch_indices_rew = jnp.tile(jnp.arange(self.batch_size), (self.nr_epochs_rew, 1))
            batch_indices_rew = jax.random.permutation(subkey, batch_indices_rew, axis=1, independent=True)
            batch_indices_rew = batch_indices_rew.reshape((self.nr_epochs_rew * self.nr_minibatches, self.minibatch_size))

            def minibatch_update(carry, minibatch_indices_rew):

                reward_fn_state = carry

                # REWARD UPDATE
                (rew_fn_loss, (metrics)), (rew_fn_gradients) = grad_loss_fn(
                    reward_fn_state.params,
                    batch_states[minibatch_indices_rew],
                    batch_actions[minibatch_indices_rew],
                    batch_target_reward[minibatch_indices_rew],
                    batch_next_states[minibatch_indices_rew],
                )

                reward_fn_state = reward_fn_state.apply_gradients(grads=rew_fn_gradients)
                carry = (reward_fn_state)

                return carry, (metrics)
            
            init_carry = (reward_fn_state)
            carry, (metrics) = jax.lax.scan(minibatch_update, init_carry, batch_indices_rew)
            reward_fn_state = carry

            # Calculate mean metrics
            mean_metrics = {key: jnp.mean(metrics[key]) for key in metrics}
            mean_metrics["lr/learning_rate_disc"] = reward_fn_state.opt_state[1].hyperparams["learning_rate"]

            return reward_fn_state, mean_metrics, key

        @jax.jit
        def get_deterministic_action(policy_state: TrainState, state: np.ndarray):
            action_mean, action_logstd = self.policy.apply(policy_state.params, state)
            return self.get_processed_action(action_mean)

        chunked_project = make_chunked_ensemble_rew_project(
            get_log_density_ratio, 
            nr_steps=self.nr_steps, 
            nr_envs=self.nr_envs, 
            epsilon=self.epsilon, 
            beta=self.beta,
            entropy_coef=self.entropy_coef,
            maximum_eta=self.maximum_eta
        )

        chunked_compute_etas = make_chunked_compute_etas(
            policy_apply_fn=lambda p, s: self.policy.apply(p, s),
            kl_cov_proj=kl_cov_proj,
            nr_steps=self.nr_steps,
            nr_envs=self.nr_envs,
            mean_bound=self.mean_bound,
            cov_bound=self.cov_bound,
        )


        """ Main Training Loop """

        self.set_train_mode()

        # Get demonstrations
        demonstrations = prepare_expert_data(self.data_path, self.subsampling_cutoff)
        self.expert_states = demonstrations["states"]
        self.expert_next_states = demonstrations["next_states"]
        self.expert_actions = demonstrations["actions"]
        self.expert_absorbing = demonstrations["absorbing"].flatten()

        batch = Batch(
            states=np.zeros((self.nr_steps, self.nr_envs) + self.os_shape),
            next_states=np.zeros((self.nr_steps, self.nr_envs) + self.os_shape),
            actions=np.zeros((self.nr_steps, self.nr_envs) + self.as_shape),
            rewards=np.zeros((self.nr_steps, self.nr_envs)),
            proj_rewards=np.zeros((self.nr_steps, self.nr_envs)),
            merged_rewards=np.zeros((self.nr_steps, self.nr_envs)),
            etas = np.zeros((self.nr_steps, self.nr_envs)),
            values=np.zeros((self.nr_steps, self.nr_envs)),
            terminations=np.zeros((self.nr_steps, self.nr_envs)),
            log_probs=np.zeros((self.nr_steps, self.nr_envs)),
            advantages=np.zeros((self.nr_steps, self.nr_envs)),
            returns=np.zeros((self.nr_steps, self.nr_envs)),
            old_action_means=np.zeros((self.nr_steps, self.nr_envs) + self.as_shape),
            old_action_logstd=np.zeros(self.as_shape),
        )

        saving_return_buffer = deque(maxlen=100 * self.nr_envs)

        state, _ = self.env.reset()
        global_step = 0
        nr_updates = 0
        nr_updates_disc = 0
        nr_episodes = 0
        steps_metrics = {}
        while global_step < self.total_timesteps:
            start_time = time.time()
            time_metrics = {}

            # Acting
            dones_this_rollout = 0
            step_info_collection = {}
            for step in range(self.nr_steps):
                processed_action, action, value, log_prob, self.key, action_mean, action_logstd = get_action_and_value(self.policy_state, self.critic_state, state, self.key)
                next_state, reward, terminated, truncated, info = self.env.step(jax.device_get(processed_action))
                done = terminated | truncated
                actual_next_state = next_state.copy()
                for i, single_done in enumerate(done):
                    if single_done:
                        actual_next_state[i] = np.array(self.env.get_final_observation_at_index(info, i))
                        saving_return_buffer.append(self.env.get_final_info_value_at_index(info, "episode_return", i))
                        dones_this_rollout += 1
                for key, info_value in self.env.get_logging_info_dict(info).items():
                    step_info_collection.setdefault(key, []).extend(info_value)

                batch.states[step] = state
                batch.next_states[step] = actual_next_state
                batch.actions[step] = action
                batch.rewards[step] = reward
                batch.values[step] = value
                batch.terminations[step] = terminated
                batch.log_probs[step] = log_prob
                batch.old_action_means[step] = action_mean
                batch.old_action_logstd = action_logstd

                state = next_state
                global_step += self.nr_envs
            nr_episodes += dones_this_rollout
            acting_end_time = time.time()
            time_metrics["time/acting_time"] = acting_end_time - start_time


            """ Optimizing Discriminator """
            self.discriminator_state, trirl_optimization_metrics, self.key = trirl_update(
                self.discriminator_state,
                batch.states, batch.actions, batch.next_states, batch.terminations, self.expert_states, self.expert_actions, self.expert_next_states, self.expert_absorbing,
                self.key, self.reward_type
            )            

            projection_start_time = time.time()
            # Add optimized discriminator to buffer
            self.disc_buffer.append(self.discriminator_state.params)

            """ Reward Projection """
            if self.on_demand_etas:
                # Compute lagrangian multipliers for the current batch of s-a pairs using a buffer of past policies
                on_demand_etas = chunked_compute_etas(
                    self.disc_buffer.policy_buffer,
                    batch.states.reshape((-1,) + self.os_shape),
                    chunk_size=self.chunk_size_dict[self.nr_steps],
                )
                proj_etas = on_demand_etas
            else:
                proj_etas = self.disc_buffer.eta_buffer

            proj_rewards = chunked_project(
                self.disc_buffer.buffer,
                (batch.states.reshape((-1,) + self.os_shape),
                batch.actions.reshape((-1,) + self.as_shape),
                batch.next_states.reshape((-1,) + self.os_shape),
                batch.terminations.reshape(-1,)),
                proj_etas,
                chunk_size=self.chunk_size_dict[self.nr_steps],
            )
            batch.proj_rewards = proj_rewards

            # Fit a network to predict the projected reward
            if self.reward_fn_approximator:
                self.reward_fn_state, rew_fn_optimization_metrics, self.key = reward_fn_update(
                    self.reward_fn_state,
                    batch.states, batch.actions, batch.next_states, batch.proj_rewards,
                    self.key
                )
                proj_rewards = get_reward_prediction(self.reward_fn_state.params, batch.states.reshape((-1,) + self.os_shape), batch.actions.reshape((-1,) + self.as_shape), batch.next_states.reshape((-1,) + self.os_shape))
                proj_rewards = proj_rewards.reshape(batch.rewards.shape)
                batch.proj_rewards = proj_rewards

            batch.merged_rewards = batch.rewards * self.env_reward_frac + batch.proj_rewards * (1 - self.env_reward_frac)
            projection_end_time = time.time()
            time_metrics["time/projection_time"] = projection_end_time - projection_start_time
            trirl_optimization_metrics["projection/eta_buffer_length"] = jnp.array([self.disc_buffer.len_eta_buffer()])

            # Calculating advantages and returns (accounting for absorbing state reward and entropy)
            if self.handle_absorbing_states:
                if self.reward_fn_approximator:
                    proj_rewards_next_state = get_reward_prediction(self.reward_fn_state.params, batch.next_states.reshape((-1,) + self.os_shape), batch.actions.reshape((-1,) + self.as_shape), batch.next_states.reshape((-1,) + self.os_shape))
                    proj_rewards_next_state = proj_rewards_next_state.reshape(batch.rewards.shape)
                else:
                    proj_rewards_next_state = chunked_project(
                        self.disc_buffer.buffer,
                        (batch.next_states.reshape((-1,) + self.os_shape), # next state
                        batch.actions.reshape((-1,) + self.as_shape), # next actions don't matter in absorbing states
                        batch.next_states.reshape((-1,) + self.os_shape), # next next states are the same as next states if in absorbing state
                        np.ones_like(batch.terminations.reshape(-1,))), # absorbing is true
                        proj_etas, # should not make a difference since we use max over the while batch anyway
                        chunk_size=self.chunk_size_dict[self.nr_steps],
                    )

                # Calculating advantages and returns
                batch.advantages, batch.returns = calculate_gae_advantages_absorbing(self.critic_state, batch.next_states, batch.merged_rewards, proj_rewards_next_state, batch.terminations, batch.values)
                calc_adv_return_end_time = time.time()
                time_metrics["time/calc_adv_and_return_time"] = calc_adv_return_end_time - projection_end_time
            else:
                # Calculating advantages and returns
                batch.advantages, batch.returns = calculate_gae_advantages(self.critic_state, batch.next_states, batch.merged_rewards, batch.terminations, batch.values)
                calc_adv_return_end_time = time.time()
                time_metrics["time/calc_adv_and_return_time"] = calc_adv_return_end_time - projection_end_time

            """ Optimizing Policy """            
            self.policy_state, self.critic_state, optimization_metrics, etas, self.key = update(
                self.policy_state, self.critic_state,
                batch.states, batch.actions, batch.advantages, batch.returns, batch.values, batch.log_probs,
                self.key, batch.old_action_means, batch.old_action_logstd, global_step,
            )

            # Add optimized policy params and computed etas to buffer
            self.disc_buffer.append_policy(self.policy_state.params)
            self.disc_buffer.append_eta(etas)
            if self.on_demand_etas:
                trirl_optimization_metrics["projection/mean_on_demand_eta"] = jnp.mean(on_demand_etas)
            else:
                trirl_optimization_metrics["projection/max_eta"] = jnp.max(etas)


            optimization_metrics = optimization_metrics | trirl_optimization_metrics
            if self.reward_fn_approximator:
                optimization_metrics = optimization_metrics | rew_fn_optimization_metrics
            optimization_metrics = {key: value.item() for key, value in optimization_metrics.items()}
            nr_updates += self.nr_epochs * self.nr_minibatches
            nr_updates_disc += self.nr_epochs_disc * self.nr_minibatches

            optimizing_end_time = time.time()
            time_metrics["time/optimizing_time"] = optimizing_end_time - calc_adv_return_end_time

            # Evaluating
            evaluation_metrics = {}
            if global_step % self.evaluation_frequency == 0 and self.evaluation_frequency != -1:
                self.set_eval_mode()
                state, _ = self.env.reset()
                eval_nr_episodes = 0
                evaluation_metrics = {"eval/episode_return": [], "eval/episode_length": []}
                while True:
                    processed_action = get_deterministic_action(self.policy_state, state)
                    state, reward, terminated, truncated, info = self.env.step(jax.device_get(processed_action))
                    done = terminated | truncated
                    for i, single_done in enumerate(done):
                        if single_done:
                            eval_nr_episodes += 1
                            evaluation_metrics["eval/episode_return"].append(self.env.get_final_info_value_at_index(info, "episode_return", i))
                            evaluation_metrics["eval/episode_length"].append(self.env.get_final_info_value_at_index(info, "episode_length", i))
                            if eval_nr_episodes == self.evaluation_episodes:
                                break
                    if eval_nr_episodes == self.evaluation_episodes:
                        break
                evaluation_metrics = {key: np.mean(value) for key, value in evaluation_metrics.items()}
                state, _ = self.env.reset()
                self.set_train_mode()
            
            evaluating_end_time = time.time()
            time_metrics["time/evaluating_time"] = evaluating_end_time - optimizing_end_time

            # Saving
            # Also only save when there were finished episodes this update
            if self.save_model and dones_this_rollout > 0:
                mean_return = np.mean(saving_return_buffer)
                if mean_return > self.best_mean_return:
                    self.best_mean_return = mean_return
                    self.save(global_step)
            
            saving_end_time = time.time()
            time_metrics["time/saving_time"] = saving_end_time - evaluating_end_time

            time_metrics["time/sps"] = int((self.nr_steps * self.nr_envs) / (saving_end_time - start_time))

            # Logging
            self.start_logging(global_step)

            steps_metrics["steps/nr_env_steps"] = global_step
            steps_metrics["steps/nr_updates"] = nr_updates
            steps_metrics["steps/nr_updates_disc"] = nr_updates_disc
            steps_metrics["steps/nr_episodes"] = nr_episodes

            rollout_info_metrics = {}
            env_info_metrics = {}
            if step_info_collection:
                info_names = list(step_info_collection.keys())
                for info_name in info_names:
                    metric_group = "rollout" if info_name in ["episode_return", "episode_length"] else "env_info"
                    metric_dict = rollout_info_metrics if metric_group == "rollout" else env_info_metrics
                    mean_value = np.mean(step_info_collection[info_name])
                    if mean_value == mean_value:  # Check if mean_value is NaN
                        metric_dict[f"{metric_group}/{info_name}"] = mean_value
            
            combined_metrics = {**rollout_info_metrics, **evaluation_metrics, **env_info_metrics, **steps_metrics, **time_metrics, **optimization_metrics}
            for key, value in combined_metrics.items():
                self.log(f"{key}", value, global_step)

            self.end_logging()


    #######################################################################
    #######################################################################
    """
    FROM HERE ON ONLY TESTING & SAVING FUNCTIONS
    """
    #######################################################################
    #######################################################################

    def train(self):
        if self.global_rew_experiment:
            self.plot_point_maze_reward()
            # self.learnt_reward_pearson_correlation()
            # self.train_ppo()
        else:
            self.train_irl()


    def plot_point_maze_reward(self):
        import matplotlib.pyplot as plt

        import matplotlib.patches as patches
        plt.rcParams.update({
            "font.family": "monospace",
            "axes.titlesize": 24,
            "axes.labelsize": 24,
            "legend.fontsize": 24,
        })

        shaping = 0.0
        grid_size, low, high = 100, 0.0, 0.3
        xs = np.linspace(low, high, grid_size, dtype=np.float32)
        ys = np.linspace(low, high, grid_size, dtype=np.float32)
        xx, yy = np.meshgrid(xs, ys, indexing="xy")
        num = grid_size * grid_size

        target = np.array([0.15, 0.0], dtype=np.float32)
        states = np.stack([xx.reshape(-1), yy.reshape(-1),
                        np.full(num, target[0], np.float32),
                        np.full(num, target[1], np.float32)], axis=-1).astype(np.float32)
        next_states = states.copy()
        actions = np.zeros((num,) + self.as_shape, np.float32)

        @jax.jit
        def get_reward_prediction(params, s, a, ns):
            return self.reward_fn.apply(params, s, a, ns, shaping=shaping)

        get_reward_prediction = jax.vmap(get_reward_prediction, in_axes=(None,0,0,0), out_axes=0)

        proj_rewards = -get_reward_prediction(
            self.reward_fn_state.params,
            states.reshape((-1,) + self.os_shape),
            actions.reshape((-1,) + self.as_shape),
            next_states.reshape((-1,) + self.os_shape),
        )

        reward_grid = np.asarray(proj_rewards).reshape(grid_size, grid_size)
        plt.figure(figsize=(8,5))
        im = plt.imshow(reward_grid, extent=[low,high,low,high], origin="lower", aspect="equal", cmap="Blues")
        
        ax = plt.gca()
        ax.add_patch(patches.Rectangle((low, low), high - low, high - low,
                                    fill=False, edgecolor="black", linewidth=8))
        ax.plot([0.0, 0.18], [0.15, 0.15], color="black", linewidth=8)
        ax.scatter(0.15, 0.28, s=150, color="green", zorder=3)
        plt.plot(0.15, 0.02, marker='*', markersize=20, color='red')
        cbar = plt.colorbar(im)
        cbar.ax.tick_params(labelsize=24)
        # cbar.set_label("Reward", fontsize=24)
        cbar.set_ticks(np.linspace(im.get_clim()[0], im.get_clim()[1], 6))

        ax.set_xticks([]); ax.set_yticks([])
        ax.set_xlabel(""); ax.set_ylabel("")

        # plt.xlabel("x"); plt.ylabel("y")
        plt.title("Learnt Reward")
        plt.tight_layout()
        # plt.savefig("./point_maze_reward.pdf", dpi=200, bbox_inches='tight')
        plt.show()


    def learnt_reward_pearson_correlation(self):
        from scipy import stats
        import matplotlib.pyplot as plt

        # Get demonstrations
        demonstrations = prepare_expert_data(self.data_path)
        expert_states = demonstrations["states"]
        expert_next_states = demonstrations["next_states"]
        expert_actions = demonstrations["actions"]
        expert_rewards = demonstrations["rewards"].reshape((-1,1))
        expert_absorbing = demonstrations["absorbing"].flatten()
        if self.handle_absorbing_states:
            abs_indices = np.where(expert_absorbing > 0.0)[0]
            expert_states[abs_indices] = expert_states[abs_indices] * 0.0
            expert_next_states[abs_indices] = expert_states[abs_indices]

        shaping = 0.0 # if shaped rewards are used then this removes the shaping part 
        if self.reward_fn_approximator:
            @jax.jit
            def get_reward_prediction(reward_fn_params, state: np.ndarray, action: np.ndarray, next_state: np.ndarray):
                logits = self.reward_fn.apply(reward_fn_params, state, action, next_state, shaping=shaping)
                return logits
            
            get_reward_prediction = jax.vmap(get_reward_prediction, in_axes=(None, 0, 0, 0), out_axes=0)

            proj_rewards = get_reward_prediction(self.reward_fn_state.params, expert_states.reshape((-1,) + self.os_shape), expert_actions.reshape((-1,) + self.as_shape), expert_next_states.reshape((-1,) + self.os_shape))
        else:
            @jax.jit
            def get_log_density_ratio(discriminator_params, state: np.ndarray, action: np.ndarray, next_state: np.ndarray, absorbing: np.ndarray):
                logits = self.discriminator.apply(discriminator_params, state, action, next_state, absorbing, shaping=shaping)
                return logits
            
            get_log_density_ratio = jax.vmap(get_log_density_ratio, in_axes=(None, 0, 0, 0, 0), out_axes=0)
            chunked_project = make_chunked_ensemble_rew_project(
                get_log_density_ratio, 
                nr_steps=expert_states.shape[0], 
                nr_envs=1, 
                epsilon=self.epsilon, 
                beta=self.beta,
                entropy_coef=self.entropy_coef,
                maximum_eta=True
            )

            proj_rewards = chunked_project(
                self.disc_buffer.buffer,
                (expert_states.reshape((-1,) + self.os_shape),
                expert_actions.reshape((-1,) + self.as_shape),
                expert_next_states.reshape((-1,) + self.os_shape),
                expert_absorbing.reshape(-1,)),
                self.disc_buffer.eta_buffer,
                chunk_size=10,
            )

        pearson_correlation = stats.pearsonr(proj_rewards, expert_rewards)
        print(f"pearson Correlation: {pearson_correlation[0]} p value: {pearson_correlation[1]}")

        plt.figure(figsize=(8, 6))
        plt.scatter(proj_rewards.flatten(), expert_rewards.flatten(), alpha=0.5, s=10)
        plt.xlabel("Projected Rewards")
        plt.ylabel("Expert Rewards")
        plt.title("Scatter Plot of Projected vs Expert Rewards")
        plt.grid(True)
        plt.show()


    def train_ppo(self):

        rlx_logger.info(f"Training PPO from scratch with a saved reward function")

        @jax.jit
        def get_action_and_value(policy_state: TrainState, critic_state: TrainState, state: np.ndarray, key: jax.random.PRNGKey):
            action_mean, action_logstd = self.policy.apply(policy_state.params, state)
            action_std = jnp.exp(action_logstd)
            key, subkey = jax.random.split(key)
            action = action_mean + action_std * jax.random.normal(subkey, shape=action_mean.shape)
            log_prob = -0.5 * ((action - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd
            value = self.critic.apply(critic_state.params, state)
            processed_action = self.get_processed_action(action)
            return processed_action, action, value.reshape(-1), log_prob.sum(1), key

        @jax.jit
        def calculate_gae_advantages(critic_state: TrainState, next_states: np.ndarray, rewards: np.ndarray, terminations: np.ndarray, values: np.ndarray):
            def compute_advantages(carry, t):
                prev_advantage = carry[0]
                advantage = delta[t] + self.gamma * self.gae_lambda * (1 - terminations[t]) * prev_advantage
                return (advantage,), advantage

            next_values = self.critic.apply(critic_state.params, next_states).squeeze(-1)
            delta = rewards + self.gamma * next_values * (1.0 - terminations) - values
            init_advantages = delta[-1]
            _, advantages = jax.lax.scan(compute_advantages, (init_advantages,), jnp.arange(self.nr_steps - 2, -1, -1))
            advantages = jnp.concatenate([advantages[::-1], jnp.array([init_advantages])])
            returns = advantages + values
            return advantages, returns


        @jax.jit
        def calculate_gae_advantages_absorbing(critic_state: TrainState, next_states: np.ndarray, rewards: np.ndarray, rewards_next_state: np.ndarray, terminations: np.ndarray, values: np.ndarray):
            """
            Correctly handle absorbing state value and entropy (instead of setting to 0.0)
            """
            def compute_advantages(carry, t):
                prev_advantage = carry[0]
                advantage = delta[t] + self.gamma * self.gae_lambda * (1 - terminations[t]) * prev_advantage
                return (advantage,), advantage

            next_values = self.critic.apply(critic_state.params, next_states).squeeze(-1)
            terminal_tail = (self.gamma / (1.0 - self.gamma)) * (rewards_next_state + self.entropy_coef * self.H_terminal)
            delta = rewards + self.gamma * next_values * (1.0 - terminations) + (terminations * terminal_tail) - values
            init_advantages = delta[-1]
            _, advantages = jax.lax.scan(compute_advantages, (init_advantages,), jnp.arange(self.nr_steps - 2, -1, -1))
            advantages = jnp.concatenate([advantages[::-1], jnp.array([init_advantages])])
            returns = advantages + values
            return advantages, returns

        
        @jax.jit
        def get_log_density_ratio(discriminator_params, state: np.ndarray, action: np.ndarray, next_state: np.ndarray, absorbing: np.ndarray):
            logits = self.discriminator.apply(discriminator_params, state, action, next_state, absorbing)
            return logits
        
        get_log_density_ratio = jax.vmap(get_log_density_ratio, in_axes=(None, 0, 0, 0, 0), out_axes=0)

        chunked_project = make_chunked_ensemble_rew_project(
            get_log_density_ratio, 
            nr_steps=self.nr_steps, 
            nr_envs=self.nr_envs, 
            epsilon=self.epsilon, 
            beta=self.beta,
            entropy_coef=self.entropy_coef,
            maximum_eta=self.maximum_eta
        )

        @jax.jit
        def update(policy_state: TrainState, critic_state: TrainState,
                   states: np.ndarray, actions: np.ndarray, advantages: np.ndarray, returns: np.ndarray, values: np.ndarray, log_probs: np.ndarray,
                   key: jax.random.PRNGKey):
            def loss_fn(policy_params, critic_params, state_b, action_b, log_prob_b, return_b, advantage_b):
                # Policy loss
                action_mean, action_logstd = self.policy.apply(policy_params, state_b)
                action_std = jnp.exp(action_logstd)
                new_log_prob = -0.5 * ((action_b - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd
                new_log_prob = new_log_prob.sum(1)
                entropy = action_logstd + 0.5 * jnp.log(2.0 * jnp.pi * jnp.e)
                entropy = self.entropy_coef * entropy # trained reward function predicts scaled rewards. scaling entropy in ppo to ensure that the beta ratio is the same as the reward learning problem
                
                logratio = new_log_prob - log_prob_b
                ratio = jnp.exp(logratio)
                approx_kl_div = (ratio - 1) - logratio
                clip_fraction = jnp.float32((jnp.abs(ratio - 1) > self.clip_range))

                pg_loss1 = -advantage_b * ratio
                pg_loss2 = -advantage_b * jnp.clip(ratio, 1 - self.clip_range, 1 + self.clip_range)
                pg_loss = jnp.maximum(pg_loss1, pg_loss2)
                
                entropy_loss = entropy.sum(1)
                
                # Critic loss
                new_value = self.critic.apply(critic_params, state_b)
                critic_loss = 0.5 * (new_value - return_b) ** 2

                # Combine losses
                loss = pg_loss - self.entropy_coef * entropy_loss + self.critic_coef * critic_loss

                # Create metrics
                metrics = {
                    "loss/policy_gradient_loss": pg_loss,
                    "loss/critic_loss": critic_loss,
                    "loss/entropy_loss": entropy_loss,
                    "policy_ratio/approx_kl": approx_kl_div,
                    "policy_ratio/clip_fraction": clip_fraction,
                }


                return loss, (metrics)
            

            batch_states = states.reshape((-1,) + self.os_shape)
            batch_actions = actions.reshape((-1,) + self.as_shape)
            batch_advantages = advantages.reshape(-1)
            batch_returns = returns.reshape(-1)
            batch_log_probs = log_probs.reshape(-1)

            vmap_loss_fn = jax.vmap(loss_fn, in_axes=(None, None, 0, 0, 0, 0, 0), out_axes=0)
            safe_mean = lambda x: jnp.mean(x) if x is not None else x
            mean_vmapped_loss_fn = lambda *a, **k: tree.map_structure(safe_mean, vmap_loss_fn(*a, **k))
            grad_loss_fn = jax.value_and_grad(mean_vmapped_loss_fn, argnums=(0, 1), has_aux=True)

            key, subkey = jax.random.split(key)
            batch_indices = jnp.tile(jnp.arange(self.batch_size), (self.nr_epochs, 1))
            batch_indices = jax.random.permutation(subkey, batch_indices, axis=1, independent=True)
            batch_indices = batch_indices.reshape((self.nr_epochs * self.nr_minibatches, self.minibatch_size))

            def minibatch_update(carry, minibatch_indices):
                policy_state, critic_state = carry

                minibatch_advantages = batch_advantages[minibatch_indices]
                minibatch_advantages = (minibatch_advantages - jnp.mean(minibatch_advantages)) / (jnp.std(minibatch_advantages) + 1e-8)

                (loss, (metrics)), (policy_gradients, critic_gradients) = grad_loss_fn(
                    policy_state.params,
                    critic_state.params,
                    batch_states[minibatch_indices],
                    batch_actions[minibatch_indices],
                    batch_log_probs[minibatch_indices],
                    batch_returns[minibatch_indices],
                    minibatch_advantages
                )

                policy_state = policy_state.apply_gradients(grads=policy_gradients)
                critic_state = critic_state.apply_gradients(grads=critic_gradients)

                metrics["gradients/policy_grad_norm"] = optax.global_norm(policy_gradients)
                metrics["gradients/critic_grad_norm"] = optax.global_norm(critic_gradients)

                carry = (policy_state, critic_state)

                return carry, (metrics)
            
            init_carry = (policy_state, critic_state)
            carry, (metrics) = jax.lax.scan(minibatch_update, init_carry, batch_indices)
            policy_state, critic_state = carry

            # Calculate mean metrics
            mean_metrics = {key: jnp.mean(metrics[key]) for key in metrics}
            mean_metrics["lr/learning_rate"] = policy_state.opt_state[1].hyperparams["learning_rate"]
            mean_metrics["v_value/explained_variance"] = 1 - jnp.var(returns - values) / (jnp.var(returns) + 1e-8)
            mean_metrics["policy/std_dev"] = jnp.mean(jnp.exp(policy_state.params["params"]["policy_logstd"]))

            return policy_state, critic_state, mean_metrics, key


        @jax.jit
        def get_deterministic_action(policy_state: TrainState, state: np.ndarray):
            action_mean, action_logstd = self.policy.apply(policy_state.params, state)
            return self.get_processed_action(action_mean)

        @jax.jit
        def get_reward_prediction(reward_fn_params, state: np.ndarray, action: np.ndarray, next_state: np.ndarray):
            logits = self.reward_fn.apply(reward_fn_params, state, action, next_state, shaping=0.0)
            return logits
        
        get_reward_prediction = jax.vmap(get_reward_prediction, in_axes=(None, 0, 0, 0), out_axes=0)


        self.set_train_mode()

        batch = BatchPPO(
            states=np.zeros((self.nr_steps, self.nr_envs) + self.os_shape),
            next_states=np.zeros((self.nr_steps, self.nr_envs) + self.os_shape),
            actions=np.zeros((self.nr_steps, self.nr_envs) + self.as_shape),
            rewards=np.zeros((self.nr_steps, self.nr_envs)),
            values=np.zeros((self.nr_steps, self.nr_envs)),
            terminations=np.zeros((self.nr_steps, self.nr_envs)),
            log_probs=np.zeros((self.nr_steps, self.nr_envs)),
            advantages=np.zeros((self.nr_steps, self.nr_envs)),
            returns=np.zeros((self.nr_steps, self.nr_envs)),
        )

        saving_return_buffer = deque(maxlen=100 * self.nr_envs)

        state, _ = self.env.reset()
        global_step = 0
        nr_updates = 0
        nr_episodes = 0
        steps_metrics = {}
        while global_step < self.total_timesteps:
            start_time = time.time()
            time_metrics = {}


            # Acting
            dones_this_rollout = 0
            step_info_collection = {}
            for step in range(self.nr_steps):
                processed_action, action, value, log_prob, self.key = get_action_and_value(self.policy_state, self.critic_state, state, self.key)
                next_state, reward, terminated, truncated, info = self.env.step(jax.device_get(processed_action))
                done = terminated | truncated
                actual_next_state = next_state.copy()
                for i, single_done in enumerate(done):
                    if single_done:
                        if self.handle_absorbing_states:
                            actual_next_state[i] = np.array(self.env.get_final_observation_at_index(info, i)) * 0.0 # setting one absorbing state (zero vector)
                        else:
                            actual_next_state[i] = np.array(self.env.get_final_observation_at_index(info, i))
                        saving_return_buffer.append(self.env.get_final_info_value_at_index(info, "episode_return", i))
                        dones_this_rollout += 1
                for key, info_value in self.env.get_logging_info_dict(info).items():
                    step_info_collection.setdefault(key, []).extend(info_value)

                batch.states[step] = state
                batch.next_states[step] = actual_next_state
                batch.actions[step] = action
                batch.rewards[step] = reward
                batch.values[step] = value
                batch.terminations[step] = terminated
                batch.log_probs[step] = log_prob
                state = next_state
                global_step += self.nr_envs
            nr_episodes += dones_this_rollout
            
            acting_end_time = time.time()
            time_metrics["time/acting_time"] = acting_end_time - start_time


            if self.reward_fn_approximator:
                proj_rewards = get_reward_prediction(self.reward_fn_state.params, batch.states.reshape((-1,) + self.os_shape), batch.actions.reshape((-1,) + self.as_shape), batch.next_states.reshape((-1,) + self.os_shape))                    
                proj_rewards = proj_rewards.reshape(batch.rewards.shape)
            else:
                assert self.on_demand_etas == False, "When computing etas on demand, the reward function must always be approximated using a neural network"
                # Compute projected reward using the saved discriminators
                proj_rewards = chunked_project(
                    self.disc_buffer.buffer,
                    (batch.states.reshape((-1,) + self.os_shape),
                    batch.actions.reshape((-1,) + self.as_shape),
                    batch.next_states.reshape((-1,) + self.os_shape),
                    batch.terminations.reshape(-1,)),
                    self.disc_buffer.eta_buffer,
                    chunk_size=self.chunk_size_dict[self.nr_steps],
                )

            if self.handle_absorbing_states:
                if self.reward_fn_approximator:
                    proj_rewards_next_state = get_reward_prediction(self.reward_fn_state.params, batch.next_states.reshape((-1,) + self.os_shape), 0.0*batch.actions.reshape((-1,) + self.as_shape), batch.next_states.reshape((-1,) + self.os_shape)) 
                    proj_rewards_next_state = proj_rewards_next_state.reshape(batch.rewards.shape)
                else:
                    assert self.on_demand_etas == False, "When computing etas on demand, the reward function must always be approximated using a neural network"
                    # Compute projected reward using the saved discriminators
                    proj_rewards_next_state = chunked_project(
                        self.disc_buffer.buffer,
                        (batch.next_states.reshape((-1,) + self.os_shape), # next state
                        0.0*batch.actions.reshape((-1,) + self.as_shape), # next actions don't matter in absorbing states
                        batch.next_states.reshape((-1,) + self.os_shape), # next next states are the same as next states if in absorbing state
                        np.ones_like(batch.terminations.reshape(-1,))), # absorbing is true
                        self.disc_buffer.eta_buffer, # should not make a difference since we use max over the while batch anyway
                        chunk_size=self.chunk_size_dict[self.nr_steps],
                    )

                # Calculating advantages and returns
                batch.advantages, batch.returns = calculate_gae_advantages_absorbing(self.critic_state, batch.next_states, proj_rewards, proj_rewards_next_state, batch.terminations, batch.values)        
                calc_adv_return_end_time = time.time()
                time_metrics["time/calc_adv_and_return_time"] = calc_adv_return_end_time - acting_end_time
            else:
                # Calculating advantages and returns
                batch.advantages, batch.returns = calculate_gae_advantages(self.critic_state, batch.next_states, proj_rewards, batch.terminations, batch.values)            
                calc_adv_return_end_time = time.time()
                time_metrics["time/calc_adv_and_return_time"] = calc_adv_return_end_time - acting_end_time


            # Optimizing
            self.policy_state, self.critic_state, optimization_metrics, self.key = update(
                self.policy_state, self.critic_state,
                batch.states, batch.actions, batch.advantages, batch.returns, batch.values, batch.log_probs,
                self.key
            )
            optimization_metrics = {key: value.item() for key, value in optimization_metrics.items()}
            nr_updates += self.nr_epochs * self.nr_minibatches

            optimizing_end_time = time.time()
            time_metrics["time/optimizing_time"] = optimizing_end_time - calc_adv_return_end_time


            # Evaluating
            evaluation_metrics = {}
            if global_step % self.evaluation_frequency == 0 and self.evaluation_frequency != -1:
                self.set_eval_mode()
                state, _ = self.env.reset()
                eval_nr_episodes = 0
                evaluation_metrics = {"eval/episode_return": [], "eval/episode_length": []}
                while True:
                    processed_action = get_deterministic_action(self.policy_state, state)
                    state, reward, terminated, truncated, info = self.env.step(jax.device_get(processed_action))
                    done = terminated | truncated
                    for i, single_done in enumerate(done):
                        if single_done:
                            eval_nr_episodes += 1
                            evaluation_metrics["eval/episode_return"].append(self.env.get_final_info_value_at_index(info, "episode_return", i))
                            evaluation_metrics["eval/episode_length"].append(self.env.get_final_info_value_at_index(info, "episode_length", i))
                            if eval_nr_episodes == self.evaluation_episodes:
                                break
                    if eval_nr_episodes == self.evaluation_episodes:
                        break
                evaluation_metrics = {key: np.mean(value) for key, value in evaluation_metrics.items()}
                state, _ = self.env.reset()
                self.set_train_mode()
            
            evaluating_end_time = time.time()
            time_metrics["time/evaluating_time"] = evaluating_end_time - optimizing_end_time
            

            # Saving
            # Also only save when there were finished episodes this update
            if self.save_model and dones_this_rollout > 0:
                mean_return = np.mean(saving_return_buffer)
                if mean_return > self.best_mean_return:
                    self.best_mean_return = mean_return
                    self.save(global_step)
            
            saving_end_time = time.time()
            time_metrics["time/saving_time"] = saving_end_time - evaluating_end_time
            time_metrics["time/sps"] = int((self.nr_steps * self.nr_envs) / (saving_end_time - start_time))


            # Logging
            self.start_logging(global_step)

            steps_metrics["steps/nr_env_steps"] = global_step
            steps_metrics["steps/nr_updates"] = nr_updates
            steps_metrics["steps/nr_episodes"] = nr_episodes

            rollout_info_metrics = {}
            env_info_metrics = {}
            if step_info_collection:
                info_names = list(step_info_collection.keys())
                for info_name in info_names:
                    metric_group = "rollout" if info_name in ["episode_return", "episode_length"] else "env_info"
                    metric_dict = rollout_info_metrics if metric_group == "rollout" else env_info_metrics
                    mean_value = np.mean(step_info_collection[info_name])
                    if mean_value == mean_value:  # Check if mean_value is NaN
                        metric_dict[f"{metric_group}/{info_name}"] = mean_value
            
            combined_metrics = {**rollout_info_metrics, **evaluation_metrics, **env_info_metrics, **steps_metrics, **time_metrics, **optimization_metrics}
            for key, value in combined_metrics.items():
                self.log(f"{key}", value, global_step)

            self.end_logging()


    def log(self, name, value, step):
        if self.track_tb:
            self.writer.add_scalar(name, value, step)
        if self.track_console:
            self.log_console(name, value)
    

    def log_console(self, name, value):
        value = np.format_float_positional(value, trim="-")
        rlx_logger.info(f"│ {name.ljust(30)}│ {str(value).ljust(14)[:14]} │", flush=False)


    def start_logging(self, step):
        if self.track_console:
            rlx_logger.info("┌" + "─" * 31 + "┬" + "─" * 16 + "┐", flush=False)
        else:
            rlx_logger.info(f"Step: {step}")


    def end_logging(self):
        if self.track_console:
            rlx_logger.info("└" + "─" * 31 + "┴" + "─" * 16 + "┘")


    def save(self, global_step):
        checkpoint = {
            "policy": self.policy_state,
            "critic": self.critic_state,
            "discriminator": self.discriminator_state,
        }

        if self.reward_fn_approximator:
            checkpoint["reward_fn"] = self.reward_fn_state

        save_args = orbax_utils.save_args_from_target(checkpoint)
        self.best_model_checkpointer.save(f"{self.save_path}/tmp", checkpoint, save_args=save_args)
        with open(f"{self.save_path}/tmp/config_algorithm.json", "w") as f:
            json.dump(self.config.algorithm.to_dict(), f)

        if not self.reward_fn_approximator:
            if not hasattr(self, "_last_buffer_save_step"):
                self._last_buffer_save_step = -1
            save_interval = self.save_delay * 4096  # adjust frequency
            if global_step - self._last_buffer_save_step >= save_interval:
                self._last_buffer_save_step = global_step
                # Save pkl files for disc buffer
                disc_state_list = [flax_serialization.to_state_dict(p) for p in list(self.disc_buffer.buffer)]
                with open(os.path.join(self.save_path, "tmp", "disc_buffer.pkl"), "wb") as f:
                    pickle.dump(disc_state_list, f)
                np.save(os.path.join(self.save_path, "tmp", "eta_buffer.npy"), self.disc_buffer.eta_buffer)

        shutil.make_archive(f"{self.save_path}/{self.best_model_file_name}", "zip", f"{self.save_path}/tmp")
        # os.rename(f"{self.save_path}/{self.best_model_file_name}.zip", f"{self.save_path}/{self.best_model_file_name}")
        shutil.rmtree(f"{self.save_path}/tmp")

        if self.track_wandb:
            wandb.save(f"{self.save_path}/{self.best_model_file_name}", base_path=self.save_path)


    def load(config, env, run_path, writer, explicitly_set_algorithm_params):
        splitted_path = config.runner.load_model.split("/")
        checkpoint_dir = os.path.abspath("/".join(splitted_path[:-1]))
        checkpoint_file_name = splitted_path[-1]
        shutil.unpack_archive(f"{checkpoint_dir}/{checkpoint_file_name}", f"{checkpoint_dir}/tmp", "zip")
        checkpoint_dir = f"{checkpoint_dir}/tmp"

        loaded_algorithm_config = json.load(open(f"{checkpoint_dir}/config_algorithm.json", "r"))
        for key, value in loaded_algorithm_config.items():
            if f"algorithm.{key}" not in explicitly_set_algorithm_params:
                config.algorithm[key] = value
        model = TRIRL_TRPL(config, env, run_path, writer)

        template_params_dict = flax_serialization.to_state_dict(model.discriminator_state.params)
        target = {
            "policy": model.policy_state,
            "critic": model.critic_state,
            "discriminator": model.discriminator_state,
        }

        if model.reward_fn_approximator:
            target["reward_fn"] = model.reward_fn_state

        restore_args = orbax_utils.restore_args_from_target(target)
        checkpointer = orbax.checkpoint.PyTreeCheckpointer()
        checkpoint = checkpointer.restore(checkpoint_dir, item=target, restore_args=restore_args)

        if not model.global_rew_experiment:
            model.policy_state = checkpoint["policy"]
            model.critic_state = checkpoint["critic"]

        model.discriminator_state = checkpoint["discriminator"]
        if model.reward_fn_approximator:
            model.reward_fn_state = checkpoint["reward_fn"]

        disc_pkl_path = os.path.join(checkpoint_dir, "disc_buffer.pkl")
        if os.path.exists(disc_pkl_path):
            with open(disc_pkl_path, "rb") as f:
                disc_state_list = pickle.load(f)
            restored_disc = []
            for sd in disc_state_list:
                params = flax_serialization.from_state_dict(model.discriminator_state.params, sd)
                params = jax.tree.map(lambda x: jnp.array(x), params)
                restored_disc.append(params)
            model.disc_buffer._buffer = deque(restored_disc, maxlen=config.algorithm.disc_buffer_capacity)
        else:
            model.disc_buffer._buffer = deque(maxlen=config.algorithm.disc_buffer_capacity)

        eta_npy_path = os.path.join(checkpoint_dir, "eta_buffer.npy")
        if os.path.exists(eta_npy_path):
            eta_array = np.load(eta_npy_path, allow_pickle=False)
            if eta_array.size == 0:
                model.disc_buffer._eta_buffer = deque(maxlen=config.algorithm.disc_buffer_capacity)
            else:
                entries = [eta_array[i] for i in range(eta_array.shape[0])]
                model.disc_buffer._eta_buffer = deque(entries, maxlen=config.algorithm.disc_buffer_capacity)
        else:
            model.disc_buffer._eta_buffer = deque(maxlen=config.algorithm.disc_buffer_capacity)

        shutil.rmtree(checkpoint_dir)
        return model


    def test(self, episodes):
        @jax.jit
        def get_action(policy_state: TrainState, state: np.ndarray):
            action_mean, action_logstd = self.policy.apply(policy_state.params, state)
            return self.get_processed_action(action_mean)

        def evaluate():
            nr_steps = 1000
            mean_episode_return = []

            batch = Batch(
                states=np.zeros((episodes-1, nr_steps, self.nr_envs) + self.os_shape),
                next_states=np.zeros((episodes-1, nr_steps, self.nr_envs) + self.os_shape),
                actions=np.zeros((episodes-1, nr_steps, self.nr_envs) + self.as_shape),
                rewards=np.zeros((episodes-1, nr_steps, self.nr_envs)),
                terminations=np.zeros((episodes-1, nr_steps, self.nr_envs)),

                values=np.zeros((episodes-1, nr_steps, self.nr_envs)),
                log_probs=np.zeros((episodes-1, nr_steps, self.nr_envs)),
                advantages=np.zeros((episodes-1, nr_steps, self.nr_envs)),
                returns=np.zeros((episodes-1, nr_steps, self.nr_envs)),
            )
            for i in range(episodes):
                done = False
                episode_return = 0
                state, _ = self.env.reset()


                # UGLY FIX: for some reason the first episode's step returns nans!! Skip the first one!
                if i == 0:
                    dummy_return = 0
                    for step in range(10):
                        processed_action = get_action(self.policy_state, state)
                        next_state, reward, terminated, truncated, info = self.env.step(jax.device_get(processed_action))
                        done = terminated | truncated
                        state = next_state
                        dummy_return += reward
                    rlx_logger.info(f"Episode {i + 1} - Return: {dummy_return.mean()}")

                else:
                    state, _ = self.env.reset()

                    for step in range(nr_steps):
                        processed_action = get_action(self.policy_state, state)
                        next_state, reward, terminated, truncated, info = self.env.step(jax.device_get(processed_action))
                        done = terminated | truncated
                        actual_next_state = next_state.copy()

                        for idx, single_done in enumerate(done):
                            if single_done:
                                actual_next_state[idx] = np.array(self.env.get_final_observation_at_index(info, idx))

                        batch.states[i-1, step] = state
                        batch.next_states[i-1, step] = actual_next_state
                        batch.actions[i-1, step] = processed_action
                        batch.rewards[i-1, step] = reward
                        batch.terminations[i-1, step] = terminated

                        state = next_state
                        episode_return += reward
                    
                    mean_episode_return.append(episode_return.mean())
                    rlx_logger.info(f"Episode {i + 1} - Return: {episode_return.mean()}")
                
            mean_episode_return = np.array(mean_episode_return).mean()
            rlx_logger.info(f"Mean Episode Return: {mean_episode_return}") 


            def flatten_and_prune(arr):
                flat = arr.reshape(-1, arr.shape[-1])
                nan_mask = ~np.isnan(flat).any(axis=1)
                return flat[nan_mask]

            exp_states = flatten_and_prune(batch.states)
            exp_actions = flatten_and_prune(batch.actions)
            exp_next_states = flatten_and_prune(batch.next_states)            
            exp_absorbing = flatten_and_prune(batch.terminations)
            exp_rewards = flatten_and_prune(batch.rewards)

            print(f"save path: {self.save_path}")
            print(f"states shape: {exp_states.shape}")
            print(f"rewards mean: {exp_states.shape}")
            np.savez(f"{self.save_path}/expert_dataset_Humanoid-v5_{episodes-1}_PPO", states=exp_states, actions=exp_actions, 
                next_states=exp_next_states, absorbing=exp_absorbing, rewards=exp_rewards)


        def visualise():

            nr_steps = 1000
            done = False
            episode_return = 0
            state, _ = self.env.reset()


            for step in range(nr_steps):
                processed_action = get_action(self.policy_state, state)
                next_state, reward, terminated, truncated, info = self.env.step(jax.device_get(processed_action))
                done = terminated | truncated
                actual_next_state = next_state.copy()

                if done:
                    break

                state = next_state
                episode_return += reward
            
            rlx_logger.info(f"Episode Return: {episode_return.mean()}")
            rlx_logger.info(f"Step: {step}")
                
        
        # evaluate()
        visualise()

            
    def set_train_mode(self):
        ...


    def set_eval_mode(self):
        ...


    def general_properties():
        return GeneralProperties
