import os
import shutil
import json
import logging
import time
from collections import deque
import tree
import numpy as np
import jax
import jax.numpy as jnp
from jax.lax import stop_gradient
from flax.training.train_state import TrainState
from flax.training import orbax_utils
import orbax.checkpoint
import optax
import wandb
from functools import partial

from irl_baselines.algorithms.near_ppo.flax.general_properties import GeneralProperties
from irl_baselines.algorithms.near_ppo.flax.policy import get_policy
from irl_baselines.algorithms.near_ppo.flax.critic import get_critic
from irl_baselines.algorithms.near_ppo.flax.energy_function import get_energyfn
from irl_baselines.algorithms.near_ppo.flax.batch import Batch
from irl_baselines.algorithms.data_utils import prepare_expert_data

rlx_logger = logging.getLogger("rl_x")

class EnergyBuffer:

    def __init__(self, capacity):
        """
        A buffer of the last k energy functions
        """
        self.capacity = capacity
        self._buffer = deque(maxlen=capacity)

    def append(self, item):
        self._buffer.append(item)

    @property
    def buffer(self):
        return list(self._buffer)

    def __len__(self):
        return len(self._buffer)

    def mean(self):
        return jnp.mean(jnp.stack(list(self._buffer)))


class NEAR_PPO:
    def __init__(self, config, env, eval_env, run_path, writer) -> None:
        self.config = config
        self.env = env
        self.writer = writer

        self.save_model = config.runner.save_model
        self.save_path = os.path.join(run_path, "models")
        self.track_console = config.runner.track_console
        self.track_tb = config.runner.track_tb
        self.track_wandb = config.runner.track_wandb
        self.seed = config.environment.seed
        self.total_timesteps = config.algorithm.total_timesteps
        self.nr_envs = config.environment.nr_envs
        self.learning_rate = config.algorithm.learning_rate
        self.anneal_learning_rate = config.algorithm.anneal_learning_rate
        self.nr_steps = config.algorithm.nr_steps
        self.nr_epochs = config.algorithm.nr_epochs
        self.minibatch_size = config.algorithm.minibatch_size
        self.gamma = config.algorithm.gamma
        self.gae_lambda = config.algorithm.gae_lambda
        self.clip_range = config.algorithm.clip_range
        self.entropy_coef = config.algorithm.entropy_coef
        self.critic_coef = config.algorithm.critic_coef
        self.max_grad_norm = config.algorithm.max_grad_norm
        self.std_dev = config.algorithm.std_dev
        self.nr_hidden_units = config.algorithm.nr_hidden_units
        self.evaluation_frequency = config.algorithm.evaluation_frequency
        self.evaluation_episodes = config.algorithm.evaluation_episodes
        self.batch_size = config.environment.nr_envs * config.algorithm.nr_steps
        self.nr_updates = config.algorithm.total_timesteps // self.batch_size
        self.nr_minibatches = self.batch_size // self.minibatch_size
        self.subsampling_cutoff = config.algorithm.get("subsampling_cutoff", 1)

        # NEAR Specific
        self.batch_size_ncsn = config.algorithm.batch_size_ncsn
        self.minibatch_size_ncsn = config.algorithm.minibatch_size_ncsn
        self.total_samples_ncsn = config.algorithm.total_samples_ncsn
        self.nr_epochs_ncsn = config.algorithm.nr_epochs_ncsn  # Number of ncsn epochs
        self.anneal_power_ncsn = config.algorithm.anneal_power_ncsn
        self.sigma_begin_ncsn = config.algorithm.sigma_begin_ncsn
        self.sigma_end_ncsn = config.algorithm.sigma_end_ncsn
        self.L_ncsn = config.algorithm.L_ncsn
        self.nr_hidden_units_encoder_ncsn = config.algorithm.nr_hidden_units_encoder_ncsn
        self.nr_hidden_units_decoder_ncsn = config.algorithm.nr_hidden_units_decoder_ncsn
        self.learning_rate_ncsn = config.algorithm.learning_rate_ncsn
        self.sigma_inference_ncsn = config.algorithm.sigma_inference_ncsn
        self.annealing = False
        self.ncsnv1 = config.algorithm.ncsnv1
        # Annealing
        if self.sigma_inference_ncsn == -1:
            self.sigma_inference_ncsn = 0
            self.annealing = True
        self.anneal_threshold = config.algorithm.anneal_threshold
        self.env_reward_frac = config.algorithm.env_reward_frac
        self.data_path = config.algorithm.data_path
        self.handle_absorbing_states = config.algorithm.handle_absorbing_states
        self.nr_minibatches_ncsn = self.batch_size_ncsn // self.minibatch_size_ncsn
        self.state_based = config.algorithm.state_based
        self.energy_buffer = EnergyBuffer(capacity=3)

        # Global Reward Experiment Flag
        self.run_diagnostics = config.algorithm.run_diagnostics

        if self.evaluation_frequency % (self.nr_steps * self.nr_envs) != 0 and self.evaluation_frequency != -1:
            raise ValueError("Evaluation frequency must be a multiple of the number of steps and environments.")

        rlx_logger.info(f"Using device: {jax.default_backend()}")
        
        self.key = jax.random.PRNGKey(self.seed)
        self.key, policy_key, critic_key, energyfn_key = jax.random.split(self.key, 4)

        self.os_shape = env.single_observation_space.shape
        self.as_shape = env.single_action_space.shape
        
        self.policy, self.get_processed_action = get_policy(config, env)
        self.critic = get_critic(config, env)
        self.energyfn = get_energyfn(config, env, ncsnv1=self.ncsnv1)

        self.policy.apply = jax.jit(self.policy.apply)
        self.critic.apply = jax.jit(self.critic.apply)
        self.energyfn.apply = jax.jit(self.energyfn.apply)

        def linear_schedule(count):
            fraction = 1.0 - (count // (self.nr_minibatches * self.nr_epochs)) / self.nr_updates
            return self.learning_rate * fraction

        learning_rate = linear_schedule if self.anneal_learning_rate else self.learning_rate
        # learning_rate_ncsn = linear_schedule_ncsn if self.anneal_learning_rate else self.learning_rate_ncsn
        learning_rate_ncsn = self.learning_rate_ncsn

        state = jnp.array([env.single_observation_space.sample()])
        next_state = jnp.array([env.single_observation_space.sample()])
        action = jnp.array([env.single_action_space.sample()])
        cond = 1.0
        if self.state_based:
            init_input = jnp.concatenate([state.flatten(), next_state.flatten()])
        else:
            init_input = jnp.concatenate([state.flatten(), action.flatten()])

        self.policy_state = TrainState.create(
            apply_fn=self.policy.apply,
            params=self.policy.init(policy_key, state),
            tx=optax.chain(
                optax.clip_by_global_norm(self.max_grad_norm),
                optax.inject_hyperparams(optax.adam)(learning_rate=learning_rate),
            )
        )

        self.critic_state = TrainState.create(
            apply_fn=self.critic.apply,
            params=self.critic.init(critic_key, state),
            tx=optax.chain(
                optax.clip_by_global_norm(self.max_grad_norm),
                optax.inject_hyperparams(optax.adam)(learning_rate=learning_rate),
            )
        )

        self.energyfn_state = TrainState.create(
            apply_fn=self.energyfn.apply,
            params=self.energyfn.init(energyfn_key, init_input, cond),
            tx=optax.chain(
                # optax.clip_by_global_norm(self.max_grad_norm),
                optax.inject_hyperparams(optax.adamw)(learning_rate=learning_rate_ncsn),
            )
        )

        if self.save_model:
            os.makedirs(self.save_path)
            self.best_mean_return = -np.inf
            self.best_model_file_name = "best.model"
            self.best_model_checkpointer = orbax.checkpoint.PyTreeCheckpointer()

    
    def train_irl(self):
        @jax.jit
        def get_action_and_value(policy_state: TrainState, critic_state: TrainState, state: np.ndarray, key: jax.random.PRNGKey):
            action_mean, action_logstd = self.policy.apply(policy_state.params, state)
            action_std = jnp.exp(action_logstd)
            key, subkey = jax.random.split(key)
            action = action_mean + action_std * jax.random.normal(subkey, shape=action_mean.shape)
            log_prob = -0.5 * ((action - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd
            value = self.critic.apply(critic_state.params, state)
            processed_action = self.get_processed_action(action)
            return processed_action, action, value.reshape(-1), log_prob.sum(1), key
        

        # @jax.jit
        def get_energy_reward(energyfn_state: TrainState, sample: np.ndarray, cond: float, last_update_mean_energy: float):
            energy = stop_gradient(self.energyfn.apply(energyfn_state.params, sample, cond))
            energy_reward = 10 * jnp.tanh((energy - last_update_mean_energy)/10)

            return energy_reward, energy

        get_energy_reward = jax.vmap(get_energy_reward, in_axes=(None, 0, None, None), out_axes=0)

        @jax.jit
        def calculate_gae_advantages(critic_state: TrainState, next_states: np.ndarray, rewards: np.ndarray, terminations: np.ndarray, values: np.ndarray):
            def compute_advantages(carry, t):
                prev_advantage = carry[0]
                advantage = delta[t] + self.gamma * self.gae_lambda * (1 - terminations[t]) * prev_advantage
                return (advantage,), advantage

            next_values = self.critic.apply(critic_state.params, next_states).squeeze(-1)
            delta = rewards + self.gamma * next_values * (1.0 - terminations) - values
            init_advantages = delta[-1]
            _, advantages = jax.lax.scan(compute_advantages, (init_advantages,), jnp.arange(self.nr_steps - 2, -1, -1))
            advantages = jnp.concatenate([advantages[::-1], jnp.array([init_advantages])])
            returns = advantages + values
            return advantages, returns


        """
        Noise Conditioned Score Networks Update
        """

        # @jax.jit
        @partial(jax.jit, static_argnames=('state_based', 'ncsnv1'))
        def ncsn_update(energyfn_state: TrainState,
                   expert_states: np.ndarray, expert_actions: np.ndarray, expert_next_states: np.ndarray,
                   key: jax.random.PRNGKey, state_based=False, ncsnv1=False):

            def ncsn_loss_fn(energyfn_params, expert_state, expert_action, expert_next_state, key):
                """
                Denoising Score Matching
                """
                # geometric schedule sigmas
                key, label_key = jax.random.split(key)
                sigmas = jnp.exp(jnp.linspace(jnp.log(self.sigma_begin_ncsn), jnp.log(self.sigma_end_ncsn), self.L_ncsn))
                conds = jnp.arange(self.L_ncsn)
                used_cond = jax.random.choice(label_key, conds)
                # used_sigma = jax.random.choice(label_key, sigmas)
                used_sigma = sigmas[used_cond]
        
                # perturbing expert sample
                if state_based:
                    sample = jnp.concatenate([expert_state.flatten(), expert_next_state.flatten()])
                else:
                    sample = jnp.concatenate([expert_state.flatten(), expert_action.flatten()])
                perturbed_sample = sample + jax.random.normal(key, shape=sample.shape) * used_sigma
                target = - 1 / (used_sigma ** 2) * (perturbed_sample - sample)

                if ncsnv1:
                    pred_score = jax.grad(lambda x, cond: jnp.sum(self.energyfn.apply(energyfn_params, x, cond)), argnums=(0))(perturbed_sample, used_cond)
                else:
                    pred_score = jax.grad(lambda x, cond: jnp.sum(self.energyfn.apply(energyfn_params, x, cond)), argnums=(0))(perturbed_sample, used_sigma)
                dsm_loss = jnp.mean((1/2.) * ((pred_score - target) ** 2).sum() * (used_sigma ** self.anneal_power_ncsn))

                metrics = {
                    "loss/energyfn_loss": dsm_loss,
                }
                return dsm_loss, (metrics)


            # Expert batch
            key, shuffle_key = jax.random.split(key)
            perm = jax.random.permutation(shuffle_key, expert_states.shape[0])
            expert_states = expert_states[perm]
            expert_actions = expert_actions[perm]
            expert_next_states = expert_next_states[perm]
            batch_expert_states = expert_states[:self.batch_size_ncsn]
            batch_expert_actions = expert_actions[:self.batch_size_ncsn]
            batch_expert_next_states = expert_next_states[:self.batch_size_ncsn]

            vmap_ncsn_loss_fn = jax.vmap(ncsn_loss_fn, in_axes=(None, 0, 0, 0, 0), out_axes=0)
            safe_mean = lambda x: jnp.mean(x) if x is not None else x
            mean_vmapped_ncsn_loss_fn = lambda *a, **k: tree.map_structure(safe_mean, vmap_ncsn_loss_fn(*a, **k))
            grad_ncsn_loss_fn = jax.value_and_grad(mean_vmapped_ncsn_loss_fn, argnums=(0), has_aux=True)

            key, subkey = jax.random.split(key)
            batch_indices_ncsn = jnp.tile(jnp.arange(self.batch_size_ncsn), (self.nr_epochs_ncsn, 1))
            batch_indices_ncsn = jax.random.permutation(subkey, batch_indices_ncsn, axis=1, independent=True)
            batch_indices_ncsn = batch_indices_ncsn.reshape((self.nr_epochs_ncsn * self.nr_minibatches_ncsn, self.minibatch_size_ncsn))

            def minibatch_update(carry, minibatch_indices_ncsn):
                energyfn_state, key = carry

                key, label_key = jax.random.split(key)
                mb_keys = jax.random.split(label_key, self.minibatch_size_ncsn)

                # NEAR UPDATE
                (near_loss, (metrics)), (energyfn_gradients) = grad_ncsn_loss_fn(
                    energyfn_state.params,
                    batch_expert_states[minibatch_indices_ncsn],
                    batch_expert_actions[minibatch_indices_ncsn],
                    batch_expert_next_states[minibatch_indices_ncsn],
                    mb_keys,
                )

                energyfn_state = energyfn_state.apply_gradients(grads=energyfn_gradients)
                metrics["gradients/energyfn_grad_norm"] = optax.global_norm(energyfn_gradients)

                carry = (energyfn_state, key)

                return carry, (metrics)
            
            init_carry = (energyfn_state, key)
            carry, (metrics) = jax.lax.scan(minibatch_update, init_carry, batch_indices_ncsn)
            energyfn_state, key = carry

            # Calculate mean metrics
            mean_metrics = {key: jnp.mean(metrics[key]) for key in metrics}
            mean_metrics["lr/learning_rate_ncsn"] = energyfn_state.opt_state[0].hyperparams["learning_rate"]

            return energyfn_state, mean_metrics, key

        """
        PPO Update
        """

        @jax.jit
        def update(policy_state: TrainState, critic_state: TrainState,
                   states: np.ndarray, actions: np.ndarray, advantages: np.ndarray, returns: np.ndarray, values: np.ndarray, log_probs: np.ndarray,
                   key: jax.random.PRNGKey):
            def loss_fn(policy_params, critic_params, state_b, action_b, log_prob_b, return_b, advantage_b):
                # Policy loss
                action_mean, action_logstd = self.policy.apply(policy_params, state_b)
                action_std = jnp.exp(action_logstd)
                new_log_prob = -0.5 * ((action_b - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd
                new_log_prob = new_log_prob.sum(1)
                entropy = action_logstd + 0.5 * jnp.log(2.0 * jnp.pi * jnp.e)
                
                logratio = new_log_prob - log_prob_b
                ratio = jnp.exp(logratio)
                approx_kl_div = (ratio - 1) - logratio
                clip_fraction = jnp.float32((jnp.abs(ratio - 1) > self.clip_range))

                pg_loss1 = -advantage_b * ratio
                pg_loss2 = -advantage_b * jnp.clip(ratio, 1 - self.clip_range, 1 + self.clip_range)
                pg_loss = jnp.maximum(pg_loss1, pg_loss2)
                
                entropy_loss = entropy.sum(1)
                
                # Critic loss
                new_value = self.critic.apply(critic_params, state_b)
                critic_loss = 0.5 * (new_value - return_b) ** 2

                # Combine losses
                loss = pg_loss - self.entropy_coef * entropy_loss + self.critic_coef * critic_loss

                # Create metrics
                metrics = {
                    "loss/policy_gradient_loss": pg_loss,
                    "loss/critic_loss": critic_loss,
                    "loss/entropy_loss": entropy_loss,
                    "policy_ratio/approx_kl": approx_kl_div,
                    "policy_ratio/clip_fraction": clip_fraction,
                }

                return loss, (metrics)
            

            batch_states = states.reshape((-1,) + self.os_shape)
            batch_actions = actions.reshape((-1,) + self.as_shape)
            batch_advantages = advantages.reshape(-1)
            batch_returns = returns.reshape(-1)
            batch_log_probs = log_probs.reshape(-1)

            vmap_loss_fn = jax.vmap(loss_fn, in_axes=(None, None, 0, 0, 0, 0, 0), out_axes=0)
            safe_mean = lambda x: jnp.mean(x) if x is not None else x
            mean_vmapped_loss_fn = lambda *a, **k: tree.map_structure(safe_mean, vmap_loss_fn(*a, **k))
            grad_loss_fn = jax.value_and_grad(mean_vmapped_loss_fn, argnums=(0, 1), has_aux=True)

            key, subkey = jax.random.split(key)
            batch_indices = jnp.tile(jnp.arange(self.batch_size), (self.nr_epochs, 1))
            batch_indices = jax.random.permutation(subkey, batch_indices, axis=1, independent=True)
            batch_indices = batch_indices.reshape((self.nr_epochs * self.nr_minibatches, self.minibatch_size))

            def minibatch_update(carry, minibatch_indices):
                policy_state, critic_state = carry

                minibatch_advantages = batch_advantages[minibatch_indices]
                minibatch_advantages = (minibatch_advantages - jnp.mean(minibatch_advantages)) / (jnp.std(minibatch_advantages) + 1e-8)

                (loss, (metrics)), (policy_gradients, critic_gradients) = grad_loss_fn(
                    policy_state.params,
                    critic_state.params,
                    batch_states[minibatch_indices],
                    batch_actions[minibatch_indices],
                    batch_log_probs[minibatch_indices],
                    batch_returns[minibatch_indices],
                    minibatch_advantages
                )

                policy_state = policy_state.apply_gradients(grads=policy_gradients)
                critic_state = critic_state.apply_gradients(grads=critic_gradients)

                metrics["gradients/policy_grad_norm"] = optax.global_norm(policy_gradients)
                metrics["gradients/critic_grad_norm"] = optax.global_norm(critic_gradients)

                carry = (policy_state, critic_state)

                return carry, (metrics)
            
            init_carry = (policy_state, critic_state)
            carry, (metrics) = jax.lax.scan(minibatch_update, init_carry, batch_indices)
            policy_state, critic_state = carry

            # Calculate mean metrics
            mean_metrics = {key: jnp.mean(metrics[key]) for key in metrics}
            mean_metrics["lr/learning_rate"] = policy_state.opt_state[1].hyperparams["learning_rate"]
            mean_metrics["v_value/explained_variance"] = 1 - jnp.var(returns - values) / (jnp.var(returns) + 1e-8)
            mean_metrics["policy/std_dev"] = jnp.mean(jnp.exp(policy_state.params["params"]["policy_logstd"]))

            return policy_state, critic_state, mean_metrics, key


        @jax.jit
        def get_deterministic_action(policy_state: TrainState, state: np.ndarray):
            action_mean, action_logstd = self.policy.apply(policy_state.params, state)
            return self.get_processed_action(action_mean)
        

        def learnt_reward_pearson_correlation():
            rlx_logger.info(f"Testing Correlation Between Expert & Learnt Rewards")

            from scipy import stats
            import matplotlib.pyplot as plt

            # Get demonstrations
            demonstrations = prepare_expert_data(self.data_path, self.subsampling_cutoff)
            expert_states = demonstrations["states"]
            expert_next_states = demonstrations["next_states"]
            expert_actions = demonstrations["actions"]
            expert_rewards = demonstrations["rewards"].reshape((-1,1))
            expert_absorbing = demonstrations["absorbing"].flatten()
            # if self.handle_absorbing_states:
            #     abs_indices = np.where(expert_absorbing > 0.0)[0]
            #     expert_states[abs_indices] = expert_states[abs_indices] * 0.0
            #     expert_next_states[abs_indices] = expert_states[abs_indices]

            sigmas = jnp.exp(jnp.linspace(jnp.log(self.sigma_begin_ncsn), jnp.log(self.sigma_end_ncsn), self.L_ncsn))
            cond = 4 # conditioning level corresponding to some sigma value
            sigma = sigmas[cond]
            if self.state_based:
                sample = jnp.concatenate([expert_states.reshape((-1,) + self.os_shape), expert_next_states.reshape((-1,) + self.os_shape)], axis=1)
            else:
                sample = jnp.concatenate([expert_states.reshape((-1,) + self.os_shape), expert_actions.reshape((-1,) + self.as_shape)], axis=1)

            if self.ncsnv1:
                energy = self.energyfn.apply(self.energyfn_state.params, sample, cond)
            else:
                energy = self.energyfn.apply(self.energyfn_state.params, sample, sigma)
            pearson_correlation = stats.pearsonr(energy, expert_rewards)
            rlx_logger.info(f"pearson correlation: {pearson_correlation[0]} p value: {pearson_correlation[1]}")

            plt.figure(figsize=(8, 6))
            plt.scatter(energy.flatten(), expert_rewards.flatten(), alpha=0.5, s=10)
            plt.xlabel("Energy Rewards")
            plt.ylabel("Expert Rewards")
            # plt.title("Scatter Plot of Projected vs Expert Rewards")
            plt.grid(True)
            # plt.show()

            return pearson_correlation[0][0]


        self.set_train_mode()

        demonstrations = prepare_expert_data(self.data_path)
        self.expert_states = demonstrations["states"]
        self.expert_actions = demonstrations["actions"]
        self.expert_next_states = demonstrations["next_states"]

        temp_start = time.time()

        """
        Reward Learning
        """
        rlx_logger.info(f"Training Energy Based Reward")
        @jax.jit
        def train_ncsn(runner_state, unused):
            energyfn_state, nr_samples_ncsn, key = runner_state

            energyfn_state, optimization_metrics_ncsn, key = ncsn_update(
                energyfn_state,
                self.expert_states, self.expert_actions, self.expert_next_states,
                key, self.state_based, self.ncsnv1
            )

            nr_samples_ncsn += self.nr_epochs_ncsn * self.batch_size_ncsn

            # LOGGING
            def metrics_callback(metric):
                rlx_logger.info("┌" + "─" * 31 + "┬" + "─" * 16 + "┐", flush=False)
                optimization_metrics_ncsn, nr_samples_ncsn = metric
                log_dict = {
                    "loss/energyfn_loss": optimization_metrics_ncsn["loss/energyfn_loss"],
                    "gradients/energyfn_grad_norm": optimization_metrics_ncsn["gradients/energyfn_grad_norm"],
                    "nr_samples_ncsn": nr_samples_ncsn,
                    } 
                for name, value in log_dict.items():
                    if isinstance(value, jax.Array):
                        value = value.__array__()
                    self.writer.add_scalar(name, value.__array__(), nr_samples_ncsn.__array__())
                    rlx_logger.info(f"│ {name.ljust(30)}│ {str(value).ljust(14)[:14]} │", flush=False)

                rlx_logger.info("└" + "─" * 31 + "┴" + "─" * 16 + "┘")

            jax.lax.cond(
                nr_samples_ncsn % (100 * self.nr_epochs_ncsn * self.batch_size_ncsn) == 0,
                lambda _: jax.debug.callback(metrics_callback, (optimization_metrics_ncsn, nr_samples_ncsn)),
                lambda _: None,
                operand=None,
            )

            runner_state = (energyfn_state, nr_samples_ncsn, key)
            return runner_state, None

        
        num_updates_ncsn = self.total_samples_ncsn // self.nr_epochs_ncsn // self.batch_size_ncsn
        ncsn_runner_state = (self.energyfn_state, 0, self.key)
        (self.energyfn_state, nr_samples_ncsn, self.key), _ = jax.lax.scan(
            train_ncsn, ncsn_runner_state, None, num_updates_ncsn
        )

        # """
        # Reward Learning
        # """
        # rlx_logger.info(f"Training Energy Based Reward")
        # nr_samples_ncsn = 0
        # steps_metrics_ncsn = {}
        # while nr_samples_ncsn < self.total_samples_ncsn:

        #     start_time_ncsn = time.time()
        #     time_metrics_ncsn = {}

        #     self.energyfn_state, optimization_metrics_ncsn, self.key = ncsn_update(
        #         self.energyfn_state,
        #         self.expert_states, self.expert_actions, self.expert_next_states,
        #         self.key, self.state_based, self.ncsnv1
        #     )

        #     optimization_metrics_ncsn = {key: value.item() for key, value in optimization_metrics_ncsn.items()}
        #     nr_samples_ncsn += self.nr_epochs_ncsn * self.batch_size_ncsn

        #     optimizing_end_time_ncsn = time.time()
        #     time_metrics_ncsn["time/optimizing_time_ncsn"] = optimizing_end_time_ncsn - start_time_ncsn


        #     # Logging
        #     if nr_samples_ncsn % (100 * self.nr_epochs_ncsn * self.batch_size_ncsn) == 0:
        #         self.start_logging(nr_samples_ncsn)
        #         steps_metrics_ncsn["steps/nr_samples_ncsn"] = nr_samples_ncsn
        #         combined_metrics_ncsn = {**steps_metrics_ncsn, **time_metrics_ncsn, **optimization_metrics_ncsn}
        #         for key, value in combined_metrics_ncsn.items():
        #             self.log(f"{key}", value, nr_samples_ncsn)
                
        #         self.end_logging()



        # # Saving
        self.save_ncsn()
        pearson_correlation = learnt_reward_pearson_correlation()
        # print(f"Took {time.time() - temp_start}")
        # return # quit

        """
        Policy Learning (PPO)
        """
        rlx_logger.info(f"Training PPO Policy")
        sigmas = jnp.exp(jnp.linspace(jnp.log(self.sigma_begin_ncsn), jnp.log(self.sigma_end_ncsn), self.L_ncsn))

        batch = Batch(
            states=np.zeros((self.nr_steps, self.nr_envs) + self.os_shape),
            next_states=np.zeros((self.nr_steps, self.nr_envs) + self.os_shape),
            actions=np.zeros((self.nr_steps, self.nr_envs) + self.as_shape),
            rewards=np.zeros((self.nr_steps, self.nr_envs)),
            energy=np.zeros((self.nr_steps, self.nr_envs)),
            energy_rewards=np.zeros((self.nr_steps, self.nr_envs)),
            merged_rewards=np.zeros((self.nr_steps, self.nr_envs)),
            values=np.zeros((self.nr_steps, self.nr_envs)),
            terminations=np.zeros((self.nr_steps, self.nr_envs)),
            log_probs=np.zeros((self.nr_steps, self.nr_envs)),
            advantages=np.zeros((self.nr_steps, self.nr_envs)),
            returns=np.zeros((self.nr_steps, self.nr_envs)),
        )

        saving_return_buffer = deque(maxlen=100 * self.nr_envs)

        state, _ = self.env.reset()
        global_step = 0
        nr_updates = 0
        nr_episodes = 0
        steps_metrics = {}
        while global_step < self.total_timesteps:
            start_time = time.time()
            time_metrics = {}


            # Acting
            dones_this_rollout = 0
            step_info_collection = {}
            for step in range(self.nr_steps):
                processed_action, action, value, log_prob, self.key = get_action_and_value(self.policy_state, self.critic_state, state, self.key)
                next_state, reward, terminated, truncated, info = self.env.step(jax.device_get(processed_action))
                done = terminated | truncated
                actual_next_state = next_state.copy()
                for i, single_done in enumerate(done):
                    if single_done:
                        actual_next_state[i] = np.array(self.env.get_final_observation_at_index(info, i))
                        saving_return_buffer.append(self.env.get_final_info_value_at_index(info, "episode_return", i))
                        dones_this_rollout += 1
                for key, info_value in self.env.get_logging_info_dict(info).items():
                    step_info_collection.setdefault(key, []).extend(info_value)

                batch.states[step] = state
                batch.next_states[step] = actual_next_state
                batch.actions[step] = action
                batch.rewards[step] = reward
                batch.values[step] = value
                batch.terminations[step] = terminated
                batch.log_probs[step] = log_prob
                state = next_state
                global_step += self.nr_envs
            nr_episodes += dones_this_rollout
            # if self.handle_absorbing_states:
                # Do not explicitly set the value of absorbing states to 0.0 and insead let the critic predict a soft value for absorbing states
                # Also let the discriminator predict a reward for absorbing states
                # The goal is to avoid survival/termination bias
                # batch.terminations = 0.0 * batch.terminations
            acting_end_time = time.time()
            time_metrics["time/acting_time"] = acting_end_time - start_time

            # Compute energy reward and anneal noise level
            if self.state_based:
                samples = jnp.concatenate([batch.states.reshape((-1,) + self.os_shape), batch.next_states.reshape((-1,) + self.os_shape)], axis=1)
            else:
                samples = jnp.concatenate([batch.states.reshape((-1,) + self.os_shape), batch.actions.reshape((-1,) + self.as_shape)], axis=1)

            if self.annealing:
                if global_step == int(self.nr_steps * self.nr_envs):
                    if self.ncsnv1:
                        batch.energy = jnp.squeeze(stop_gradient(self.energyfn.apply(self.energyfn_state.params, samples, self.sigma_inference_ncsn))).reshape(batch.energy.shape)
                    else:
                        batch.energy = jnp.squeeze(stop_gradient(self.energyfn.apply(self.energyfn_state.params, samples, sigmas[self.sigma_inference_ncsn]))).reshape(batch.energy.shape)
                    self.energy_buffer.append(batch.energy)

                # energy_rewards, energy = get_energy_reward(self.energyfn_state, samples, sigmas[self.sigma_inference_ncsn], jnp.mean(batch.energy))
                if self.ncsnv1:
                    energy_rewards, energy = get_energy_reward(self.energyfn_state, samples, self.sigma_inference_ncsn, self.energy_buffer.mean())
                else:
                    energy_rewards, energy = get_energy_reward(self.energyfn_state, samples, sigmas[self.sigma_inference_ncsn], self.energy_buffer.mean())
                batch.energy = jnp.squeeze(energy).reshape(batch.energy.shape)
                self.energy_buffer.append(batch.energy)
                mean_energy = self.energy_buffer.mean()
                # mean_energy = batch.energy.mean()

                if global_step == int(self.nr_steps * self.nr_envs) and self.sigma_inference_ncsn == 0:
                    self.current_lvl_init_energy = mean_energy
                
                if mean_energy >= self.current_lvl_init_energy * (1 + self.anneal_threshold):
                    self.current_lvl_init_energy = mean_energy
                    self.sigma_inference_ncsn = jnp.clip(self.sigma_inference_ncsn + 1, 0, self.L_ncsn)
            else:
                if self.ncsnv1:
                    energy_rewards = jnp.squeeze(stop_gradient(self.energyfn.apply(self.energyfn_state.params, samples, self.sigma_inference_ncsn))).reshape(batch.energy.shape)
                else:
                    energy_rewards = jnp.squeeze(stop_gradient(self.energyfn.apply(self.energyfn_state.params, samples, sigmas[self.sigma_inference_ncsn]))).reshape(batch.energy.shape)

            batch.energy_rewards = jnp.squeeze(energy_rewards).reshape(batch.energy_rewards.shape)
            batch.merged_rewards = batch.rewards * self.env_reward_frac + batch.energy_rewards * (1 - self.env_reward_frac)

            # Calculating advantages and returns
            batch.advantages, batch.returns = calculate_gae_advantages(self.critic_state, batch.next_states, batch.merged_rewards, batch.terminations, batch.values)
            
            calc_adv_return_end_time = time.time()
            time_metrics["time/calc_adv_and_return_time"] = calc_adv_return_end_time - acting_end_time

            # Optimizing
            self.policy_state, self.critic_state, optimization_metrics, self.key = update(
                self.policy_state, self.critic_state,
                batch.states, batch.actions, batch.advantages, batch.returns, batch.values, batch.log_probs,
                self.key
            )
            optimization_metrics["reward/sigma_level"] = np.array([self.sigma_inference_ncsn])
            if self.annealing:
                optimization_metrics["reward/mean_energy"] = mean_energy
            else:
                optimization_metrics["reward/mean_energy"] = batch.energy_rewards.mean()
            optimization_metrics["reward/pearson_correlation"] = np.array([pearson_correlation])

            optimization_metrics = {key: value.item() for key, value in optimization_metrics.items()}
            nr_updates += self.nr_epochs * self.nr_minibatches

            optimizing_end_time = time.time()
            time_metrics["time/optimizing_time"] = optimizing_end_time - calc_adv_return_end_time


            # Evaluating
            evaluation_metrics = {}
            if global_step % self.evaluation_frequency == 0 and self.evaluation_frequency != -1:
                self.set_eval_mode()
                state, _ = self.env.reset()
                eval_nr_episodes = 0
                evaluation_metrics = {"eval/episode_return": [], "eval/episode_length": []}
                while True:
                    processed_action = get_deterministic_action(self.policy_state, state)
                    state, reward, terminated, truncated, info = self.env.step(jax.device_get(processed_action))
                    done = terminated | truncated
                    for i, single_done in enumerate(done):
                        if single_done:
                            eval_nr_episodes += 1
                            evaluation_metrics["eval/episode_return"].append(self.env.get_final_info_value_at_index(info, "episode_return", i))
                            evaluation_metrics["eval/episode_length"].append(self.env.get_final_info_value_at_index(info, "episode_length", i))
                            if eval_nr_episodes == self.evaluation_episodes:
                                break
                    if eval_nr_episodes == self.evaluation_episodes:
                        break
                evaluation_metrics = {key: np.mean(value) for key, value in evaluation_metrics.items()}
                state, _ = self.env.reset()
                self.set_train_mode()
            
            evaluating_end_time = time.time()
            time_metrics["time/evaluating_time"] = evaluating_end_time - optimizing_end_time
            

            # Saving
            # Also only save when there were finished episodes this update
            if self.save_model and dones_this_rollout > 0:
                mean_return = np.mean(saving_return_buffer)
                if mean_return > self.best_mean_return:
                    self.best_mean_return = mean_return
                    self.save(global_step)
            
            saving_end_time = time.time()
            time_metrics["time/saving_time"] = saving_end_time - evaluating_end_time

            time_metrics["time/sps"] = int((self.nr_steps * self.nr_envs) / (saving_end_time - start_time))


            # Logging
            self.start_logging(global_step)

            steps_metrics["steps/nr_env_steps"] = global_step
            steps_metrics["steps/nr_updates"] = nr_updates
            steps_metrics["steps/nr_episodes"] = nr_episodes

            rollout_info_metrics = {}
            env_info_metrics = {}
            if step_info_collection:
                info_names = list(step_info_collection.keys())
                for info_name in info_names:
                    metric_group = "rollout" if info_name in ["episode_return", "episode_length"] else "env_info"
                    metric_dict = rollout_info_metrics if metric_group == "rollout" else env_info_metrics
                    mean_value = np.mean(step_info_collection[info_name])
                    if mean_value == mean_value:  # Check if mean_value is NaN
                        metric_dict[f"{metric_group}/{info_name}"] = mean_value
            
            combined_metrics = {**rollout_info_metrics, **evaluation_metrics, **env_info_metrics, **steps_metrics, **time_metrics, **optimization_metrics}
            for key, value in combined_metrics.items():
                self.log(f"{key}", value, global_step)

            self.end_logging()


    ######
    #################################################################
    #######################################################################
    """
    FROM HERE ON ONLY TESTING & SAVING FUNCTIONS
    """
    #######################################################################
    #######################################################################

    def train(self):
        if self.run_diagnostics:
            self.plot_point_maze_reward()
            # self.learnt_reward_pearson_correlation()
        else:
            self.train_irl()


    def plot_point_maze_reward(self):
        import matplotlib.pyplot as plt

        import matplotlib.patches as patches
        plt.rcParams.update({
            "font.family": "monospace",
            "axes.titlesize": 24,
            "axes.labelsize": 24,
            "legend.fontsize": 24,
        })

        shaping = 0.0
        grid_size, low, high = 100, 0.0, 0.3
        xs = np.linspace(low, high, grid_size, dtype=np.float32)
        ys = np.linspace(low, high, grid_size, dtype=np.float32)
        xx, yy = np.meshgrid(xs, ys, indexing="xy")
        num = grid_size * grid_size

        target = np.array([0.15, 0.0], dtype=np.float32)
        states = np.stack([xx.reshape(-1), yy.reshape(-1),
                        np.full(num, target[0], np.float32),
                        np.full(num, target[1], np.float32)], axis=-1).astype(np.float32)
        next_states = states.copy()
        actions = np.zeros((num,) + self.as_shape, np.float32)


        sigmas = jnp.exp(jnp.linspace(jnp.log(self.sigma_begin_ncsn), jnp.log(self.sigma_end_ncsn), self.L_ncsn))
        cond = 10 # conditioning level corresponding to some sigma value
        sigma = sigmas[cond]
        if self.state_based:
            sample = jnp.concatenate([states.reshape((-1,) + self.os_shape), next_states.reshape((-1,) + self.os_shape)], axis=1)
        else:
            sample = jnp.concatenate([states.reshape((-1,) + self.os_shape), actions.reshape((-1,) + self.as_shape)], axis=1)

        if self.ncsnv1:
            energy = self.energyfn.apply(self.energyfn_state.params, sample, cond)
        else:
            energy = self.energyfn.apply(self.energyfn_state.params, sample, sigma)


        reward_grid = np.asarray(energy).reshape(grid_size, grid_size)
        plt.figure(figsize=(8,5))
        im = plt.imshow(reward_grid, extent=[low,high,low,high], origin="lower", aspect="equal", cmap="Blues")
        
        ax = plt.gca()
        ax.add_patch(patches.Rectangle((low, low), high - low, high - low,
                                    fill=False, edgecolor="black", linewidth=8))
        ax.plot([0.0, 0.18], [0.15, 0.15], color="black", linewidth=8)
        ax.scatter(0.15, 0.28, s=150, color="green", zorder=3)
        plt.plot(0.15, 0.02, marker='*', markersize=20, color='red')
        cbar = plt.colorbar(im)
        cbar.ax.tick_params(labelsize=24)
        # cbar.set_label("Reward", fontsize=24)
        cbar.set_ticks(np.linspace(im.get_clim()[0], im.get_clim()[1], 6))

        ax.set_xticks([]); ax.set_yticks([])
        ax.set_xlabel(""); ax.set_ylabel("")

        # plt.xlabel("x"); plt.ylabel("y")
        plt.title("Learnt Reward")
        plt.tight_layout()
        # plt.savefig("./airl_point_maze_reward.pdf", dpi=200, bbox_inches='tight')
        plt.show()


    def learnt_reward_pearson_correlation(self):
        rlx_logger.info(f"Testing Correlation Between Expert & Learnt Rewards")

        from scipy import stats
        import matplotlib.pyplot as plt

        # Get demonstrations
        demonstrations = prepare_expert_data(self.data_path)
        expert_states = demonstrations["states"]
        expert_next_states = demonstrations["next_states"]
        expert_actions = demonstrations["actions"]
        expert_rewards = demonstrations["rewards"].reshape((-1,1))
        expert_absorbing = demonstrations["absorbing"].flatten()
        # if self.handle_absorbing_states:
        #     abs_indices = np.where(expert_absorbing > 0.0)[0]
        #     expert_states[abs_indices] = expert_states[abs_indices] * 0.0
        #     expert_next_states[abs_indices] = expert_states[abs_indices]

        sigmas = jnp.exp(jnp.linspace(jnp.log(self.sigma_begin_ncsn), jnp.log(self.sigma_end_ncsn), self.L_ncsn))
        cond = 10 # conditioning level corresponding to some sigma value
        sigma = sigmas[cond]
        if self.state_based:
            sample = jnp.concatenate([expert_states.reshape((-1,) + self.os_shape), expert_next_states.reshape((-1,) + self.os_shape)], axis=1)
        else:
            sample = jnp.concatenate([expert_states.reshape((-1,) + self.os_shape), expert_actions.reshape((-1,) + self.as_shape)], axis=1)

        if self.ncsnv1:
            energy = self.energyfn.apply(self.energyfn_state.params, sample, cond)
        else:
            energy = self.energyfn.apply(self.energyfn_state.params, sample, sigma)
        # energy = 10*jnp.tanh((energy - jnp.mean(energy))/10)
        pearson_correlation = stats.pearsonr(energy, expert_rewards)
        print(f"pearson correlation: {pearson_correlation[0]} p value: {pearson_correlation[1]}")

        plt.figure(figsize=(8, 6))
        plt.scatter(energy.flatten(), expert_rewards.flatten(), alpha=0.5, s=10)
        plt.xlabel("Energy Rewards")
        plt.ylabel("Expert Rewards")
        # plt.title("Scatter Plot of Projected vs Expert Rewards")
        plt.grid(True)
        plt.show()

    def log(self, name, value, step):
        if self.track_tb:
            self.writer.add_scalar(name, value, step)
        if self.track_console:
            self.log_console(name, value)
    

    def log_console(self, name, value):
        value = np.format_float_positional(value, trim="-")
        rlx_logger.info(f"│ {name.ljust(30)}│ {str(value).ljust(14)[:14]} │", flush=False)


    def start_logging(self, step):
        if self.track_console:
            rlx_logger.info("┌" + "─" * 31 + "┬" + "─" * 16 + "┐", flush=False)
        else:
            rlx_logger.info(f"Step: {step}")


    def end_logging(self):
        if self.track_console:
            rlx_logger.info("└" + "─" * 31 + "┴" + "─" * 16 + "┘")

    
    def save(self, global_step):

        if not hasattr(self, "_last_buffer_save_step"):
            self._last_buffer_save_step = -1
        save_interval = 40 * 4096  # adjust frequency
        if global_step - self._last_buffer_save_step >= save_interval:
            self._last_buffer_save_step = global_step

            checkpoint = {
                "policy": self.policy_state,
                "critic": self.critic_state,
                "energy_function": self.energyfn_state,
            }
            save_args = orbax_utils.save_args_from_target(checkpoint)
            self.best_model_checkpointer.save(f"{self.save_path}/tmp", checkpoint, save_args=save_args)
            with open(f"{self.save_path}/tmp/config_algorithm.json", "w") as f:
                json.dump(self.config.algorithm.to_dict(), f)
            shutil.make_archive(f"{self.save_path}/{self.best_model_file_name}", "zip", f"{self.save_path}/tmp")
            # os.rename(f"{self.save_path}/{self.best_model_file_name}.zip", f"{self.save_path}/{self.best_model_file_name}")
            shutil.rmtree(f"{self.save_path}/tmp")

            if self.track_wandb:
                wandb.save(f"{self.save_path}/{self.best_model_file_name}", base_path=self.save_path)


    def save_ncsn(self):
        checkpoint = {
            "policy": self.policy_state,
            "critic": self.critic_state,
            "energy_function": self.energyfn_state,
        }
        save_args = orbax_utils.save_args_from_target(checkpoint)
        self.best_model_checkpointer.save(f"{self.save_path}/tmp", checkpoint, save_args=save_args)
        with open(f"{self.save_path}/tmp/config_algorithm.json", "w") as f:
            json.dump(self.config.algorithm.to_dict(), f)
        shutil.make_archive(f"{self.save_path}/{self.best_model_file_name}", "zip", f"{self.save_path}/tmp")
        # os.rename(f"{self.save_path}/{self.best_model_file_name}.zip", f"{self.save_path}/{self.best_model_file_name}")
        shutil.rmtree(f"{self.save_path}/tmp")

        if self.track_wandb:
            wandb.save(f"{self.save_path}/{self.best_model_file_name}", base_path=self.save_path)


    def load(config, env, eval_env, run_path, writer, explicitly_set_algorithm_params):
        splitted_path = config.runner.load_model.split("/")
        checkpoint_dir = os.path.abspath("/".join(splitted_path[:-1]))
        checkpoint_file_name = splitted_path[-1]
        shutil.unpack_archive(f"{checkpoint_dir}/{checkpoint_file_name}", f"{checkpoint_dir}/tmp", "zip")
        checkpoint_dir = f"{checkpoint_dir}/tmp"

        loaded_algorithm_config = json.load(open(f"{checkpoint_dir}/config_algorithm.json", "r"))
        for key, value in loaded_algorithm_config.items():
            if f"algorithm.{key}" not in explicitly_set_algorithm_params:
                config.algorithm[key] = value
        model = NEAR_PPO(config, env, run_path, writer)

        target = {
            "policy": model.policy_state,
            "critic": model.critic_state,
            "energy_function": model.energyfn_state,
        }
        restore_args = orbax_utils.restore_args_from_target(target)
        checkpointer = orbax.checkpoint.PyTreeCheckpointer()
        checkpoint = checkpointer.restore(checkpoint_dir, item=target, restore_args=restore_args)

        model.policy_state = checkpoint["policy"]
        model.critic_state = checkpoint["critic"]
        model.energyfn_state = checkpoint["energy_function"]

        shutil.rmtree(checkpoint_dir)

        return model
    

    def test(self, episodes):


        @jax.jit
        def get_action(policy_state: TrainState, state: np.ndarray):
            action_mean, action_logstd = self.policy.apply(policy_state.params, state)
            return self.get_processed_action(action_mean)

        # def evaluate():
        #     nr_steps = 100
        #     mean_episode_return = []
        #     for i in range(episodes):
        #         done = False
        #         episode_return = 0.0  # keep as float until we add arrays
        #         state, _ = self.env.reset()

        #         for step in range(nr_steps):
        #             # Get action (jit compiled)
        #             processed_action = get_action(self.policy_state, state)
        #             # Convert to host numpy for env.step and debugging
        #             processed_action_np = jax.device_get(processed_action)


        #             # Optional guard: replace any NaNs in the actions before stepping
        #             if np.any(np.isnan(processed_action_np)):
        #                 rlx_logger.error(f"NaN action at episode {i+1}, step {step}: {processed_action_np}")
        #                 break
        #                 # Option A: replace NaNs with zeros (temporary)
        #                 processed_action_np = np.nan_to_num(processed_action_np, nan=0.0, posinf=1e6, neginf=-1e6)
        #                 # Option B: raise to stop early:
        #                 # raise RuntimeError("NaN action detected")

        #             next_state, reward, terminated, truncated, info = self.env.step(processed_action_np)

        #             # Debug reward NaNs
        #             if np.any(np.isnan(reward)):
        #                 rlx_logger.error(f"NaN reward at episode {i+1}, step {step}: reward={reward}")
        #                 rlx_logger.error(f"state={state}")
        #                 rlx_logger.error(f"processed_action_np={processed_action_np}")
        #                 rlx_logger.error(f"next_state={next_state}")
        #                 rlx_logger.error(f"info={info}")
        #                 break
        #                 # raise RuntimeError("NaN reward detected")  # Uncomment to stop and inspect

        #             done = terminated | truncated
        #             actual_next_state = next_state.copy()

        #             for idx, single_done in enumerate(done):
        #                 if single_done:
        #                     final_obs = self.env.get_final_observation_at_index(info, idx)
        #                     if np.any(np.isnan(final_obs)):
        #                         rlx_logger.error(f"NaN final observation at episode {i+1}, step {step}, index {idx}: {final_obs}")
        #                     actual_next_state[idx] = np.array(final_obs)

        #             # accumulate
        #             # reward may be per-env array; keep episode_return as array to preserve per-env returns
        #             if isinstance(episode_return, float):
        #                 episode_return = reward
        #             else:
        #                 episode_return = episode_return + reward

        #             state = next_state

        #         # compute per-episode mean safely
        #         ep_mean = np.nanmean(episode_return)  # ignores NaNs if any slipped through
        #         mean_episode_return.append(ep_mean)
        #         rlx_logger.info(f"Episode {i + 1} - Return: {ep_mean}")

        #     # final mean ignoring NaNs across episodes
        #     mean_episode_return = np.nanmean(np.array(mean_episode_return))
        #     rlx_logger.info(f"Mean Episode Return: {mean_episode_return}")



        def evaluate():
            nr_steps = 100
            mean_episode_return = []

            batch = Batch(
                states=np.zeros((episodes-1, nr_steps, self.nr_envs) + self.os_shape),
                next_states=np.zeros((episodes-1, nr_steps, self.nr_envs) + self.os_shape),
                actions=np.zeros((episodes-1, nr_steps, self.nr_envs) + self.as_shape),
                rewards=np.zeros((episodes-1, nr_steps, self.nr_envs)),
                terminations=np.zeros((episodes-1, nr_steps, self.nr_envs)),

                values=np.zeros((episodes-1, nr_steps, self.nr_envs)),
                log_probs=np.zeros((episodes-1, nr_steps, self.nr_envs)),
                advantages=np.zeros((episodes-1, nr_steps, self.nr_envs)),
                returns=np.zeros((episodes-1, nr_steps, self.nr_envs)),
            )
            for i in range(episodes):
                done = False
                episode_return = 0
                state, _ = self.env.reset()


                # UGLY FIX: for some reason the first episode's step returns nans!! Skip the first one!
                if i == 0:
                    dummy_return = 0
                    for step in range(10):
                        processed_action = get_action(self.policy_state, state)
                        next_state, reward, terminated, truncated, info = self.env.step(jax.device_get(processed_action))
                        done = terminated | truncated
                        state = next_state
                        dummy_return += reward
                    rlx_logger.info(f"Episode {i + 1} - Return: {dummy_return.mean()}")

                else:
                    for step in range(nr_steps):
                        processed_action = get_action(self.policy_state, state)
                        next_state, reward, terminated, truncated, info = self.env.step(jax.device_get(processed_action))
                        done = terminated | truncated
                        actual_next_state = next_state.copy()

                        for idx, single_done in enumerate(done):
                            if single_done:
                                actual_next_state[idx] = np.array(self.env.get_final_observation_at_index(info, idx))

                        batch.states[i-1, step] = state
                        batch.next_states[i-1, step] = actual_next_state
                        batch.actions[i-1, step] = processed_action
                        batch.rewards[i-1, step] = reward
                        batch.terminations[i-1, step] = terminated

                        state = next_state
                        episode_return += reward
                    
                    mean_episode_return.append(episode_return.mean())
                    rlx_logger.info(f"Episode {i + 1} - Return: {episode_return.mean()}")
                
            mean_episode_return = np.array(mean_episode_return).mean()
            rlx_logger.info(f"Mean Episode Return: {mean_episode_return}") 


            def flatten_and_prune(arr):
                flat = arr.reshape(-1, arr.shape[-1])
                nan_mask = ~np.isnan(flat).any(axis=1)
                return flat[nan_mask]

            exp_states = flatten_and_prune(batch.states)
            exp_actions = flatten_and_prune(batch.actions)
            exp_next_states = flatten_and_prune(batch.next_states)            
            exp_absorbing = flatten_and_prune(batch.terminations)
            exp_rewards = flatten_and_prune(batch.rewards)

            print(f"save path: {self.save_path}")
            np.savez(f"{self.save_path}/NEAR_{episodes-1}_episodes", states=exp_states, actions=exp_actions, 
                next_states=exp_next_states, absorbing=exp_absorbing, rewards=exp_rewards)


        def visualise():

            nr_steps = 1000
            done = False
            episode_return = 0
            state, _ = self.env.reset()


            for step in range(nr_steps):
                processed_action = get_action(self.policy_state, state)
                next_state, reward, terminated, truncated, info = self.env.step(jax.device_get(processed_action))
                done = terminated | truncated
                actual_next_state = next_state.copy()

                if done:
                    break

                state = next_state
                episode_return += reward
            
            rlx_logger.info(f"Episode Return: {episode_return.mean()}")
            rlx_logger.info(f"Step: {step}")
        
        # evaluate()
        visualise()




            
    def set_train_mode(self):
        ...


    def set_eval_mode(self):
        ...


    def general_properties():
        return GeneralProperties



        # def evaluate():
        #     evaluation_metrics = {}
        #     self.set_eval_mode()
        #     state, _ = self.env.reset()
        #     eval_nr_episodes = 0
        #     evaluation_metrics = {"eval/episode_return": [], "eval/episode_length": []}
        #     while True:
        #         processed_action = get_action(self.policy_state, state)
        #         state, reward, terminated, truncated, info = self.env.step(jax.device_get(processed_action))
        #         done = terminated | truncated
        #         for i, single_done in enumerate(done):
        #             if single_done:
        #                 eval_nr_episodes += 1
        #                 evaluation_metrics["eval/episode_return"].append(self.env.get_final_info_value_at_index(info, "episode_return", i))
        #                 evaluation_metrics["eval/episode_length"].append(self.env.get_final_info_value_at_index(info, "episode_length", i))
        #                 if eval_nr_episodes == episodes:
        #                     break
        #         if eval_nr_episodes == episodes:
        #             break

        #     evaluation_metrics = {key: np.mean(value) for key, value in evaluation_metrics.items()}
        #     print(evaluation_metrics)