import os
import shutil
import json
import logging
import time
from collections import deque
import tree
import numpy as np
import jax
import jax.numpy as jnp
from jax.lax import stop_gradient
from flax.training.train_state import TrainState
from flax.training import orbax_utils
from flax import serialization as flax_serialization
import orbax.checkpoint
import optax
import wandb
import pickle
from functools import partial

from irl_baselines.algorithms.trirl_ppo.flax.general_properties import GeneralProperties
from irl_baselines.algorithms.trirl_ppo.flax.policy import get_policy
from irl_baselines.algorithms.trirl_ppo.flax.critic import get_critic
from irl_baselines.algorithms.trirl_ppo.flax.discriminator import get_discriminator, get_reward_approximator, DiscriminatorBuffer, make_chunked_ensemble_rew_project
from irl_baselines.algorithms.trirl_ppo.flax.batch import Batch
from irl_baselines.algorithms.ppo.flax.batch import Batch as BatchPPO
from irl_baselines.algorithms.data_utils import prepare_expert_data

rlx_logger = logging.getLogger("rl_x")

class TRIRL_PPO:
    def __init__(self, config, env, eval_env, run_path, writer) -> None:
        self.config = config
        self.env = env
        self.writer = writer

        self.save_model = config.runner.save_model
        self.save_path = os.path.join(run_path, "models")
        self.track_console = config.runner.track_console
        self.track_tb = config.runner.track_tb
        self.track_wandb = config.runner.track_wandb
        self.seed = config.environment.seed
        self.total_timesteps = config.algorithm.total_timesteps
        self.nr_envs = config.environment.nr_envs
        self.learning_rate = config.algorithm.learning_rate
        self.anneal_learning_rate = config.algorithm.anneal_learning_rate
        self.nr_steps = config.algorithm.nr_steps
        self.nr_epochs = config.algorithm.nr_epochs
        self.minibatch_size = config.algorithm.minibatch_size
        self.gamma = config.algorithm.gamma
        self.gae_lambda = config.algorithm.gae_lambda
        self.clip_range = config.algorithm.clip_range
        self.entropy_coef = config.algorithm.entropy_coef
        self.critic_coef = config.algorithm.critic_coef
        self.max_grad_norm = config.algorithm.max_grad_norm
        self.std_dev = config.algorithm.std_dev
        self.nr_hidden_units = config.algorithm.nr_hidden_units
        self.evaluation_frequency = config.algorithm.evaluation_frequency
        self.evaluation_episodes = config.algorithm.evaluation_episodes
        self.batch_size = config.environment.nr_envs * config.algorithm.nr_steps
        self.nr_updates = config.algorithm.total_timesteps // self.batch_size
        self.nr_minibatches = self.batch_size // self.minibatch_size
        self.subsampling_cutoff = config.algorithm.get("subsampling_cutoff", 1)

        # Global Reward Experiment Flag
        self.global_rew_experiment = config.algorithm.global_rew_experiment

        # TRIRL Specific
        self.data_path = config.algorithm.data_path
        self.nr_epochs_disc = config.algorithm.nr_epochs_disc
        self.learning_rate_disc = config.algorithm.learning_rate_disc
        self.env_reward_frac = config.algorithm.env_reward_frac
        self.handle_absorbing_states = config.algorithm.handle_absorbing_states
        self.epsilon = config.algorithm.epsilon
        self.disc_buffer_capacity = config.algorithm.disc_buffer_capacity
        self.init_eta = config.algorithm.init_eta
        self.const_eta = config.algorithm.const_eta
        self.gp_lambda = config.algorithm.gp_lambda
        self.gp_alpha = config.algorithm.gp_alpha
        # self.beta = config.algorithm.beta
        self.beta = 1/config.algorithm.entropy_coef
        self.disc_buffer = DiscriminatorBuffer(self.disc_buffer_capacity, (self.nr_steps, self.nr_envs))
        self.chunk_size_dict = {10:50, 50:8, 100:4} # dict mapping nr_steps to chunk size (in the case where eta's are not on demand)
        self.reward_type = config.algorithm.reward_type
        self.reward_approximator_type = config.algorithm.reward_approximator_type
        self.save_delay = 50

        # TRIRL - Reward Approximation
        self.reward_fn_approximator = config.algorithm.reward_fn_approximator
        self.nr_epochs_rew = config.algorithm.nr_epochs_rew
        self.learning_rate_reward_fn = config.algorithm.learning_rate_reward_fn

        if self.evaluation_frequency % (self.nr_steps * self.nr_envs) != 0 and self.evaluation_frequency != -1:
            raise ValueError("Evaluation frequency must be a multiple of the number of steps and environments.")

        rlx_logger.info(f"Using device: {jax.default_backend()}")
        
        self.key = jax.random.PRNGKey(self.seed)
        self.key, policy_key, critic_key, discriminator_key = jax.random.split(self.key, 4)

        self.os_shape = env.single_observation_space.shape
        self.as_shape = env.single_action_space.shape
        
        self.policy, self.get_processed_action = get_policy(config, env)
        self.critic = get_critic(config, env)
        self.discriminator = get_discriminator(config, env, reward_type=self.reward_type)

        if self.reward_fn_approximator:
            self.reward_fn = get_reward_approximator(config, env, reward_approximator_type=self.reward_approximator_type)

        self.policy.apply = jax.jit(self.policy.apply)
        self.critic.apply = jax.jit(self.critic.apply)
        self.discriminator.apply = jax.jit(self.discriminator.apply)
        self.H_terminal = jnp.sum(jnp.log(env.single_action_space.high - env.single_action_space.low))

        def linear_schedule(count):
            fraction = 1.0 - (count // (self.nr_minibatches * self.nr_epochs)) / self.nr_updates
            return self.learning_rate * fraction

        def linear_schedule_disc(count):
            fraction = 1.0 - (count // (self.nr_minibatches * self.nr_epochs_disc)) / ((self.nr_updates * self.nr_epochs) / self.nr_epochs_disc)
            return self.learning_rate * fraction

        learning_rate = linear_schedule if self.anneal_learning_rate else self.learning_rate
        learning_rate_disc = linear_schedule_disc if self.anneal_learning_rate else self.learning_rate_disc

        state = jnp.array([env.single_observation_space.sample()])
        action = jnp.array([env.single_action_space.sample()])
        next_state = jnp.array([env.single_observation_space.sample()])
        absorbing = jnp.array([0.0])

        self.policy_state = TrainState.create(
            apply_fn=self.policy.apply,
            params=self.policy.init(policy_key, state),
            tx=optax.chain(
                optax.clip_by_global_norm(self.max_grad_norm),
                optax.inject_hyperparams(optax.adam)(learning_rate=learning_rate),
            )
        )

        self.critic_state = TrainState.create(
            apply_fn=self.critic.apply,
            params=self.critic.init(critic_key, state),
            tx=optax.chain(
                optax.clip_by_global_norm(self.max_grad_norm),
                optax.inject_hyperparams(optax.adam)(learning_rate=learning_rate),
            )
        )

        self.discriminator_state = TrainState.create(
            apply_fn=self.discriminator.apply,
            params=self.discriminator.init(discriminator_key, state, action, next_state, absorbing),
            tx=optax.chain(
                optax.clip_by_global_norm(self.max_grad_norm),
                optax.inject_hyperparams(optax.adam)(learning_rate=learning_rate_disc),
            )
        )

        if self.reward_fn_approximator:
            learning_rate_reward_fn = self.learning_rate_reward_fn
            self.key, reward_fn_key = jax.random.split(self.key)
            self.reward_fn_state = TrainState.create(
                apply_fn=self.reward_fn.apply,
                params=self.reward_fn.init(reward_fn_key, state, action, next_state),
                tx=optax.chain(
                    optax.clip_by_global_norm(self.max_grad_norm),
                    optax.inject_hyperparams(optax.adam)(learning_rate=learning_rate_reward_fn),
                )
            )

        if self.save_model:
            os.makedirs(self.save_path)
            self.best_mean_return = -np.inf
            self.best_model_file_name = "best.model"
            self.best_model_checkpointer = orbax.checkpoint.PyTreeCheckpointer()


    def train_irl(self):

        @jax.jit
        def get_action_and_value(policy_state: TrainState, critic_state: TrainState, state: np.ndarray, key: jax.random.PRNGKey):
            action_mean, action_logstd = self.policy.apply(policy_state.params, state)
            action_std = jnp.exp(action_logstd)
            key, subkey = jax.random.split(key)
            action = action_mean + action_std * jax.random.normal(subkey, shape=action_mean.shape)
            log_prob = -0.5 * ((action - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd
            value = self.critic.apply(critic_state.params, state)
            processed_action = self.get_processed_action(action)
            return processed_action, action, value.reshape(-1), log_prob.sum(1), key, action_mean, action_logstd

        @jax.jit
        def get_log_density_ratio(discriminator_params, state: np.ndarray, action: np.ndarray, next_state: np.ndarray, absorbing: np.ndarray):
            logits = self.discriminator.apply(discriminator_params, state, action, next_state, absorbing)
            return logits
        
        get_log_density_ratio = jax.vmap(get_log_density_ratio, in_axes=(None, 0, 0, 0, 0), out_axes=0)
        

        @jax.jit
        def get_reward_prediction(reward_fn_params, state: np.ndarray, action: np.ndarray, next_state: np.ndarray):
            logits = self.reward_fn.apply(reward_fn_params, state, action, next_state)
            return logits
        
        get_reward_prediction = jax.vmap(get_reward_prediction, in_axes=(None, 0, 0, 0), out_axes=0)
    

        @jax.jit
        def calculate_gae_advantages(critic_state: TrainState, next_states: np.ndarray, rewards: np.ndarray, terminations: np.ndarray, values: np.ndarray):
            def compute_advantages(carry, t):
                prev_advantage = carry[0]
                advantage = delta[t] + self.gamma * self.gae_lambda * (1 - terminations[t]) * prev_advantage
                return (advantage,), advantage

            next_values = self.critic.apply(critic_state.params, next_states).squeeze(-1)
            delta = rewards + self.gamma * next_values * (1.0 - terminations) - values
            init_advantages = delta[-1]
            _, advantages = jax.lax.scan(compute_advantages, (init_advantages,), jnp.arange(self.nr_steps - 2, -1, -1))
            advantages = jnp.concatenate([advantages[::-1], jnp.array([init_advantages])])
            returns = advantages + values
            return advantages, returns


        @jax.jit
        def calculate_gae_advantages_absorbing(critic_state: TrainState, next_states: np.ndarray, rewards: np.ndarray, rewards_next_state: np.ndarray, terminations: np.ndarray, values: np.ndarray):
            """
            Correctly handle absorbing state value and entropy (instead of setting to 0.0)
            """
            def compute_advantages(carry, t):
                prev_advantage = carry[0]
                advantage = delta[t] + self.gamma * self.gae_lambda * (1 - terminations[t]) * prev_advantage
                return (advantage,), advantage

            next_values = self.critic.apply(critic_state.params, next_states).squeeze(-1)
            terminal_tail = (self.gamma / (1.0 - self.gamma)) * (rewards_next_state + self.entropy_coef * self.H_terminal)
            delta = rewards + self.gamma * next_values * (1.0 - terminations) + (terminations * terminal_tail) - values
            init_advantages = delta[-1]
            _, advantages = jax.lax.scan(compute_advantages, (init_advantages,), jnp.arange(self.nr_steps - 2, -1, -1))
            advantages = jnp.concatenate([advantages[::-1], jnp.array([init_advantages])])
            returns = advantages + values
            return advantages, returns

        dim = np.prod(self.as_shape).item()

        """
        TRIRL Update
        """
        @partial(jax.jit, static_argnames=('reward_type'))
        def trirl_update(discriminator_state: TrainState,
                   states: np.ndarray, actions: np.ndarray,
                   next_states: np.ndarray, absorbing: np.ndarray,
                   expert_states: np.ndarray, expert_actions: np.ndarray,
                   expert_next_states: np.ndarray, expert_absorbing: np.ndarray,
                   key: jax.random.PRNGKey, reward_type='state-action'):

            def trirl_loss_fn(discriminator_params, state, action, expert_state, expert_action, label, expert_label, next_state=None, absorbing=None, expert_next_state=None, expert_absorbing=None):
                logits = self.discriminator.apply(discriminator_params, state, action, next_state, absorbing)
                expert_logits = self.discriminator.apply(discriminator_params, expert_state, expert_action, expert_next_state, expert_absorbing)
                bce_agent = optax.sigmoid_binary_cross_entropy(logits, label).mean()
                bce_expert = optax.sigmoid_binary_cross_entropy(expert_logits, expert_label).mean()

                # Gradient penalty on an interpolated sample between the expert and agent
                alpha = self.gp_alpha
                gp_lambda = self.gp_lambda
                interpolated_state = alpha * expert_state + (1 - alpha) * state
                interpolated_action = alpha * expert_action + (1 - alpha) * action
                interpolated_next_state = alpha * expert_next_state + (1 - alpha) * next_state
                interpolated_abs = 0.0 * expert_absorbing # assume interpolated state to be non-absorbing                
                grad_state, grad_action, grad_next_state = jax.grad(lambda s, a, sn, ab: jnp.sum(self.discriminator.apply(discriminator_params, s, a, sn, ab)), argnums=(0, 1, 2))(interpolated_state, interpolated_action, interpolated_next_state, interpolated_abs)                
                if reward_type in ["state-action", "uncorrelated"]:
                    grad_norm = jnp.sqrt(jnp.sum(jnp.square(grad_state)) + jnp.sum(jnp.square(grad_action)))
                elif reward_type in ["state-based", "shaped"]:
                    grad_norm = jnp.sqrt(jnp.sum(jnp.square(grad_state)) + jnp.sum(jnp.square(grad_next_state)))
                elif reward_type in ["shaped-sa"]:
                    grad_norm = jnp.sqrt(jnp.sum(jnp.square(grad_state)) + jnp.sum(jnp.square(grad_action)) + jnp.sum(jnp.square(grad_next_state)))
                gp = (grad_norm - 1.0) ** 2

                bce_loss = bce_agent + bce_expert + gp_lambda * gp
                metrics = {
                    "loss/discriminator_loss": bce_loss,
                    "loss/discriminator_agent_loss": bce_agent,
                    "loss/discriminator_expert_loss": bce_expert,
                    "loss/discriminator_gp": gp
                }
                return bce_loss, (metrics)


            batch_states = states.reshape((-1,) + self.os_shape)
            batch_actions = actions.reshape((-1,) + self.as_shape)

            # Expert batch
            key, shuffle_key = jax.random.split(key)
            perm = jax.random.permutation(shuffle_key, expert_states.shape[0])
            expert_states = expert_states[perm]
            expert_actions = expert_actions[perm]
            batch_expert_states = expert_states[:self.batch_size]
            batch_expert_actions = expert_actions[:self.batch_size]
            expert_labels = jnp.ones((self.batch_size, 1), dtype=jnp.float32)
            rollout_labels = jnp.zeros((self.batch_size, 1), dtype=jnp.float32)

            batch_next_states = next_states.reshape((-1,) + self.os_shape)
            batch_absorbing = absorbing.reshape(-1)
            expert_next_states = expert_next_states[perm]
            expert_absorbing = expert_absorbing[perm]
            batch_expert_next_states = expert_next_states[:self.batch_size]
            batch_expert_absorbing = expert_absorbing[:self.batch_size]

            vmap_trirl_loss_fn = jax.vmap(trirl_loss_fn, in_axes=(None, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), out_axes=0)
            safe_mean = lambda x: jnp.mean(x) if x is not None else x
            mean_vmapped_trirl_loss_fn = lambda *a, **k: tree.map_structure(safe_mean, vmap_trirl_loss_fn(*a, **k))
            grad_trirl_loss_fn = jax.value_and_grad(mean_vmapped_trirl_loss_fn, argnums=(0), has_aux=True)

            key, subkey = jax.random.split(key)
            batch_indices_disc = jnp.tile(jnp.arange(self.batch_size), (self.nr_epochs_disc, 1))
            batch_indices_disc = jax.random.permutation(subkey, batch_indices_disc, axis=1, independent=True)
            batch_indices_disc = batch_indices_disc.reshape((self.nr_epochs_disc * self.nr_minibatches, self.minibatch_size))

            def minibatch_update(carry, minibatch_indices_disc):

                discriminator_state = carry

                # TRIRL UPDATE
                (trirl_loss, (metrics)), (discriminator_gradients) = grad_trirl_loss_fn(
                    discriminator_state.params,
                    batch_states[minibatch_indices_disc],
                    batch_actions[minibatch_indices_disc],
                    batch_expert_states[minibatch_indices_disc],
                    batch_expert_actions[minibatch_indices_disc],
                    rollout_labels[minibatch_indices_disc],
                    expert_labels[minibatch_indices_disc],
                    batch_next_states[minibatch_indices_disc],
                    batch_absorbing[minibatch_indices_disc],
                    batch_expert_next_states[minibatch_indices_disc],
                    batch_expert_absorbing[minibatch_indices_disc],
                )

                discriminator_state = discriminator_state.apply_gradients(grads=discriminator_gradients)
                carry = (discriminator_state)

                return carry, (metrics)
            
            init_carry = (discriminator_state)
            carry, (metrics) = jax.lax.scan(minibatch_update, init_carry, batch_indices_disc)
            discriminator_state = carry

            # Calculate mean metrics
            mean_metrics = {key: jnp.mean(metrics[key]) for key in metrics}
            mean_metrics["lr/learning_rate_disc"] = discriminator_state.opt_state[1].hyperparams["learning_rate"]

            return discriminator_state, mean_metrics, key


        @jax.jit
        def update(policy_state: TrainState, critic_state: TrainState,
                   states: np.ndarray, actions: np.ndarray, advantages: np.ndarray, returns: np.ndarray, values: np.ndarray, log_probs: np.ndarray, 
                   old_action_means:np.ndarray, old_action_logstd:np.ndarray, eta: float,
                   key: jax.random.PRNGKey):
            def loss_fn(policy_params, critic_params, state_b, action_b, log_prob_b, return_b, advantage_b, old_mean_b, old_logstd_b, eta):
                # Policy loss
                action_mean, action_logstd = self.policy.apply(policy_params, state_b)
                action_std = jnp.exp(action_logstd)
                new_log_prob = -0.5 * ((action_b - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd
                new_log_prob = new_log_prob.sum(1)
                entropy = action_logstd + 0.5 * jnp.log(2.0 * jnp.pi * jnp.e)
                entropy = self.entropy_coef * entropy # scaling to reduce critic loss

                logratio = new_log_prob - log_prob_b
                ratio = jnp.exp(logratio)
                approx_kl_div = (ratio - 1) - logratio
                clip_fraction = jnp.float32((jnp.abs(ratio - 1) > self.clip_range))

                pg_loss1 = -advantage_b * ratio
                pg_loss2 = -advantage_b * jnp.clip(ratio, 1 - self.clip_range, 1 + self.clip_range)
                pg_loss = jnp.maximum(pg_loss1, pg_loss2)
                entropy_loss = entropy.sum(1)
                
                # Critic loss
                new_value = self.critic.apply(critic_params, state_b)
                critic_loss = 0.5 * (new_value - return_b) ** 2

                # Trust Region Loss
                # Analytical KL (across all samples)
                old_logstd_b = jnp.expand_dims(old_logstd_b, 0) # shape (1, as_shape)
                old_action_std = jnp.exp(old_logstd_b)
                tr_loss_maha = 0.5 * jnp.sum(((old_mean_b - action_mean)/old_action_std) ** 2, axis=1)
                tr_loss_cov_part = 0.5 * jnp.sum(2.0 * (jnp.log(old_action_std) - jnp.log(action_std)) + (action_std/old_action_std)**2 - 1.0, axis=1)
                trust_region_loss = tr_loss_maha + tr_loss_cov_part

                # Combine losses
                loss = pg_loss - self.entropy_coef * entropy_loss + self.critic_coef * critic_loss + eta * trust_region_loss

                # Create metrics
                metrics = {
                    "loss/policy_gradient_loss": pg_loss,
                    "loss/critic_loss": critic_loss,
                    "loss/entropy_loss": entropy_loss,
                    "loss/trust_region_loss": trust_region_loss,
                    "policy_ratio/approx_kl": approx_kl_div,
                    "policy_ratio/clip_fraction": clip_fraction,
                }

                return loss, (metrics)
            

            batch_states = states.reshape((-1,) + self.os_shape)
            batch_actions = actions.reshape((-1,) + self.as_shape)
            batch_advantages = advantages.reshape(-1)
            batch_returns = returns.reshape(-1)
            batch_log_probs = log_probs.reshape(-1)
            batch_action_means = old_action_means.reshape((-1,) + self.as_shape)

            vmap_loss_fn = jax.vmap(loss_fn, in_axes=(None, None, 0, 0, 0, 0, 0, 0, None, None), out_axes=0)
            safe_mean = lambda x: jnp.mean(x) if x is not None else x
            mean_vmapped_loss_fn = lambda *a, **k: tree.map_structure(safe_mean, vmap_loss_fn(*a, **k))
            grad_loss_fn = jax.value_and_grad(mean_vmapped_loss_fn, argnums=(0, 1), has_aux=True)

            key, subkey = jax.random.split(key)
            batch_indices = jnp.tile(jnp.arange(self.batch_size), (self.nr_epochs, 1))
            batch_indices = jax.random.permutation(subkey, batch_indices, axis=1, independent=True)
            batch_indices = batch_indices.reshape((self.nr_epochs * self.nr_minibatches, self.minibatch_size))

            def minibatch_update(carry, minibatch_indices):
                policy_state, critic_state = carry
                minibatch_advantages = batch_advantages[minibatch_indices]
                minibatch_advantages = (minibatch_advantages - jnp.mean(minibatch_advantages)) / (jnp.std(minibatch_advantages) + 1e-8)

                (loss, (metrics)), (policy_gradients, critic_gradients) = grad_loss_fn(
                    policy_state.params,
                    critic_state.params,
                    batch_states[minibatch_indices],
                    batch_actions[minibatch_indices],
                    batch_log_probs[minibatch_indices],
                    batch_returns[minibatch_indices],
                    minibatch_advantages,
                    batch_action_means[minibatch_indices],
                    old_action_logstd, eta
                )

                policy_state = policy_state.apply_gradients(grads=policy_gradients)
                critic_state = critic_state.apply_gradients(grads=critic_gradients)

                metrics["gradients/policy_grad_norm"] = optax.global_norm(policy_gradients)
                metrics["gradients/critic_grad_norm"] = optax.global_norm(critic_gradients)
                metrics = metrics

                carry = (policy_state, critic_state)

                return carry, (metrics)
            
            init_carry = (policy_state, critic_state)
            carry, (metrics) = jax.lax.scan(minibatch_update, init_carry, batch_indices)
            policy_state, critic_state = carry

            # Calculate mean metrics
            mean_metrics = {key: jnp.mean(metrics[key]) for key in metrics}
            mean_metrics["lr/learning_rate"] = policy_state.opt_state[1].hyperparams["learning_rate"]
            mean_metrics["v_value/explained_variance"] = 1 - jnp.var(returns - values) / (jnp.var(returns) + 1e-8)
            mean_metrics["policy/std_dev"] = jnp.mean(jnp.exp(policy_state.params["params"]["policy_logstd"]))
            mean_metrics["policy/entropy"] = jnp.mean(jnp.sum(policy_state.params["params"]["policy_logstd"] + 0.5 * jnp.log(2.0 * jnp.pi * jnp.e)))

            return policy_state, critic_state, mean_metrics, key


        """
        Reward Approximator Update
        """

        @jax.jit
        def reward_fn_update(reward_fn_state: TrainState,
                   states: np.ndarray, actions: np.ndarray, next_states: np.ndarray,
                   target_reward: np.ndarray,
                   key: jax.random.PRNGKey):

            def loss_fn(reward_fn_params, state, action, target_reward, next_state=None):
                logits = self.reward_fn.apply(reward_fn_params, state, action, next_state)
                mse = optax.squared_error(logits, target_reward).mean()

                metrics = {
                    "loss/reward_fn_loss": mse,
                }
                return mse, (metrics)


            batch_states = states.reshape((-1,) + self.os_shape)
            batch_actions = actions.reshape((-1,) + self.as_shape)
            batch_target_reward = target_reward.reshape((-1,) + (1,))

            batch_next_states = next_states.reshape((-1,) + self.os_shape)
            vmap_loss_fn = jax.vmap(loss_fn, in_axes=(None, 0, 0, 0, 0), out_axes=0)
            safe_mean = lambda x: jnp.mean(x) if x is not None else x
            mean_vmapped_loss_fn = lambda *a, **k: tree.map_structure(safe_mean, vmap_loss_fn(*a, **k))
            grad_loss_fn = jax.value_and_grad(mean_vmapped_loss_fn, argnums=(0), has_aux=True)

            key, subkey = jax.random.split(key)
            batch_indices_rew = jnp.tile(jnp.arange(self.batch_size), (self.nr_epochs_rew, 1))
            batch_indices_rew = jax.random.permutation(subkey, batch_indices_rew, axis=1, independent=True)
            batch_indices_rew = batch_indices_rew.reshape((self.nr_epochs_rew * self.nr_minibatches, self.minibatch_size))

            def minibatch_update(carry, minibatch_indices_rew):

                reward_fn_state = carry

                # REWARD UPDATE
                (rew_fn_loss, (metrics)), (rew_fn_gradients) = grad_loss_fn(
                    reward_fn_state.params,
                    batch_states[minibatch_indices_rew],
                    batch_actions[minibatch_indices_rew],
                    batch_target_reward[minibatch_indices_rew],
                    batch_next_states[minibatch_indices_rew],
                )

                reward_fn_state = reward_fn_state.apply_gradients(grads=rew_fn_gradients)
                carry = (reward_fn_state)

                return carry, (metrics)
            
            init_carry = (reward_fn_state)
            carry, (metrics) = jax.lax.scan(minibatch_update, init_carry, batch_indices_rew)
            reward_fn_state = carry

            # Calculate mean metrics
            mean_metrics = {key: jnp.mean(metrics[key]) for key in metrics}
            mean_metrics["lr/learning_rate_disc"] = reward_fn_state.opt_state[1].hyperparams["learning_rate"]

            return reward_fn_state, mean_metrics, key

        @jax.jit
        def get_deterministic_action(policy_state: TrainState, state: np.ndarray):
            action_mean, action_logstd = self.policy.apply(policy_state.params, state)
            return self.get_processed_action(action_mean)

        chunked_project = make_chunked_ensemble_rew_project(
            get_log_density_ratio, 
            nr_steps=self.nr_steps, 
            nr_envs=self.nr_envs, 
            epsilon=self.epsilon, 
            beta=self.beta,
            entropy_coef=self.entropy_coef,
            maximum_eta=True
        )

        def linear_schedule_eta(global_step):
            if self.const_eta:
                return self.init_eta
            else:
                fraction = 1.0 - np.clip(global_step, a_min=None, a_max=self.total_timesteps) / self.total_timesteps
                return self.init_eta * fraction


        """ Main Training Loop """

        self.set_train_mode()

        # Get demonstrations
        demonstrations = prepare_expert_data(self.data_path, self.subsampling_cutoff)
        self.expert_states = demonstrations["states"]
        self.expert_next_states = demonstrations["next_states"]
        self.expert_actions = demonstrations["actions"]
        self.expert_absorbing = demonstrations["absorbing"].flatten()

        batch = Batch(
            states=np.zeros((self.nr_steps, self.nr_envs) + self.os_shape),
            next_states=np.zeros((self.nr_steps, self.nr_envs) + self.os_shape),
            actions=np.zeros((self.nr_steps, self.nr_envs) + self.as_shape),
            rewards=np.zeros((self.nr_steps, self.nr_envs)),
            proj_rewards=np.zeros((self.nr_steps, self.nr_envs)),
            merged_rewards=np.zeros((self.nr_steps, self.nr_envs)),
            etas = np.zeros((self.nr_steps, self.nr_envs)),
            values=np.zeros((self.nr_steps, self.nr_envs)),
            terminations=np.zeros((self.nr_steps, self.nr_envs)),
            log_probs=np.zeros((self.nr_steps, self.nr_envs)),
            advantages=np.zeros((self.nr_steps, self.nr_envs)),
            returns=np.zeros((self.nr_steps, self.nr_envs)),
            old_action_means=np.zeros((self.nr_steps, self.nr_envs) + self.as_shape),
            old_action_logstd=np.zeros(self.as_shape),
        )

        saving_return_buffer = deque(maxlen=100 * self.nr_envs)

        state, _ = self.env.reset()
        global_step = 0
        nr_updates = 0
        nr_updates_disc = 0
        nr_episodes = 0
        steps_metrics = {}
        while global_step < self.total_timesteps:
            start_time = time.time()
            time_metrics = {}

            # Acting
            dones_this_rollout = 0
            step_info_collection = {}
            for step in range(self.nr_steps):
                processed_action, action, value, log_prob, self.key, action_mean, action_logstd = get_action_and_value(self.policy_state, self.critic_state, state, self.key)
                next_state, reward, terminated, truncated, info = self.env.step(jax.device_get(processed_action))
                done = terminated | truncated
                actual_next_state = next_state.copy()
                for i, single_done in enumerate(done):
                    if single_done:
                        # if self.handle_absorbing_states:
                        #     actual_next_state[i] = np.array(self.env.get_final_observation_at_index(info, i)) * 0.0 # setting one absorbing state (zero vector)
                        # else:
                        actual_next_state[i] = np.array(self.env.get_final_observation_at_index(info, i))
                        saving_return_buffer.append(self.env.get_final_info_value_at_index(info, "episode_return", i))
                        dones_this_rollout += 1
                for key, info_value in self.env.get_logging_info_dict(info).items():
                    step_info_collection.setdefault(key, []).extend(info_value)

                batch.states[step] = state
                batch.next_states[step] = actual_next_state
                batch.actions[step] = action
                batch.rewards[step] = reward
                batch.values[step] = value
                batch.terminations[step] = terminated
                batch.log_probs[step] = log_prob
                batch.old_action_means[step] = action_mean
                batch.old_action_logstd = action_logstd

                state = next_state
                global_step += self.nr_envs
            nr_episodes += dones_this_rollout
            acting_end_time = time.time()
            time_metrics["time/acting_time"] = acting_end_time - start_time


            """ Optimizing Discriminator """
            # Optimizing Discriminator
            self.discriminator_state, trirl_optimization_metrics, self.key = trirl_update(
                self.discriminator_state,
                batch.states, batch.actions, batch.next_states, batch.terminations, self.expert_states, self.expert_actions, self.expert_next_states, self.expert_absorbing,
                self.key, self.reward_type
            )

            projection_start_time = time.time()
            # Add optimized discriminator to buffer
            self.disc_buffer.append(self.discriminator_state.params)

            # Compute projected reward
            proj_rewards = chunked_project(
                self.disc_buffer.buffer,
                (batch.states.reshape((-1,) + self.os_shape),
                batch.actions.reshape((-1,) + self.as_shape),
                batch.next_states.reshape((-1,) + self.os_shape),
                batch.terminations.reshape(-1,)),
                self.disc_buffer.eta_buffer,
                chunk_size=self.chunk_size_dict[self.nr_steps],
            )
            batch.proj_rewards = proj_rewards

            # Fit a network to predict the projected reward
            if self.reward_fn_approximator:
                self.reward_fn_state, rew_fn_optimization_metrics, self.key = reward_fn_update(
                    self.reward_fn_state,
                    batch.states, batch.actions, batch.next_states, batch.proj_rewards,
                    self.key
                )
                proj_rewards = get_reward_prediction(self.reward_fn_state.params, batch.states.reshape((-1,) + self.os_shape), batch.actions.reshape((-1,) + self.as_shape), batch.next_states.reshape((-1,) + self.os_shape))
                proj_rewards = proj_rewards.reshape(batch.rewards.shape)
                batch.proj_rewards = proj_rewards

            batch.merged_rewards = batch.rewards * self.env_reward_frac + batch.proj_rewards * (1 - self.env_reward_frac)
            projection_end_time = time.time()
            time_metrics["time/projection_time"] = projection_end_time - projection_start_time

            if self.handle_absorbing_states:
                if self.reward_fn_approximator:
                    proj_rewards_next_state = get_reward_prediction(self.reward_fn_state.params, batch.next_states.reshape((-1,) + self.os_shape), batch.actions.reshape((-1,) + self.as_shape), batch.next_states.reshape((-1,) + self.os_shape))
                    proj_rewards_next_state = proj_rewards_next_state.reshape(batch.rewards.shape)
                else:
                    proj_rewards_next_state = chunked_project(
                        self.disc_buffer.buffer,
                        (batch.next_states.reshape((-1,) + self.os_shape), # next state
                        batch.actions.reshape((-1,) + self.as_shape), # next actions don't matter in absorbing states
                        batch.next_states.reshape((-1,) + self.os_shape), # next next states are the same as next states if in absorbing state
                        np.ones_like(batch.terminations.reshape(-1,))), # absorbing is true
                        self.disc_buffer.eta_buffer, # should not make a difference since we use max over the while batch anyway
                        chunk_size=self.chunk_size_dict[self.nr_steps],
                    )

                # Calculating advantages and returns
                batch.advantages, batch.returns = calculate_gae_advantages_absorbing(self.critic_state, batch.next_states, batch.merged_rewards, proj_rewards_next_state, batch.terminations, batch.values)
                calc_adv_return_end_time = time.time()
                time_metrics["time/calc_adv_and_return_time"] = calc_adv_return_end_time - acting_end_time
            else:
                batch.advantages, batch.returns = calculate_gae_advantages(self.critic_state, batch.next_states, batch.merged_rewards, batch.terminations, batch.values)
                calc_adv_return_end_time = time.time()
                time_metrics["time/calc_adv_and_return_time"] = calc_adv_return_end_time - projection_end_time


            """ Optimizing Policy """            
            # Optimizing Policy
            eta = linear_schedule_eta(global_step)
            self.policy_state, self.critic_state, optimization_metrics, self.key = update(
                self.policy_state, self.critic_state,
                batch.states, batch.actions, batch.advantages, batch.returns, batch.values, batch.log_probs, batch.old_action_means, batch.old_action_logstd, eta,
                self.key,
            )

            # Add eta to the buffer
            self.disc_buffer.append_eta(eta * np.ones((self.nr_steps, self.nr_envs)))
            trirl_optimization_metrics["projection/eta"] = np.array([eta])
            trirl_optimization_metrics["projection/eta_buffer_length"] = jnp.array([self.disc_buffer.len_eta_buffer()])

            optimization_metrics = optimization_metrics | trirl_optimization_metrics
            if self.reward_fn_approximator:
                optimization_metrics = optimization_metrics | rew_fn_optimization_metrics
            optimization_metrics = {key: value.item() for key, value in optimization_metrics.items()}
            nr_updates += self.nr_epochs * self.nr_minibatches
            nr_updates_disc += self.nr_epochs_disc * self.nr_minibatches

            optimizing_end_time = time.time()
            time_metrics["time/optimizing_time"] = optimizing_end_time - calc_adv_return_end_time

            # Evaluating
            evaluation_metrics = {}
            if global_step % self.evaluation_frequency == 0 and self.evaluation_frequency != -1:
                self.set_eval_mode()
                state, _ = self.env.reset()
                eval_nr_episodes = 0
                evaluation_metrics = {"eval/episode_return": [], "eval/episode_length": []}
                while True:
                    processed_action = get_deterministic_action(self.policy_state, state)
                    state, reward, terminated, truncated, info = self.env.step(jax.device_get(processed_action))
                    done = terminated | truncated
                    for i, single_done in enumerate(done):
                        if single_done:
                            eval_nr_episodes += 1
                            evaluation_metrics["eval/episode_return"].append(self.env.get_final_info_value_at_index(info, "episode_return", i))
                            evaluation_metrics["eval/episode_length"].append(self.env.get_final_info_value_at_index(info, "episode_length", i))
                            if eval_nr_episodes == self.evaluation_episodes:
                                break
                    if eval_nr_episodes == self.evaluation_episodes:
                        break
                evaluation_metrics = {key: np.mean(value) for key, value in evaluation_metrics.items()}
                state, _ = self.env.reset()
                self.set_train_mode()
            
            evaluating_end_time = time.time()
            time_metrics["time/evaluating_time"] = evaluating_end_time - optimizing_end_time
            

            # Saving
            # Also only save when there were finished episodes this update
            if self.save_model and dones_this_rollout > 0:
                mean_return = np.mean(saving_return_buffer)
                if mean_return > self.best_mean_return:
                    self.best_mean_return = mean_return
                    self.save(global_step)
            
            saving_end_time = time.time()
            time_metrics["time/saving_time"] = saving_end_time - evaluating_end_time

            time_metrics["time/sps"] = int((self.nr_steps * self.nr_envs) / (saving_end_time - start_time))


            # Logging
            self.start_logging(global_step)

            steps_metrics["steps/nr_env_steps"] = global_step
            steps_metrics["steps/nr_updates"] = nr_updates
            steps_metrics["steps/nr_updates_disc"] = nr_updates_disc
            steps_metrics["steps/nr_episodes"] = nr_episodes

            rollout_info_metrics = {}
            env_info_metrics = {}
            if step_info_collection:
                info_names = list(step_info_collection.keys())
                for info_name in info_names:
                    metric_group = "rollout" if info_name in ["episode_return", "episode_length"] else "env_info"
                    metric_dict = rollout_info_metrics if metric_group == "rollout" else env_info_metrics
                    mean_value = np.mean(step_info_collection[info_name])
                    if mean_value == mean_value:  # Check if mean_value is NaN
                        metric_dict[f"{metric_group}/{info_name}"] = mean_value
            
            combined_metrics = {**rollout_info_metrics, **evaluation_metrics, **env_info_metrics, **steps_metrics, **time_metrics, **optimization_metrics}
            for key, value in combined_metrics.items():
                self.log(f"{key}", value, global_step)

            self.end_logging()


    #######################################################################
    #######################################################################
    """
    FROM HERE ON ONLY TESTING & SAVING FUNCTIONS
    """
    #######################################################################
    #######################################################################

    def train(self):
        if self.global_rew_experiment:
            self.learnt_reward_pearson_correlation()
            # self.train_ppo()
        else:
            self.train_irl()


    def learnt_reward_pearson_correlation(self):
        from scipy import stats
        import matplotlib.pyplot as plt

        # Get demonstrations
        demonstrations = prepare_expert_data(self.data_path)
        expert_states = demonstrations["states"]
        expert_next_states = demonstrations["next_states"]
        expert_actions = demonstrations["actions"]
        expert_rewards = demonstrations["rewards"].reshape((-1,1))
        expert_absorbing = demonstrations["absorbing"].flatten()
        # if self.handle_absorbing_states:
        #     abs_indices = np.where(expert_absorbing > 0.0)[0]
        #     expert_states[abs_indices] = expert_states[abs_indices] * 0.0
        #     expert_next_states[abs_indices] = expert_states[abs_indices]

        shaping = 0.0 # if shaped rewards are used then this removes the shaping part 
        if self.reward_fn_approximator:
            @jax.jit
            def get_reward_prediction(reward_fn_params, state: np.ndarray, action: np.ndarray, next_state: np.ndarray):
                logits = self.reward_fn.apply(reward_fn_params, state, action, next_state, shaping=shaping)
                return logits
            
            get_reward_prediction = jax.vmap(get_reward_prediction, in_axes=(None, 0, 0, 0), out_axes=0)

            proj_rewards = get_reward_prediction(self.reward_fn_state.params, expert_states.reshape((-1,) + self.os_shape), expert_actions.reshape((-1,) + self.as_shape), expert_next_states.reshape((-1,) + self.os_shape))
        else:
            @jax.jit
            def get_log_density_ratio(discriminator_params, state: np.ndarray, action: np.ndarray, next_state: np.ndarray, absorbing: np.ndarray):
                logits = self.discriminator.apply(discriminator_params, state, action, next_state, absorbing, shaping=shaping)
                return logits
            
            get_log_density_ratio = jax.vmap(get_log_density_ratio, in_axes=(None, 0, 0, 0, 0), out_axes=0)
            chunked_project = make_chunked_ensemble_rew_project(
                get_log_density_ratio, 
                nr_steps=expert_states.shape[0], 
                nr_envs=1, 
                epsilon=self.epsilon, 
                beta=self.beta,
                entropy_coef=self.entropy_coef,
                maximum_eta=True
            )

            proj_rewards = chunked_project(
                self.disc_buffer.buffer,
                (expert_states.reshape((-1,) + self.os_shape),
                expert_actions.reshape((-1,) + self.as_shape),
                expert_next_states.reshape((-1,) + self.os_shape),
                expert_absorbing.reshape(-1,)),
                self.disc_buffer.eta_buffer,
                chunk_size=10,
            )

        pearson_correlation = stats.pearsonr(proj_rewards, expert_rewards)
        print(f"pearson Correlation: {pearson_correlation[0]} p value: {pearson_correlation[1]}")

        plt.figure(figsize=(8, 6))
        plt.scatter(proj_rewards.flatten(), expert_rewards.flatten(), alpha=0.5, s=10)
        plt.xlabel("Projected Rewards")
        plt.ylabel("Expert Rewards")
        plt.title("Scatter Plot of Projected vs Expert Rewards")
        plt.grid(True)
        plt.show()


    def train_ppo(self):

        rlx_logger.info(f"Training PPO from scratch with a saved reward function")

        @jax.jit
        def get_action_and_value(policy_state: TrainState, critic_state: TrainState, state: np.ndarray, key: jax.random.PRNGKey):
            action_mean, action_logstd = self.policy.apply(policy_state.params, state)
            action_std = jnp.exp(action_logstd)
            key, subkey = jax.random.split(key)
            action = action_mean + action_std * jax.random.normal(subkey, shape=action_mean.shape)
            log_prob = -0.5 * ((action - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd
            value = self.critic.apply(critic_state.params, state)
            processed_action = self.get_processed_action(action)
            return processed_action, action, value.reshape(-1), log_prob.sum(1), key
        

        @jax.jit
        def calculate_gae_advantages(critic_state: TrainState, next_states: np.ndarray, rewards: np.ndarray, terminations: np.ndarray, values: np.ndarray):
            def compute_advantages(carry, t):
                prev_advantage = carry[0]
                advantage = delta[t] + self.gamma * self.gae_lambda * (1 - terminations[t]) * prev_advantage
                return (advantage,), advantage

            next_values = self.critic.apply(critic_state.params, next_states).squeeze(-1)
            delta = rewards + self.gamma * next_values * (1.0 - terminations) - values
            init_advantages = delta[-1]
            _, advantages = jax.lax.scan(compute_advantages, (init_advantages,), jnp.arange(self.nr_steps - 2, -1, -1))
            advantages = jnp.concatenate([advantages[::-1], jnp.array([init_advantages])])
            returns = advantages + values
            return advantages, returns
        
        @jax.jit
        def calculate_gae_advantages_absorbing(critic_state: TrainState, next_states: np.ndarray, rewards: np.ndarray, rewards_next_state: np.ndarray, terminations: np.ndarray, values: np.ndarray):
            """
            Correctly handle absorbing state value and entropy (instead of setting to 0.0)
            """
            def compute_advantages(carry, t):
                prev_advantage = carry[0]
                advantage = delta[t] + self.gamma * self.gae_lambda * (1 - terminations[t]) * prev_advantage
                return (advantage,), advantage

            next_values = self.critic.apply(critic_state.params, next_states).squeeze(-1)
            terminal_tail = (self.gamma / (1.0 - self.gamma)) * (rewards_next_state + self.entropy_coef * self.H_terminal)
            delta = rewards + self.gamma * next_values * (1.0 - terminations) + (terminations * terminal_tail) - values
            init_advantages = delta[-1]
            _, advantages = jax.lax.scan(compute_advantages, (init_advantages,), jnp.arange(self.nr_steps - 2, -1, -1))
            advantages = jnp.concatenate([advantages[::-1], jnp.array([init_advantages])])
            returns = advantages + values
            return advantages, returns

        @jax.jit
        def get_log_density_ratio(discriminator_params, state: np.ndarray, action: np.ndarray, next_state: np.ndarray, absorbing: np.ndarray):
            logits = self.discriminator.apply(discriminator_params, state, action, next_state, absorbing)
            return logits
        
        get_log_density_ratio = jax.vmap(get_log_density_ratio, in_axes=(None, 0, 0, 0, 0), out_axes=0)

        chunked_project = make_chunked_ensemble_rew_project(
            get_log_density_ratio, 
            nr_steps=self.nr_steps, 
            nr_envs=self.nr_envs, 
            epsilon=self.epsilon, 
            beta=self.beta,
            entropy_coef=self.entropy_coef,
            maximum_eta=True
        )

        @jax.jit
        def update(policy_state: TrainState, critic_state: TrainState,
                   states: np.ndarray, actions: np.ndarray, advantages: np.ndarray, returns: np.ndarray, values: np.ndarray, log_probs: np.ndarray,
                   key: jax.random.PRNGKey):
            def loss_fn(policy_params, critic_params, state_b, action_b, log_prob_b, return_b, advantage_b):
                # Policy loss
                action_mean, action_logstd = self.policy.apply(policy_params, state_b)
                action_std = jnp.exp(action_logstd)
                new_log_prob = -0.5 * ((action_b - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd
                new_log_prob = new_log_prob.sum(1)
                entropy = action_logstd + 0.5 * jnp.log(2.0 * jnp.pi * jnp.e)
                
                logratio = new_log_prob - log_prob_b
                ratio = jnp.exp(logratio)
                approx_kl_div = (ratio - 1) - logratio
                clip_fraction = jnp.float32((jnp.abs(ratio - 1) > self.clip_range))

                pg_loss1 = -advantage_b * ratio
                pg_loss2 = -advantage_b * jnp.clip(ratio, 1 - self.clip_range, 1 + self.clip_range)
                pg_loss = jnp.maximum(pg_loss1, pg_loss2)
                
                entropy_loss = entropy.sum(1)
                
                # Critic loss
                new_value = self.critic.apply(critic_params, state_b)
                critic_loss = 0.5 * (new_value - return_b) ** 2

                # Combine losses
                loss = pg_loss - self.entropy_coef * entropy_loss + self.critic_coef * critic_loss

                # Create metrics
                metrics = {
                    "loss/policy_gradient_loss": pg_loss,
                    "loss/critic_loss": critic_loss,
                    "loss/entropy_loss": entropy_loss,
                    "policy_ratio/approx_kl": approx_kl_div,
                    "policy_ratio/clip_fraction": clip_fraction,
                }


                return loss, (metrics)
            

            batch_states = states.reshape((-1,) + self.os_shape)
            batch_actions = actions.reshape((-1,) + self.as_shape)
            batch_advantages = advantages.reshape(-1)
            batch_returns = returns.reshape(-1)
            batch_log_probs = log_probs.reshape(-1)

            vmap_loss_fn = jax.vmap(loss_fn, in_axes=(None, None, 0, 0, 0, 0, 0), out_axes=0)
            safe_mean = lambda x: jnp.mean(x) if x is not None else x
            mean_vmapped_loss_fn = lambda *a, **k: tree.map_structure(safe_mean, vmap_loss_fn(*a, **k))
            grad_loss_fn = jax.value_and_grad(mean_vmapped_loss_fn, argnums=(0, 1), has_aux=True)

            key, subkey = jax.random.split(key)
            batch_indices = jnp.tile(jnp.arange(self.batch_size), (self.nr_epochs, 1))
            batch_indices = jax.random.permutation(subkey, batch_indices, axis=1, independent=True)
            batch_indices = batch_indices.reshape((self.nr_epochs * self.nr_minibatches, self.minibatch_size))

            def minibatch_update(carry, minibatch_indices):
                policy_state, critic_state = carry

                minibatch_advantages = batch_advantages[minibatch_indices]
                minibatch_advantages = (minibatch_advantages - jnp.mean(minibatch_advantages)) / (jnp.std(minibatch_advantages) + 1e-8)

                (loss, (metrics)), (policy_gradients, critic_gradients) = grad_loss_fn(
                    policy_state.params,
                    critic_state.params,
                    batch_states[minibatch_indices],
                    batch_actions[minibatch_indices],
                    batch_log_probs[minibatch_indices],
                    batch_returns[minibatch_indices],
                    minibatch_advantages
                )

                policy_state = policy_state.apply_gradients(grads=policy_gradients)
                critic_state = critic_state.apply_gradients(grads=critic_gradients)

                metrics["gradients/policy_grad_norm"] = optax.global_norm(policy_gradients)
                metrics["gradients/critic_grad_norm"] = optax.global_norm(critic_gradients)

                carry = (policy_state, critic_state)

                return carry, (metrics)
            
            init_carry = (policy_state, critic_state)
            carry, (metrics) = jax.lax.scan(minibatch_update, init_carry, batch_indices)
            policy_state, critic_state = carry

            # Calculate mean metrics
            mean_metrics = {key: jnp.mean(metrics[key]) for key in metrics}
            mean_metrics["lr/learning_rate"] = policy_state.opt_state[1].hyperparams["learning_rate"]
            mean_metrics["v_value/explained_variance"] = 1 - jnp.var(returns - values) / (jnp.var(returns) + 1e-8)
            mean_metrics["policy/std_dev"] = jnp.mean(jnp.exp(policy_state.params["params"]["policy_logstd"]))

            return policy_state, critic_state, mean_metrics, key


        @jax.jit
        def get_deterministic_action(policy_state: TrainState, state: np.ndarray):
            action_mean, action_logstd = self.policy.apply(policy_state.params, state)
            return self.get_processed_action(action_mean)

        @jax.jit
        def get_reward_prediction(reward_fn_params, state: np.ndarray, action: np.ndarray, next_state: np.ndarray):
            logits = self.reward_fn.apply(reward_fn_params, state, action, next_state, shaping=0.0)
            return logits
        
        get_reward_prediction = jax.vmap(get_reward_prediction, in_axes=(None, 0, 0, 0), out_axes=0)
    

        self.set_train_mode()

        batch = BatchPPO(
            states=np.zeros((self.nr_steps, self.nr_envs) + self.os_shape),
            next_states=np.zeros((self.nr_steps, self.nr_envs) + self.os_shape),
            actions=np.zeros((self.nr_steps, self.nr_envs) + self.as_shape),
            rewards=np.zeros((self.nr_steps, self.nr_envs)),
            values=np.zeros((self.nr_steps, self.nr_envs)),
            terminations=np.zeros((self.nr_steps, self.nr_envs)),
            log_probs=np.zeros((self.nr_steps, self.nr_envs)),
            advantages=np.zeros((self.nr_steps, self.nr_envs)),
            returns=np.zeros((self.nr_steps, self.nr_envs)),
        )

        saving_return_buffer = deque(maxlen=100 * self.nr_envs)

        state, _ = self.env.reset()
        global_step = 0
        nr_updates = 0
        nr_episodes = 0
        steps_metrics = {}
        while global_step < self.total_timesteps:
            start_time = time.time()
            time_metrics = {}


            # Acting
            dones_this_rollout = 0
            step_info_collection = {}
            for step in range(self.nr_steps):
                processed_action, action, value, log_prob, self.key = get_action_and_value(self.policy_state, self.critic_state, state, self.key)
                next_state, reward, terminated, truncated, info = self.env.step(jax.device_get(processed_action))
                done = terminated | truncated
                actual_next_state = next_state.copy()
                for i, single_done in enumerate(done):
                    if single_done:
                        if self.handle_absorbing_states:
                            actual_next_state[i] = np.array(self.env.get_final_observation_at_index(info, i)) * 0.0 # setting one absorbing state (zero vector)
                        else:
                            actual_next_state[i] = np.array(self.env.get_final_observation_at_index(info, i))
                        saving_return_buffer.append(self.env.get_final_info_value_at_index(info, "episode_return", i))
                        dones_this_rollout += 1
                for key, info_value in self.env.get_logging_info_dict(info).items():
                    step_info_collection.setdefault(key, []).extend(info_value)

                batch.states[step] = state
                batch.next_states[step] = actual_next_state
                batch.actions[step] = action
                batch.rewards[step] = reward
                batch.values[step] = value
                batch.terminations[step] = terminated
                batch.log_probs[step] = log_prob
                state = next_state
                global_step += self.nr_envs
            nr_episodes += dones_this_rollout
            
            acting_end_time = time.time()
            time_metrics["time/acting_time"] = acting_end_time - start_time

            if self.reward_fn_approximator:
                proj_rewards = get_reward_prediction(self.reward_fn_state.params, batch.states.reshape((-1,) + self.os_shape), batch.actions.reshape((-1,) + self.as_shape), batch.next_states.reshape((-1,) + self.os_shape))                    
                proj_rewards = proj_rewards.reshape(batch.rewards.shape)
            else:
                # Compute projected reward using the saved discriminators
                proj_rewards = chunked_project(
                    self.disc_buffer.buffer,
                    (batch.states.reshape((-1,) + self.os_shape),
                    batch.actions.reshape((-1,) + self.as_shape),
                    batch.next_states.reshape((-1,) + self.os_shape),
                    batch.terminations.reshape(-1,)),
                    self.disc_buffer.eta_buffer,
                    chunk_size=self.chunk_size_dict[self.nr_steps],
                )

            if self.handle_absorbing_states:
                if self.reward_fn_approximator:
                    proj_rewards_next_state = get_reward_prediction(self.reward_fn_state.params, batch.next_states.reshape((-1,) + self.os_shape), 0.0*batch.actions.reshape((-1,) + self.as_shape), batch.next_states.reshape((-1,) + self.os_shape)) 
                    proj_rewards_next_state = proj_rewards_next_state.reshape(batch.rewards.shape)
                else:
                    # Compute projected reward using the saved discriminators
                    proj_rewards_next_state = chunked_project(
                        self.disc_buffer.buffer,
                        (batch.next_states.reshape((-1,) + self.os_shape), # next state
                        0.0*batch.actions.reshape((-1,) + self.as_shape), # next actions don't matter in absorbing states
                        batch.next_states.reshape((-1,) + self.os_shape), # next next states are the same as next states if in absorbing state
                        np.ones_like(batch.terminations.reshape(-1,))), # absorbing is true
                        self.disc_buffer.eta_buffer, # should not make a difference since we use max over the while batch anyway
                        chunk_size=self.chunk_size_dict[self.nr_steps],
                    )

                # Calculating advantages and returns
                batch.advantages, batch.returns = calculate_gae_advantages_absorbing(self.critic_state, batch.next_states, proj_rewards, proj_rewards_next_state, batch.terminations, batch.values)        
                calc_adv_return_end_time = time.time()
                time_metrics["time/calc_adv_and_return_time"] = calc_adv_return_end_time - acting_end_time
            else:
                # Calculating advantages and returns
                batch.advantages, batch.returns = calculate_gae_advantages(self.critic_state, batch.next_states, proj_rewards, batch.terminations, batch.values)
                calc_adv_return_end_time = time.time()
                time_metrics["time/calc_adv_and_return_time"] = calc_adv_return_end_time - acting_end_time


            # Optimizing
            self.policy_state, self.critic_state, optimization_metrics, self.key = update(
                self.policy_state, self.critic_state,
                batch.states, batch.actions, batch.advantages, batch.returns, batch.values, batch.log_probs,
                self.key
            )
            optimization_metrics = {key: value.item() for key, value in optimization_metrics.items()}
            nr_updates += self.nr_epochs * self.nr_minibatches

            optimizing_end_time = time.time()
            time_metrics["time/optimizing_time"] = optimizing_end_time - calc_adv_return_end_time


            # Evaluating
            evaluation_metrics = {}
            if global_step % self.evaluation_frequency == 0 and self.evaluation_frequency != -1:
                self.set_eval_mode()
                state, _ = self.env.reset()
                eval_nr_episodes = 0
                evaluation_metrics = {"eval/episode_return": [], "eval/episode_length": []}
                while True:
                    processed_action = get_deterministic_action(self.policy_state, state)
                    state, reward, terminated, truncated, info = self.env.step(jax.device_get(processed_action))
                    done = terminated | truncated
                    for i, single_done in enumerate(done):
                        if single_done:
                            eval_nr_episodes += 1
                            evaluation_metrics["eval/episode_return"].append(self.env.get_final_info_value_at_index(info, "episode_return", i))
                            evaluation_metrics["eval/episode_length"].append(self.env.get_final_info_value_at_index(info, "episode_length", i))
                            if eval_nr_episodes == self.evaluation_episodes:
                                break
                    if eval_nr_episodes == self.evaluation_episodes:
                        break
                evaluation_metrics = {key: np.mean(value) for key, value in evaluation_metrics.items()}
                state, _ = self.env.reset()
                self.set_train_mode()
            
            evaluating_end_time = time.time()
            time_metrics["time/evaluating_time"] = evaluating_end_time - optimizing_end_time
            

            # Saving
            # Also only save when there were finished episodes this update
            if self.save_model and dones_this_rollout > 0:
                mean_return = np.mean(saving_return_buffer)
                if mean_return > self.best_mean_return:
                    self.best_mean_return = mean_return
                    self.save(global_step)
            
            saving_end_time = time.time()
            time_metrics["time/saving_time"] = saving_end_time - evaluating_end_time

            time_metrics["time/sps"] = int((self.nr_steps * self.nr_envs) / (saving_end_time - start_time))


            # Logging
            self.start_logging(global_step)

            steps_metrics["steps/nr_env_steps"] = global_step
            steps_metrics["steps/nr_updates"] = nr_updates
            steps_metrics["steps/nr_episodes"] = nr_episodes

            rollout_info_metrics = {}
            env_info_metrics = {}
            if step_info_collection:
                info_names = list(step_info_collection.keys())
                for info_name in info_names:
                    metric_group = "rollout" if info_name in ["episode_return", "episode_length"] else "env_info"
                    metric_dict = rollout_info_metrics if metric_group == "rollout" else env_info_metrics
                    mean_value = np.mean(step_info_collection[info_name])
                    if mean_value == mean_value:  # Check if mean_value is NaN
                        metric_dict[f"{metric_group}/{info_name}"] = mean_value
            
            combined_metrics = {**rollout_info_metrics, **evaluation_metrics, **env_info_metrics, **steps_metrics, **time_metrics, **optimization_metrics}
            for key, value in combined_metrics.items():
                self.log(f"{key}", value, global_step)

            self.end_logging()


    def log(self, name, value, step):
        if self.track_tb:
            self.writer.add_scalar(name, value, step)
        if self.track_console:
            self.log_console(name, value)
    

    def log_console(self, name, value):
        value = np.format_float_positional(value, trim="-")
        rlx_logger.info(f"│ {name.ljust(30)}│ {str(value).ljust(14)[:14]} │", flush=False)


    def start_logging(self, step):
        if self.track_console:
            rlx_logger.info("┌" + "─" * 31 + "┬" + "─" * 16 + "┐", flush=False)
        else:
            rlx_logger.info(f"Step: {step}")


    def end_logging(self):
        if self.track_console:
            rlx_logger.info("└" + "─" * 31 + "┴" + "─" * 16 + "┘")

    
    def save(self, global_step):

        if not hasattr(self, "_last_buffer_save_step"):
            self._last_buffer_save_step = -1
        save_interval = self.save_delay * 4096  # adjust frequency
        if global_step - self._last_buffer_save_step >= save_interval:
            self._last_buffer_save_step = global_step

            checkpoint = {
                "policy": self.policy_state,
                "critic": self.critic_state,
                "discriminator": self.discriminator_state,
            }

            if self.reward_fn_approximator:
                checkpoint["reward_fn"] = self.reward_fn_state

            save_args = orbax_utils.save_args_from_target(checkpoint)
            self.best_model_checkpointer.save(f"{self.save_path}/tmp", checkpoint, save_args=save_args)
            with open(f"{self.save_path}/tmp/config_algorithm.json", "w") as f:
                json.dump(self.config.algorithm.to_dict(), f)

            # Save pkl files for disc buffer
            disc_state_list = [flax_serialization.to_state_dict(p) for p in list(self.disc_buffer.buffer)]
            with open(os.path.join(self.save_path, "tmp", "disc_buffer.pkl"), "wb") as f:
                pickle.dump(disc_state_list, f)
            np.save(os.path.join(self.save_path, "tmp", "eta_buffer.npy"), self.disc_buffer.eta_buffer)

            shutil.make_archive(f"{self.save_path}/{self.best_model_file_name}", "zip", f"{self.save_path}/tmp")
            # os.rename(f"{self.save_path}/{self.best_model_file_name}.zip", f"{self.save_path}/{self.best_model_file_name}")
            shutil.rmtree(f"{self.save_path}/tmp")

            if self.track_wandb:
                wandb.save(f"{self.save_path}/{self.best_model_file_name}", base_path=self.save_path)


    def load(config, env, run_path, writer, explicitly_set_algorithm_params):
        splitted_path = config.runner.load_model.split("/")
        checkpoint_dir = os.path.abspath("/".join(splitted_path[:-1]))
        checkpoint_file_name = splitted_path[-1]
        shutil.unpack_archive(f"{checkpoint_dir}/{checkpoint_file_name}", f"{checkpoint_dir}/tmp", "zip")
        checkpoint_dir = f"{checkpoint_dir}/tmp"

        loaded_algorithm_config = json.load(open(f"{checkpoint_dir}/config_algorithm.json", "r"))
        for key, value in loaded_algorithm_config.items():
            if f"algorithm.{key}" not in explicitly_set_algorithm_params:
                config.algorithm[key] = value
        model = TRIRL_PPO(config, env, run_path, writer)

        template_params_dict = flax_serialization.to_state_dict(model.discriminator_state.params)
        target = {
            "policy": model.policy_state,
            "critic": model.critic_state,
            "discriminator": model.discriminator_state,
        }
        
        if model.reward_fn_approximator:
            target["reward_fn"] = model.reward_fn_state

        restore_args = orbax_utils.restore_args_from_target(target)
        checkpointer = orbax.checkpoint.PyTreeCheckpointer()
        checkpoint = checkpointer.restore(checkpoint_dir, item=target, restore_args=restore_args)

        if not model.global_rew_experiment:
            model.policy_state = checkpoint["policy"]
            model.critic_state = checkpoint["critic"]
        model.discriminator_state = checkpoint["discriminator"]

        disc_pkl_path = os.path.join(checkpoint_dir, "disc_buffer.pkl")
        if os.path.exists(disc_pkl_path):
            with open(disc_pkl_path, "rb") as f:
                disc_state_list = pickle.load(f)
            restored_disc = []
            for sd in disc_state_list:
                params = flax_serialization.from_state_dict(model.discriminator_state.params, sd)
                params = jax.tree.map(lambda x: jnp.array(x), params)
                restored_disc.append(params)
            model.disc_buffer._buffer = deque(restored_disc, maxlen=config.algorithm.disc_buffer_capacity)
        else:
            model.disc_buffer._buffer = deque(maxlen=config.algorithm.disc_buffer_capacity)

        eta_npy_path = os.path.join(checkpoint_dir, "eta_buffer.npy")
        if os.path.exists(eta_npy_path):
            eta_array = np.load(eta_npy_path, allow_pickle=False)
            if eta_array.size == 0:
                model.disc_buffer._eta_buffer = deque(maxlen=config.algorithm.disc_buffer_capacity)
            else:
                entries = [eta_array[i] for i in range(eta_array.shape[0])]
                model.disc_buffer._eta_buffer = deque(entries, maxlen=config.algorithm.disc_buffer_capacity)
        else:
            model.disc_buffer._eta_buffer = deque(maxlen=config.algorithm.disc_buffer_capacity)

        shutil.rmtree(checkpoint_dir)
        return model
    

    def test(self, episodes):
        @jax.jit
        def get_action(policy_state: TrainState, state: np.ndarray):
            action_mean, action_logstd = self.policy.apply(policy_state.params, state)
            return self.get_processed_action(action_mean)
        
        self.set_eval_mode()
        nr_steps = 100
        generate_data = True

        def evaluate():
            mean_episode_return = []
            for i in range(episodes):
                done = False
                episode_return = 0
                state, _ = self.env.reset()
                for step in range(nr_steps):
                    processed_action = get_action(self.policy_state, state)
                    next_state, reward, terminated, truncated, info = self.env.step(jax.device_get(processed_action))
                    done = terminated | truncated
                    state = next_state
                    episode_return += reward
                
                mean_episode_return.append(episode_return.mean())
                rlx_logger.info(f"Episode {i + 1} - Return: {episode_return.mean()}")
            
            mean_episode_return = np.array(mean_episode_return).mean()
            rlx_logger.info(f"Mean Episode Return: {mean_episode_return}") 

        evaluate()

            
    def set_train_mode(self):
        ...


    def set_eval_mode(self):
        ...


    def general_properties():
        return GeneralProperties
