import os
import shutil
import json
import logging
import time
from collections import deque
import tree
import numpy as np
import jax
from jax.lax import stop_gradient
import jax.numpy as jnp
import flax
from flax.training.train_state import TrainState
from flax.training import orbax_utils
import orbax.checkpoint
import optax
import wandb
from functools import partial

from irl_baselines.algorithms.iq_sac.flax.general_properties import GeneralProperties
from irl_baselines.algorithms.iq_sac.flax.policy import get_policy
from irl_baselines.algorithms.iq_sac.flax.critic import get_critic
from irl_baselines.algorithms.iq_sac.flax.entropy_coefficient import EntropyCoefficient, ConstantEntropyCoefficient
from irl_baselines.algorithms.iq_sac.flax.replay_buffer import ReplayBuffer
from irl_baselines.algorithms.iq_sac.flax.rl_train_state import RLTrainState
from irl_baselines.algorithms.data_utils import prepare_expert_data

rlx_logger = logging.getLogger("rl_x")

class IQ_SAC:
    def __init__(self, config, env, eval_env, run_path, writer):
        self.config = config
        self.env = env
        self.writer = writer

        self.save_model = config.runner.save_model
        self.save_path = os.path.join(run_path, "models")
        self.track_console = config.runner.track_console
        self.track_tb = config.runner.track_tb
        self.track_wandb = config.runner.track_wandb
        self.seed = config.environment.seed
        self.total_timesteps = config.algorithm.total_timesteps
        self.nr_envs = config.environment.nr_envs
        self.learning_rate = config.algorithm.learning_rate
        self.anneal_learning_rate = config.algorithm.anneal_learning_rate
        self.buffer_size = config.algorithm.buffer_size
        self.learning_starts = config.algorithm.learning_starts
        self.batch_size = config.algorithm.batch_size
        self.tau = config.algorithm.tau
        self.gamma = config.algorithm.gamma
        self.target_entropy = config.algorithm.target_entropy
        self.nr_hidden_units = config.algorithm.nr_hidden_units
        self.logging_frequency = config.algorithm.logging_frequency
        self.evaluation_frequency = config.algorithm.evaluation_frequency
        self.evaluation_episodes = config.algorithm.evaluation_episodes
        self.subsampling_cutoff = config.algorithm.get("subsampling_cutoff", 1)

        # IQ Learn Specific
        self.reg_mult = config.algorithm.reg_mult
        self.data_path = config.algorithm.data_path
        self.gp_lambda = config.algorithm.gp_lambda
        self.learn_ent_coeff = config.algorithm.learn_ent_coeff
        self.v0_loss = config.algorithm.v0_loss
        self.nr_q_updates_per_step = config.algorithm.nr_q_updates_per_step
        self.use_target_q = config.algorithm.use_target_q
        self.max_grad_norm = config.algorithm.max_grad_norm
        self.use_lsiq = config.algorithm.use_lsiq
        self.Q_max = 1.0 / (self.reg_mult * (1 - self.gamma))
        self.Q_min = - 1.0 / (self.reg_mult * (1 - self.gamma))

        rlx_logger.info(f"Using device: {jax.default_backend()}")
        
        self.rng = np.random.default_rng(self.seed)
        self.key = jax.random.PRNGKey(self.seed)
        self.key, policy_key, critic_key, entropy_coefficient_key = jax.random.split(self.key, 4)

        self.env_as_low = env.single_action_space.low
        self.env_as_high = env.single_action_space.high
        self.os_shape = env.single_observation_space.shape
        self.as_shape = env.single_action_space.shape

        self.policy, self.get_processed_action = get_policy(config, env)
        self.critic = get_critic(config, env)
        
        if self.target_entropy == "auto":
            self.target_entropy = -np.prod(env.single_action_space.shape).item()
        else:
            self.target_entropy = float(self.target_entropy)
        
        if self.learn_ent_coeff:
            self.entropy_coefficient = EntropyCoefficient(config.algorithm.init_ent_coeff)
        else:
            self.entropy_coefficient = ConstantEntropyCoefficient(config.algorithm.init_ent_coeff)

        self.policy.apply = jax.jit(self.policy.apply)
        self.critic.apply = jax.jit(self.critic.apply)
        self.entropy_coefficient.apply = jax.jit(self.entropy_coefficient.apply)
        self.nr_q_updates = 0

        def linear_schedule(count):
            step = (count * self.nr_envs) - self.learning_starts
            total_steps = self.total_timesteps - self.learning_starts
            fraction = 1.0 - (step / total_steps)
            return self.learning_rate * fraction
        
        self.q_learning_rate = linear_schedule if self.anneal_learning_rate else self.learning_rate
        self.policy_learning_rate = linear_schedule if self.anneal_learning_rate else self.learning_rate
        self.entropy_learning_rate = linear_schedule if self.anneal_learning_rate else self.learning_rate

        state = jnp.array([self.env.single_observation_space.sample()])
        action = jnp.array([self.env.single_action_space.sample()])

        self.policy_state = TrainState.create(
            apply_fn=self.policy.apply,
            params=self.policy.init(policy_key, state),
            tx=optax.chain(
                optax.clip_by_global_norm(self.max_grad_norm),
                optax.inject_hyperparams(optax.adam)(learning_rate=self.policy_learning_rate),
            ),
        )

        self.critic_state = RLTrainState.create(
            apply_fn=self.critic.apply,
            params=self.critic.init(critic_key, state, action),
            target_params=self.critic.init(critic_key, state, action),
            tx=optax.chain(
                optax.clip_by_global_norm(self.max_grad_norm),
                optax.inject_hyperparams(optax.adam)(learning_rate=self.q_learning_rate),
            ),
        )

        self.entropy_coefficient_state = TrainState.create(
            apply_fn=self.entropy_coefficient.apply,
            params=self.entropy_coefficient.init(entropy_coefficient_key),
            tx=optax.inject_hyperparams(optax.adam)(learning_rate=self.entropy_learning_rate)
        )

        if self.save_model:
            os.makedirs(self.save_path)
            self.best_mean_return = -np.inf
            self.best_model_file_name = "best.model"
            self.best_model_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
        
    
    def train(self):
        @jax.jit
        def get_action(policy_state: TrainState, state: np.ndarray, key: jax.random.PRNGKey):
            dist = self.policy.apply(policy_state.params, state)
            key, subkey = jax.random.split(key)
            action = dist.sample(seed=subkey)
            return action, key

        #######################
        """
        SAC policy update
        """
        #######################

        @jax.jit
        def sac_update(
                policy_state: TrainState, critic_state: RLTrainState, entropy_coefficient_state: TrainState,
                states: np.ndarray, next_states: np.ndarray, actions: np.ndarray, key: jax.random.PRNGKey
            ):
            def loss_fn(policy_params: flax.core.FrozenDict, critic_params: flax.core.FrozenDict, entropy_coefficient_params: flax.core.FrozenDict,
                        state: np.ndarray, next_state: np.ndarray, action: np.ndarray,
                        key1: jax.random.PRNGKey
                ):

                alpha_with_grad = self.entropy_coefficient.apply(entropy_coefficient_params)
                alpha = stop_gradient(alpha_with_grad)

                # Policy loss
                dist = self.policy.apply(policy_params, state)
                current_action = dist.sample(seed=key1)
                current_log_prob = dist.log_prob(current_action)
                entropy = stop_gradient(-current_log_prob)

                q = self.critic.apply(stop_gradient(critic_params), state, current_action)
                min_q = jnp.min(q)

                policy_loss = alpha * current_log_prob - min_q

                # Entropy loss
                entropy_loss = alpha_with_grad * (entropy - self.target_entropy)

                # Combine losses
                loss = policy_loss + entropy_loss

                # Create metrics
                metrics = {
                    "loss/policy_loss": policy_loss,
                    "loss/entropy_loss": entropy_loss,
                    "entropy/entropy": entropy,
                    "entropy/alpha": alpha,
                    "q_value/q_value": min_q,
                }

                return loss, (metrics)
            

            vmap_loss_fn = jax.vmap(loss_fn, in_axes=(None, None, None, 0, 0, 0, 0), out_axes=0)
            safe_mean = lambda x: jnp.mean(x) if x is not None else x
            mean_vmapped_loss_fn = lambda *a, **k: tree.map_structure(safe_mean, vmap_loss_fn(*a, **k))
            grad_loss_fn = jax.value_and_grad(mean_vmapped_loss_fn, argnums=(0, 2), has_aux=True)

            keys = jax.random.split(key, (self.batch_size) + 1)
            key, keys1 = keys[0], keys[1:]

            (loss, (metrics)), (policy_gradients, entropy_gradients) = grad_loss_fn(
                policy_state.params, critic_state.params, entropy_coefficient_state.params,
                states, next_states, actions, keys1)

            policy_state = policy_state.apply_gradients(grads=policy_gradients)
            entropy_coefficient_state = entropy_coefficient_state.apply_gradients(grads=entropy_gradients)

            # metrics["lr/learning_rate"] = policy_state.opt_state.hyperparams["learning_rate"]
            metrics["gradients/policy_grad_norm"] = optax.global_norm(policy_gradients)
            metrics["gradients/entropy_grad_norm"] = optax.global_norm(entropy_gradients)

            return policy_state, critic_state, entropy_coefficient_state, metrics, key

        #######################
        """
        IQ Learn q function update
        """
        #######################
        def get_v(critic_params, policy_params, ent_params, states, key):
            alpha_with_grad = self.entropy_coefficient.apply(ent_params)
            alpha = stop_gradient(alpha_with_grad)
            dist = self.policy.apply(policy_params, states)
            actions = dist.sample(seed=key)
            log_prob = dist.log_prob(actions)[..., None]              # [B, 1]
            q_values = self.critic.apply(critic_params, states, actions)  # [C, B, 1]
            q_values = jnp.squeeze(q_values, axis=-1).min(axis=0)    # [B]
            q_values = q_values[:, None]                              # [B, 1]
            v_values = q_values - alpha * log_prob
            return stop_gradient(v_values)

        def regularizer_loss(absorbing, reward, gamma, reg_mult,
                             treat_absorbing_states=False):
            if treat_absorbing_states:
                reg_absorbing = absorbing
            else:
                reg_absorbing = jnp.zeros_like(absorbing)
            chi2_loss = ((1.0 - reg_absorbing) * reg_mult * jnp.square(reward) +
                         reg_absorbing * (1.0 - gamma) * reg_mult * jnp.square(reward)).mean()
            return chi2_loss

        def gradient_penalty(critic_params, obs, act, num_rollout, key):
            if self.gp_lambda <= 0.0:
                return 0.0
            obs_plcy, obs_exp = obs[:num_rollout], obs[num_rollout:]
            act_plcy, act_exp = act[:num_rollout], act[num_rollout:]
            alpha = jax.random.uniform(key, shape=(num_rollout, 1))
            while alpha.ndim < obs_exp.ndim:
                alpha = alpha[..., None]
            s_interp = alpha * obs_exp + (1.0 - alpha) * obs_plcy
            a_interp = alpha * act_exp + (1.0 - alpha) * act_plcy

            def q_single(s, a):
                q = self.critic.apply(critic_params, s[None, ...], a[None, ...])  # [C, 1, 1]
                q = jnp.squeeze(q, axis=-1).min(axis=0)                           # [1]
                return q[0]                                                        # scalar

            grad_s, grad_a = jax.vmap(
                jax.grad(lambda ss, aa: q_single(ss, aa), argnums=(0, 1))
            )(s_interp, a_interp)
            grad_s_flat = grad_s.reshape((grad_s.shape[0], -1))
            grad_a_flat = grad_a.reshape((grad_a.shape[0], -1))
            grad_norm = jnp.linalg.norm(
                jnp.concatenate([grad_s_flat, grad_a_flat], axis=-1), axis=-1
            )
            gp = self.gp_lambda * jnp.mean((grad_norm - 1.0) ** 2)
            return gp

        @jax.jit
        def iq_update(policy_state: TrainState, critic_state: RLTrainState,
                      entropy_coefficient_state: TrainState,
                      batch_states: jnp.ndarray, batch_actions: jnp.ndarray,
                      batch_next_states: jnp.ndarray, batch_terminations: jnp.ndarray,
                      expert_states: jnp.ndarray, expert_actions: jnp.ndarray,
                      expert_next_states: jnp.ndarray, expert_absorbing: jnp.ndarray,
                      key: jax.random.PRNGKey):

            gamma = jnp.asarray(self.gamma, dtype=jnp.float32)

            rollout_states = batch_states.reshape((-1,) + self.os_shape)
            rollout_actions = batch_actions.reshape((-1,) + self.as_shape)
            rollout_next_states = batch_next_states.reshape((-1,) + self.os_shape)
            rollout_absorbing = batch_terminations.reshape(-1).astype(jnp.float32)
            num_rollout = rollout_states.shape[0]

            key, key_demo, key_v_next, key_v, key_gp = jax.random.split(key, 5)
            perm = jax.random.permutation(key_demo, expert_states.shape[0])
            idx = perm[:num_rollout]
            expert_s = expert_states[idx]
            expert_a = expert_actions[idx]
            expert_s_next = expert_next_states[idx]
            expert_abs = expert_absorbing[idx].astype(jnp.float32)

            obs = jnp.concatenate([rollout_states, expert_s], axis=0)           # [2N, S]
            act = jnp.concatenate([rollout_actions, expert_a], axis=0)         # [2N, A]
            next_obs = jnp.concatenate([rollout_next_states, expert_s_next], axis=0)
            absorbing = jnp.concatenate([rollout_absorbing, expert_abs], axis=0)
            absorbing = absorbing[:, None]                                      # [2N, 1]

            is_expert = jnp.concatenate(
                [jnp.zeros((num_rollout,), dtype=bool),
                 jnp.ones((num_rollout,), dtype=bool)],
                axis=0
            )
            is_expert_f = is_expert.astype(jnp.float32)
            expert_denom = is_expert_f.sum() + 1e-8

            def loss_fn(critic_params):
                # δ(s,a) = Q(s,a) - γ(1 - d)V(s')  (IQ-Learn Eq. 9)
                q_values = self.critic.apply(critic_params, obs, act)          # [C, 2N, 1]
                q_values = jnp.squeeze(q_values, axis=-1).min(axis=0)         # [2N]
                q_values = q_values[:, None]                                   # [2N, 1]

                target_params = jax.lax.cond(
                    self.use_target_q,
                    lambda _: critic_state.target_params,
                    lambda _: critic_params,
                    operand=None
                )
                next_v = get_v(target_params, policy_state.params,
                               entropy_coefficient_state.params,
                               next_obs, key_v_next)                           # [2N, 1]
                if self.use_lsiq:
                    y = (1.0 - absorbing) * gamma * jnp.clip(next_v, self.Q_min, self.Q_max)
                else:
                    y = (1.0 - absorbing) * gamma * next_v

                reward = q_values - y                                          # [2N, 1]
                reward_flat = reward.squeeze(-1)                               # [2N]

                # -E_expert[δ]
                exp_reward_mean = (reward_flat * is_expert_f).sum() / expert_denom

                if self.use_lsiq:
                    loss_term1 = jnp.mean(optax.losses.squared_error(q_values * is_expert_f, jnp.ones_like(q_values * is_expert_f) * self.Q_max))
                else:
                    loss_term1 = -exp_reward_mean

                # E_ρ[V(s) - γV(s')]
                V = get_v(critic_params, policy_state.params,
                          entropy_coefficient_state.params,
                          obs, key_v)                                          # [2N, 1]
                value = (V - y).squeeze(-1)                                    # [2N]

                if self.v0_loss:
                    V_flat = V.squeeze(-1)                                     # [2N]
                    V_exp_mean = (V_flat * is_expert_f).sum() / expert_denom
                    loss_term2 = (1.0 - gamma) * V_exp_mean                    # v0 loss
                else:
                    loss_term2 = value.mean()                                  # value loss

                chi2_loss = regularizer_loss(
                    absorbing, reward, gamma, self.reg_mult,
                    treat_absorbing_states=True,
                )
                loss_gp = gradient_penalty(critic_params, obs, act, num_rollout, key_gp)
                loss_Q = loss_term1 + loss_term2 + chi2_loss + loss_gp

                diff_exp = reward_flat - exp_reward_mean
                exp_reward_var = ((diff_exp ** 2) * is_expert_f).sum() / expert_denom
                exp_reward_std = jnp.sqrt(exp_reward_var)

                metrics = {
                    "iq/loss_q": loss_Q,
                    "iq/loss_term1_expert": loss_term1,
                    "iq/loss_term2_value": loss_term2,
                    "iq/chi2_loss": chi2_loss,
                    "iq/gp_loss": loss_gp,
                    "iq/reward_mean": reward_flat.mean(),
                    "iq/reward_expert_mean": exp_reward_mean,
                    "iq/reward_expert_std": exp_reward_std,
                }
                return loss_Q, metrics

            (loss_Q, metrics), critic_grads = jax.value_and_grad(
                loss_fn, has_aux=True
            )(critic_state.params)

            new_critic_state = critic_state.apply_gradients(grads=critic_grads)
            new_target_params = optax.incremental_update(
                new_critic_state.params, new_critic_state.target_params, self.tau
            )
            new_critic_state = new_critic_state.replace(target_params=new_target_params)

            metrics = {k: jnp.mean(v) for k, v in metrics.items()}
            metrics["gradients/q_grad_norm"] = optax.global_norm(critic_grads)
            return policy_state, new_critic_state, entropy_coefficient_state, metrics, key


        @jax.jit
        def get_deterministic_action(policy_state: TrainState, state: np.ndarray):
            dist = self.policy.apply(policy_state.params, state)
            action = dist.mode()
            return self.get_processed_action(action)


        self.set_train_mode()

        demonstrations = prepare_expert_data(self.data_path, self.subsampling_cutoff)
        self.expert_states = demonstrations["states"]
        self.expert_actions = demonstrations["actions"]
        self.expert_next_states = demonstrations["next_states"]
        self.expert_absorbing = demonstrations["absorbing"].flatten()

        replay_buffer = ReplayBuffer(int(self.buffer_size), self.nr_envs, self.env.single_observation_space.shape, self.env.single_action_space.shape, self.rng)

        saving_return_buffer = deque(maxlen=100 * self.nr_envs)

        state, _ = self.env.reset()
        global_step = 0
        # nr_updates = 0
        nr_episodes = 0
        time_metrics_collection = {}
        step_info_collection = {}
        optimization_metrics_collection = {}
        evaluation_metrics_collection = {}
        steps_metrics = {}
        while global_step < self.total_timesteps:
            start_time = time.time()


            # Acting
            dones_this_rollout = 0
            if global_step < self.learning_starts:
                processed_action = np.array([self.env.single_action_space.sample() for _ in range(self.nr_envs)])
                action = (processed_action - self.env_as_low) / (self.env_as_high - self.env_as_low) * 2.0 - 1.0
            else:
                action, self.key = get_action(self.policy_state, state, self.key)
                processed_action = self.get_processed_action(action)
            
            next_state, reward, terminated, truncated, info = self.env.step(jax.device_get(processed_action))
            done = terminated | truncated
            actual_next_state = next_state.copy()
            for i, single_done in enumerate(done):
                if single_done:
                    actual_next_state[i] = np.array(self.env.get_final_observation_at_index(info, i))
                    saving_return_buffer.append(self.env.get_final_info_value_at_index(info, "episode_return", i))
                    dones_this_rollout += 1
            for key, info_value in self.env.get_logging_info_dict(info).items():
                step_info_collection.setdefault(key, []).extend(info_value)
            
            replay_buffer.add(state, actual_next_state, action, reward, terminated)

            state = next_state
            global_step += self.nr_envs
            nr_episodes += dones_this_rollout

            acting_end_time = time.time()
            time_metrics_collection.setdefault("time/acting_time", []).append(acting_end_time - start_time)


            # What to do in this step after acting
            should_learning_start = global_step > self.learning_starts
            should_optimize = should_learning_start
            should_evaluate = global_step % self.evaluation_frequency == 0 and self.evaluation_frequency != -1
            should_try_to_save = should_learning_start and self.save_model and dones_this_rollout > 0
            should_log = global_step % self.logging_frequency == 0


            # Optimizing - Prepare batches
            if should_optimize:
                batch_states, batch_next_states, batch_actions, batch_rewards, batch_terminations = replay_buffer.sample(self.batch_size)


            # Optimizing - Q-functions, policy and entropy coefficient
            if should_optimize:
                for _ in range(self.nr_q_updates_per_step):
                    self.policy_state, self.critic_state, self.entropy_coefficient_state, iq_optimization_metrics, self.key = iq_update(
                        self.policy_state,
                        self.critic_state,
                        self.entropy_coefficient_state,
                        jnp.array(batch_states),
                        jnp.array(batch_actions),
                        jnp.array(batch_next_states),
                        jnp.array(batch_terminations),
                        jnp.array(self.expert_states),
                        jnp.array(self.expert_actions),
                        jnp.array(self.expert_next_states),
                        jnp.array(self.expert_absorbing),
                        self.key,
                    )
                    self.nr_q_updates += 1

                self.policy_state, self.critic_state, self.entropy_coefficient_state, sac_optimization_metrics, self.key = sac_update(
                                self.policy_state, self.critic_state, self.entropy_coefficient_state, batch_states, batch_next_states, 
                                batch_actions, self.key)

                optimization_metrics = iq_optimization_metrics | sac_optimization_metrics


                for key, value in optimization_metrics.items():
                    optimization_metrics_collection.setdefault(key, []).append(value)
                # nr_updates += self.utd
            
            optimizing_end_time = time.time()
            time_metrics_collection.setdefault("time/optimizing_time", []).append(optimizing_end_time - acting_end_time)


            # Evaluating
            if should_evaluate:
                self.set_eval_mode()
                state, _ = self.env.reset()
                eval_nr_episodes = 0
                while True:
                    processed_action = get_deterministic_action(self.policy_state, state)
                    state, reward, terminated, truncated, info = self.env.step(jax.device_get(processed_action))
                    done = terminated | truncated
                    for i, single_done in enumerate(done):
                        if single_done:
                            eval_nr_episodes += 1
                            evaluation_metrics_collection.setdefault("eval/episode_return", []).append(self.env.get_final_info_value_at_index(info, "episode_return", i))
                            evaluation_metrics_collection.setdefault("eval/episode_length", []).append(self.env.get_final_info_value_at_index(info, "episode_length", i))
                            if eval_nr_episodes == self.evaluation_episodes:
                                break
                    if eval_nr_episodes == self.evaluation_episodes:
                        break
                state, _ = self.env.reset()
                self.set_train_mode()
            
            evaluating_end_time = time.time()
            time_metrics_collection.setdefault("time/evaluating_time", []).append(evaluating_end_time - optimizing_end_time)


            # Saving
            if should_try_to_save:
                mean_return = np.mean(saving_return_buffer)
                if mean_return > self.best_mean_return:
                    self.best_mean_return = mean_return
                    self.save()
            
            saving_end_time = time.time()
            time_metrics_collection.setdefault("time/saving_time", []).append(saving_end_time - evaluating_end_time)
            time_metrics_collection.setdefault("time/sps", []).append(self.nr_envs / (saving_end_time - start_time))


            # Logging
            if should_log:
                self.start_logging(global_step)

                steps_metrics["steps/nr_env_steps"] = global_step
                # steps_metrics["steps/nr_updates"] = nr_updates
                steps_metrics["steps/nr_episodes"] = nr_episodes

                rollout_info_metrics = {}
                env_info_metrics = {}
                if step_info_collection:
                    info_names = list(step_info_collection.keys())
                    for info_name in info_names:
                        metric_group = "rollout" if info_name in ["episode_return", "episode_length"] else "env_info"
                        metric_dict = rollout_info_metrics if metric_group == "rollout" else env_info_metrics
                        mean_value = np.mean(step_info_collection[info_name])
                        if mean_value == mean_value:  # Check if mean_value is NaN
                            metric_dict[f"{metric_group}/{info_name}"] = mean_value
                
                time_metrics = {key: np.mean(value) for key, value in time_metrics_collection.items()}
                optimization_metrics = {key: np.mean(value) for key, value in optimization_metrics_collection.items()}
                evaluation_metrics = {key: np.mean(value) for key, value in evaluation_metrics_collection.items()}
                combined_metrics = {**rollout_info_metrics, **evaluation_metrics, **env_info_metrics, **steps_metrics, **time_metrics, **optimization_metrics}
                for key, value in combined_metrics.items():
                    self.log(f"{key}", value, global_step)

                time_metrics_collection = {}
                step_info_collection = {}
                optimization_metrics_collection = {}
                evaluation_metrics_collection = {}

                self.end_logging()


    def log(self, name, value, step):
        if self.track_tb:
            self.writer.add_scalar(name, value, step)
        if self.track_console:
            self.log_console(name, value)
    

    def log_console(self, name, value):
        value = np.format_float_positional(value, trim="-")
        rlx_logger.info(f"│ {name.ljust(30)}│ {str(value).ljust(14)[:14]} │", flush=False)

    
    def start_logging(self, step):
        if self.track_console:
            rlx_logger.info("┌" + "─" * 31 + "┬" + "─" * 16 + "┐", flush=False)
        else:
            rlx_logger.info(f"Step: {step}")


    def end_logging(self):
        if self.track_console:
            rlx_logger.info("└" + "─" * 31 + "┴" + "─" * 16 + "┘")


    def save(self):
        checkpoint = {
            "policy": self.policy_state,
            "critic": self.critic_state,
            "entropy_coefficient": self.entropy_coefficient_state,           
        }
        save_args = orbax_utils.save_args_from_target(checkpoint)
        self.best_model_checkpointer.save(f"{self.save_path}/tmp", checkpoint, save_args=save_args)
        with open(f"{self.save_path}/tmp/config_algorithm.json", "w") as f:
            json.dump(self.config.algorithm.to_dict(), f)
        shutil.make_archive(f"{self.save_path}/{self.best_model_file_name}", "zip", f"{self.save_path}/tmp")
        # os.rename(f"{self.save_path}/{self.best_model_file_name}.zip", f"{self.save_path}/{self.best_model_file_name}")
        shutil.rmtree(f"{self.save_path}/tmp")

        if self.track_wandb:
            wandb.save(f"{self.save_path}/{self.best_model_file_name}", base_path=self.save_path)


    def load(config, env, run_path, writer, explicitly_set_algorithm_params):
        splitted_path = config.runner.load_model.split("/")
        checkpoint_dir = os.path.abspath("/".join(splitted_path[:-1]))
        checkpoint_file_name = splitted_path[-1]
        shutil.unpack_archive(f"{checkpoint_dir}/{checkpoint_file_name}", f"{checkpoint_dir}/tmp", "zip")
        checkpoint_dir = f"{checkpoint_dir}/tmp"

        loaded_algorithm_config = json.load(open(f"{checkpoint_dir}/config_algorithm.json", "r"))
        for key, value in loaded_algorithm_config.items():
            if f"algorithm.{key}" not in explicitly_set_algorithm_params:
                config.algorithm[key] = value
        model = IQ_SAC(config, env, run_path, writer)

        target = {
            "policy": model.policy_state,
            "critic": model.critic_state,
            "entropy_coefficient": model.entropy_coefficient_state
        }
        restore_args = orbax_utils.restore_args_from_target(target)
        checkpointer = orbax.checkpoint.PyTreeCheckpointer()
        checkpoint = checkpointer.restore(checkpoint_dir, item=target, restore_args=restore_args)

        model.policy_state = checkpoint["policy"]
        model.critic_state = checkpoint["critic"]
        model.entropy_coefficient_state = checkpoint["entropy_coefficient"]

        shutil.rmtree(checkpoint_dir)

        return model
    

    def test(self, episodes):
        @jax.jit
        def get_action(policy_state: TrainState, state: np.ndarray):
            dist = self.policy.apply(policy_state.params, state)
            action = dist.mode()
            return self.get_processed_action(action)
        
        self.set_eval_mode()
        for i in range(episodes):
            done = False
            episode_return = 0
            state, _ = self.env.reset()
            while not done:
                processed_action = get_action(self.policy_state, state)
                state, reward, terminated, truncated, info = self.env.step(jax.device_get(processed_action))
                done = terminated | truncated
                episode_return += reward
            rlx_logger.info(f"Episode {i + 1} - Return: {episode_return}")
    

    def set_train_mode(self):
        ...


    def set_eval_mode(self):
        ...


    def general_properties():
        return GeneralProperties
