# Code mostly taken from https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_envpool_xla_jax.py
import gymnasium as gym
import numpy as np
import jax
import jax.numpy as jnp
import flax
from distrax import Normal, MultivariateNormalDiag
from flax import linen as nn
from functools import partial

from flax.training.train_state import TrainState
import optax
import optuna
import functools
import os
import random
from typing import Any
#from torch.utils.tensorboard import SummaryWriter
from PIL import Image, ImageDraw, ImageFont
from moviepy.editor import ImageSequenceClip
from matplotlib import pyplot as plt

import datetime

from utils import build_env, ActorTrainState, EpisodeStatistics, Storage, ObservationActionBuffer, convert_to_discrete_tree, plot_decision_tree, plot_decision_tree_soft, OBSERVATION_LABELS

from args import get_args
import configs_A2C
from sympol import SYMPOL_RL
from mlp import Critic_MLP, Actor_MLP, Actor_MLP_Continuous
from sdt import Critic_SDT, Actor_SDT


from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor, plot_tree
import distrax

import graphviz
import wandb
import pickle
import time

import jax
import jax.numpy as jnp
from jax import lax
from gymnasium.wrappers import FlattenObservation

#os.environ['MUJOCO_GL'] = 'egl'


# Fix weird OOM https://github.com/google/jax/discussions/6332#discussioncomment-1279991
#os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.1"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "0"
#os.environ["XLA_FLAGS"] = "--xla_dump_to=~/tmp/foo"

# Fix CUDNN non-determinisim; https://github.com/google/jax/issues/4823#issuecomment-952835771
#os.environ["TF_XLA_FLAGS"] = "--xla_gpu_autotune_level=2 --xla_gpu_deterministic_reductions"
#os.environ["TF_CUDNN DETERMINISTIC"] = "1"

def train_agent(args, trial=None, queue=None):
    start_time = time.time()

    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_number)
    print('CUDA_VISIBLE_DEVICES', os.environ['CUDA_VISIBLE_DEVICES'])
    
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")

    if trial is not None:
        timestamp = 'TRIAL_NO' + str(trial.number) + '_' + timestamp
        
        if args.actor == 'mlp':         
            suggested_params = configs.suggest_config_mlp(trial, args.env_id)
        elif args.actor == 'sympol':         
            suggested_params = configs.suggest_config_sympol(trial, args.env_id)
        elif args.actor == 'sdt':         
            suggested_params = configs.suggest_config_sdt(trial, args.env_id)
        elif args.actor == 'd-sdt':         
            suggested_params = configs.suggest_config_dsdt(trial, args.env_id)
        elif args.actor == 'stateActionDT':         
            suggested_params = configs.suggest_config_stateActionDT(trial, args.env_id)            
        else:
            suggested_params = {}
        
        args.__dict__.update(suggested_params)    

        n_steps = args.n_steps
        
    elif not args.use_best_config:
        if False:
            if args.env_id == 'Hopper-v4':
                n_steps = 512
                args.n_envs = 2
            elif args.env_id == 'CartPole-v1':
                n_steps = 32
                args.n_envs = 8   
            elif args.env_id == 'Pendulum-v1':
                n_steps = 1024
                args.n_envs = 4
            elif args.env_id == 'BipedalWalker-v3':
                n_steps = 512
                args.n_envs = 16          
            elif args.env_id == 'LunarLander-v2':
                n_steps = 512
                args.n_envs = 8   
        else:
            n_steps = 512
            args.n_envs = 8 
    else:
        n_steps = args.n_steps
            
    if args.dynamic_buffer:
        n_steps = max(2, n_steps // 8)
        
    accumulate_gradients_every = args.accumulate_gradients_every
    accumulate_gradients_every_initial = accumulate_gradients_every
            
    initial_steps = n_steps
    # these parameters are defined dynamically
    batch_size = int(args.n_envs * n_steps)
    #minibatch_size = int(batch_size // args.n_minibatches)
    minibatch_size = args.minibatch_size
    while batch_size // minibatch_size < 2:
        minibatch_size = minibatch_size // 2
    n_iterations = args.total_steps // batch_size
    #eval_freq = max(args.eval_freq // batch_size, 1)

    trial_scores = []
    for random_trial_number in range(1, args.random_trials+1):
        
        if args.actor == 'sympol':
            model_identifier = '-'.join([str(args.depth), str(args.n_estimators), str(args.seed)])
        elif args.actor != 'mlp':
            model_identifier = '-'.join([str(args.depth),  str(args.seed)])            
        else:
            model_identifier = str(args.seed)
            
        run_name = '-'.join([args.run_name, args.actor, str(np.round(args.learning_rate_actor, 6)), model_identifier, timestamp])
            
        group_name = run_name
        run_name = run_name + '_' + str(random_trial_number)
        
        envs = build_env(args.env_id, n_env=args.n_envs, view_size=args.view_size)
        
    
        obs_dim = envs.single_observation_space.shape[-1]
            
        print('Observations:', obs_dim)
        if isinstance(envs.single_action_space, gym.spaces.Discrete):
            if any(substring in args.env_id for substring in ['Crossing', 'DistShift', 'Empty', 'LavaGap', 'FourRooms', 'Dynamic-Obstacles']):
                action_dim = 3
                action_indices = [0,1,2]
            elif any(substring in args.env_id for substring in ['MultiRoom', 'Unlock', 'GoToDoor', 'RedBlueDoors']):
                action_dim = 4
                if 'GoToDoor' in args.env_id:
                    action_indices = [0,1,2,6]
                else:
                    action_indices = [0,1,2,5]
            elif any(substring in args.env_id for substring in ['UnlockPickup', 'DoorKey']):
                action_dim = 5    
                action_indices = [0,1,2,3,5]
            #elif any(substring in args.env_id for substring in ['forex']):
            #    action_dim = 2        
            else:
                action_dim = envs.single_action_space.n
                action_indices = [i for i in range(action_dim)]
            print('Actions:', action_dim)
        elif isinstance(envs.single_action_space, gym.spaces.Box):
            action_dim = envs.single_action_space.shape[-1]
            action_indices = [i for i in range(action_dim)]
            print('Actions:', action_dim)
            #raise NotImplementedError


        
        if args.track:
            wandb_run = wandb.init(
                project=f"{args.exp_name}_{args.env_id}",
                group=group_name,
                tags=[args.run_name],
                #sync_tensorboard=True,
                config=vars(args),
                name=run_name,
                monitor_gym=True,
                save_code=False, 
            )
        env_seed = args.seed + (random_trial_number * 100)

        if True:
            seed_training = args.seed + (random_trial_number * 100)
        else:
            seed_training = args.seed
            
        key = jax.random.PRNGKey(seed_training)
        model_key = jax.random.PRNGKey(args.seed)
        model_key, actor_key, critic_key = jax.random.split(model_key, 3)
    
        # agent setup
        if args.critic == "mlp":
            critic = Critic_MLP(num_layers=args.num_layers, neurons_per_layer=args.neurons_per_layer)
        elif args.critic == "sdt":
            critic = Critic_SDT(depth=args.depth, temperature=args.temperature)#, temp=1)
        critic.apply = jax.jit(critic.apply)
        if args.actor == "mlp" or args.actor == "stateActionDT":
            if args.action_type == "discrete":
                actor = Actor_MLP(action_dim=action_dim, num_layers=args.num_layers, neurons_per_layer=args.neurons_per_layer)
            else:
                actor = Actor_MLP_Continuous(action_dim=action_dim, num_layers=args.num_layers, neurons_per_layer=args.neurons_per_layer)
            #args.learning_rate_actor = args.learning_rate_critic  # same lr for MLP's
            args.accumulate_gradients_every = 1  # do not accumulate gradients for MLP's
            learning_rate_actor = args.learning_rate_actor
            actor.apply = jax.jit(actor.apply)
        elif args.actor == "sympol":
            learning_rate_actor_weights = args.learning_rate_actor_weights
            learning_rate_actor_split_values = args.learning_rate_actor_split_values
            learning_rate_actor_split_idx_array = args.learning_rate_actor_split_idx_array
            learning_rate_actor_leaf_array = args.learning_rate_actor_leaf_array
            learning_rate_actor_log_std = args.learning_rate_actor_log_std

            
            actor = SYMPOL_RL(
                obs_dim=obs_dim,
                action_dim=action_dim,
                action_type=args.action_type,
                depth=args.depth,
                n_estimators=args.n_estimators,
            )
        elif args.actor == "sdt" or args.actor == "d-sdt":
            actor = Actor_SDT(action_dim=action_dim, depth=args.depth, temperature=args.temperature, action_type=args.action_type)#, temp=1)
            
            #args.learning_rate_actor = args.learning_rate_critic  # same lr for SDT's
            learning_rate_actor = args.learning_rate_actor
            args.accumulate_gradients_every = 1  # do not accumulate gradients for SDT's
            actor.apply = jax.jit(actor.apply)

        if args.adamW:
            critic_state = TrainState.create(
                apply_fn=None,
                params=critic.init(critic_key, jnp.array([envs.single_observation_space.sample()])),
                tx=optax.chain(
                    optax.clip_by_global_norm(args.max_grad_norm), optax.adamw(learning_rate=args.learning_rate_critic)
                ),
            )
        else:
            critic_state = TrainState.create(
                apply_fn=None,
                params=critic.init(critic_key, jnp.array([envs.single_observation_space.sample()])),
                tx=optax.chain(
                    optax.clip_by_global_norm(args.max_grad_norm), optax.adam(learning_rate=args.learning_rate_critic)
                ),
            )            

        if args.actor == "sympol":
            def map_nested_fn(fn):
                '''Recursively apply `fn` to key-value pairs of a nested dict.'''
                def map_fn(nested_dict):
                    return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v))
                        for k, v in nested_dict.items()}
                return map_fn
            if args.SWA:    
                from optax_swag import swag
                if args.adamW:
                    actor_state = ActorTrainState.create(
                        apply_fn=None,
                        params=actor.init(actor_key, jnp.array([envs.single_observation_space.sample()])),
                        tx=optax.chain(
                            optax.clip_by_global_norm(args.max_grad_norm),
                            optax.multi_transform(
                                {'estimator_weights': optax.inject_hyperparams(optax.adam)(learning_rate_actor_weights), 
                                 'split_values': optax.inject_hyperparams(optax.adam)(learning_rate_actor_split_values), 
                                 'split_idx_array': optax.inject_hyperparams(optax.adamw)(learning_rate_actor_split_idx_array), 
                                 'leaf_array': optax.inject_hyperparams(optax.adamw)(learning_rate_actor_leaf_array), 
                                 'log_std': optax.inject_hyperparams(optax.adamw)(learning_rate_actor_log_std),}, 
                                map_nested_fn(lambda k, _: k)),
                                swag(10, 2),
                        ),
                        grad_accum=jax.tree.map(
                            jnp.zeros_like, actor.init(actor_key, jnp.array([envs.single_observation_space.sample()]))
                        ),
                        indices=actor.init_indices(actor_key) if args.actor == "sympol" else None,
                    )
                else:
                    actor_state = ActorTrainState.create(
                        apply_fn=None,
                        params=actor.init(actor_key, jnp.array([envs.single_observation_space.sample()])),
                        tx=optax.chain(
                            optax.clip_by_global_norm(args.max_grad_norm),
                            optax.multi_transform(
                                {'estimator_weights': optax.inject_hyperparams(optax.adam)(learning_rate_actor_weights), 
                                 'split_values': optax.inject_hyperparams(optax.adam)(learning_rate_actor_split_values), 
                                 'split_idx_array': optax.inject_hyperparams(optax.adam)(learning_rate_actor_split_idx_array), 
                                 'leaf_array': optax.inject_hyperparams(optax.adam)(learning_rate_actor_leaf_array), 
                                 'log_std': optax.inject_hyperparams(optax.adam)(learning_rate_actor_log_std),}, 
                                map_nested_fn(lambda k, _: k)),
                                swag(10, 2),
                        ),
                        grad_accum=jax.tree.map(
                            jnp.zeros_like, actor.init(actor_key, jnp.array([envs.single_observation_space.sample()]))
                        ),
                        indices=actor.init_indices(actor_key) if args.actor == "sympol" else None,
                    )                    
            else:
                if args.adamW:
                    actor_state = ActorTrainState.create(
                        apply_fn=None,
                        params=actor.init(actor_key, jnp.array([envs.single_observation_space.sample()])),
                        tx=optax.chain(
                            optax.clip_by_global_norm(args.max_grad_norm),
                            optax.multi_transform(
                                {'estimator_weights': optax.inject_hyperparams(optax.adam)(learning_rate_actor_weights), 
                                 'split_values': optax.inject_hyperparams(optax.adam)(learning_rate_actor_split_values), 
                                 'split_idx_array': optax.inject_hyperparams(optax.adamw)(learning_rate_actor_split_idx_array), 
                                 'leaf_array': optax.inject_hyperparams(optax.adamw)(learning_rate_actor_leaf_array), 
                                 'log_std': optax.inject_hyperparams(optax.adamw)(learning_rate_actor_log_std),}, 
                                map_nested_fn(lambda k, _: k)),
                        ),
                        grad_accum=jax.tree.map(
                            jnp.zeros_like, actor.init(actor_key, jnp.array([envs.single_observation_space.sample()]))
                        ),
                        indices=actor.init_indices(actor_key) if args.actor == "sympol" else None,
                    )
                else:
                    actor_state = ActorTrainState.create(
                        apply_fn=None,
                        params=actor.init(actor_key, jnp.array([envs.single_observation_space.sample()])),
                        tx=optax.chain(
                            optax.clip_by_global_norm(args.max_grad_norm),
                            optax.multi_transform(
                                {'estimator_weights': optax.inject_hyperparams(optax.adam)(learning_rate_actor_weights), 
                                 'split_values': optax.inject_hyperparams(optax.adam)(learning_rate_actor_split_values), 
                                 'split_idx_array': optax.inject_hyperparams(optax.adam)(learning_rate_actor_split_idx_array), 
                                 'leaf_array': optax.inject_hyperparams(optax.adam)(learning_rate_actor_leaf_array), 
                                 'log_std': optax.inject_hyperparams(optax.adam)(learning_rate_actor_log_std),}, 
                                map_nested_fn(lambda k, _: k)),
                        ),
                        grad_accum=jax.tree.map(
                            jnp.zeros_like, actor.init(actor_key, jnp.array([envs.single_observation_space.sample()]))
                        ),
                        indices=actor.init_indices(actor_key) if args.actor == "sympol" else None,
                    ) 
                    
        else:
            if args.adamW:
                actor_state = ActorTrainState.create(
                    apply_fn=None,
                    params=actor.init(actor_key, jnp.array([envs.single_observation_space.sample()])),
                    tx=optax.chain(
                        optax.clip_by_global_norm(args.max_grad_norm),
                        optax.inject_hyperparams(optax.adamw)(learning_rate_actor),
                    ),
                    grad_accum=jax.tree.map(
                        jnp.zeros_like, actor.init(actor_key, jnp.array([envs.single_observation_space.sample()]))
                    ),
                    indices=None,
                )  
            else:
                actor_state = ActorTrainState.create(
                    apply_fn=None,
                    params=actor.init(actor_key, jnp.array([envs.single_observation_space.sample()])),
                    tx=optax.chain(
                        optax.clip_by_global_norm(args.max_grad_norm),
                        optax.inject_hyperparams(optax.adam)(learning_rate_actor),
                    ),
                    grad_accum=jax.tree.map(
                        jnp.zeros_like, actor.init(actor_key, jnp.array([envs.single_observation_space.sample()]))
                    ),
                    indices=None,
                )  

        
        lr_scheduler = optax.contrib.reduce_on_plateau(patience=3, factor=0.5)
        lr_scheduler_state = lr_scheduler.init(actor_state.params)
        #actor.apply = jax.jit(actor.apply)
        #critic.apply = jax.jit(critic.apply)
            
        episode_stats = EpisodeStatistics(
            episode_returns=jnp.zeros(args.n_envs, dtype=jnp.float32),
            episode_lengths=jnp.zeros(args.n_envs, dtype=jnp.int32),
            returned_episode_returns=jnp.zeros(args.n_envs, dtype=jnp.float32),
            returned_episode_lengths=jnp.zeros(args.n_envs, dtype=jnp.int32),
        )
            
        @jax.jit
        def get_action_and_value(
            actor_state: TrainState,
            critic_state: TrainState,
            next_obs: np.ndarray,
            next_done: np.ndarray,
            storage: Storage,
            step: int,
            key: jax.random.PRNGKey,
        ):
            """sample action, calculate value, logprob, entropy, and update storage"""
            if args.action_type == "discrete":
                action_logits = actor.apply(actor_state.params, next_obs, indices=actor_state.indices)
                action_distribution = distrax.Categorical(logits=action_logits)
                value = critic.apply(critic_state.params, next_obs)
        
                # Sample discrete actions from Normal distribution
                key, subkey = jax.random.split(key)
                action = action_distribution.sample(seed=subkey)

                logprob = action_distribution.log_prob(action)#.sum(-1)
                storage = storage.replace(
                    obs=storage.obs.at[step].set(next_obs),
                    dones=storage.dones.at[step].set(next_done),
                    actions=storage.actions.at[step].set(action),
                    logprobs=storage.logprobs.at[step].set(logprob),
                    values=storage.values.at[step].set(value.squeeze()),
                )
            else:
                result = actor.apply(actor_state.params, next_obs, indices=actor_state.indices)
                action_distribution = distrax.MultivariateNormalDiag(result[0], jnp.exp(result[1]))
                                                  
                value = critic.apply(critic_state.params, next_obs)

                # Sample continuous actions from Normal distribution
                key, subkey = jax.random.split(key)
                action = action_distribution.sample(seed=subkey)
                logprob = action_distribution.log_prob(action)#.sum(-1)
                
                storage = storage.replace(
                    obs=storage.obs.at[step].set(next_obs),
                    dones=storage.dones.at[step].set(next_done),
                    actions=storage.actions.at[step].set(action),
                    logprobs=storage.logprobs.at[step].set(logprob),
                    values=storage.values.at[step].set(value.squeeze()),
                )

            return storage, action, key


        
        @jax.jit
        def get_action_and_value2(
            actor_state_params: flax.core.FrozenDict,
            critic_state_params: flax.core.FrozenDict,
            x: np.ndarray,
            action: np.ndarray,
        ):
            """calculate value, logprob of supplied `action`, and entropy"""

            if args.action_type == "discrete":
                logits = actor.apply(actor_state_params, x, indices=actor_state.indices)
                value = critic.apply(critic_state_params, x).squeeze()
                
                action_distribution = distrax.Categorical(logits=logits)
                logprob = action_distribution.log_prob(action)
                entropy = action_distribution.entropy()
            else:
                result = actor.apply(actor_state_params, x, indices=actor_state.indices)
                action_distribution = distrax.MultivariateNormalDiag(result[0], jnp.exp(result[1]))

                value = critic.apply(critic_state_params, x).squeeze()
                logprob = action_distribution.log_prob(action)
                entropy = action_distribution.entropy()

            return logprob, entropy, value
         
        def compute_gae_once(carry, inp, gamma, gae_lambda):
            advantages = carry
            nextdone, nextvalues, curvalues, reward = inp
            nextnonterminal = 1.0 - nextdone
    
            delta = reward + gamma * nextvalues * nextnonterminal - curvalues
            advantages = delta + gamma * gae_lambda * nextnonterminal * advantages
            return advantages, advantages
    
        compute_gae_once = partial(compute_gae_once, gamma=args.gamma, gae_lambda=args.gae_lambda)
        
        @jax.jit
        def compute_gae(
            critic_state: TrainState,
            next_obs: np.ndarray,
            next_done: np.ndarray,
            storage: Storage,
        ):
            next_value = critic.apply(critic_state.params, next_obs).squeeze()
    
            advantages = jnp.zeros((args.n_envs,))
            dones = jnp.concatenate([storage.dones, next_done[None, :]], axis=0)
            values = jnp.concatenate([storage.values, next_value[None, :]], axis=0)
            _, advantages = jax.lax.scan(
                compute_gae_once, advantages, (dones[1:], values[1:], values[:-1], storage.rewards), reverse=True
            )
            storage = storage.replace(
                advantages=advantages,
                returns=advantages + storage.values,
            )
            return storage

        @jax.jit
        def ppo_loss_base(actor_state_params, critic_state_params, x, a, logp, mb_advantages, mb_returns):
            newlogprob, entropy, newvalue = get_action_and_value2(actor_state_params, critic_state_params, x, a)
            logratio = newlogprob - logp
            ratio = jnp.exp(logratio)
            approx_kl = ((ratio - 1) - logratio).mean()
    
            if args.norm_adv:
                mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
    
            # Policy loss
            if False: #PPO loss (clipped surrogate)
                pg_loss1 = -mb_advantages * ratio
                pg_loss2 = -mb_advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
                pg_loss = jnp.maximum(pg_loss1, pg_loss2).mean()
            else: #A2C loss
                pg_loss = (-mb_advantages * ratio).mean()    
    
            # Value loss
            v_loss = 0.5 * ((newvalue - mb_returns) ** 2).mean()
    
            entropy_loss = entropy.mean()
            loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
            return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl))
        ppo_loss_base_grad_fn = jax.value_and_grad(ppo_loss_base, argnums=(0, 1), has_aux=True)
        
        @jax.jit
        def update_ppo(
            actor_state: TrainState,
            critic_state: TrainState,                
            storage: Storage,
            key: jax.random.PRNGKey,
            accumulate_gradients_every: int,
        ):
            def update_epoch(carry, unused_inp):
                actor_state, critic_state, key = carry
                key, subkey = jax.random.split(key)
    
                def flatten(x):
                    return x.reshape((-1,) + x.shape[2:])
    
                # taken from: https://github.com/google/brax/blob/main/brax/training/agents/ppo/train.py
                def convert_data(x: jnp.ndarray):
                    num_minibatches = int(np.floor(x.shape[0] / minibatch_size))
                    size = num_minibatches * minibatch_size
                    x = jax.random.permutation(subkey, x)[:size]
                    x = jnp.reshape(x, (num_minibatches, -1) + x.shape[1:])
                    return x
    
                flatten_storage = jax.tree_map(flatten, storage)
                shuffled_storage = jax.tree_map(convert_data, flatten_storage)
    
                def update_minibatch(carry, minibatch):
                    actor_state, critic_state = carry
                    (loss, (pg_loss, v_loss, entropy_loss, approx_kl)), (actor_grads, critic_grads) = ppo_loss_base_grad_fn(
                        actor_state.params,
                        critic_state.params,              
                        minibatch.obs,
                        minibatch.actions,
                        minibatch.logprobs,
                        minibatch.advantages,
                        minibatch.returns,
                    )
                    critic_state = critic_state.apply_gradients(grads=critic_grads)
                    actor_grad_accum = jax.tree_util.tree_map(lambda x, y: x + y, actor_grads, actor_state.grad_accum)
                    actor_state = actor_state.apply_gradients(grads=actor_grads)
            
                    def update_fn():
                        grads = jax.tree_util.tree_map(lambda x: x / accumulate_gradients_every, actor_grad_accum)
                        new_state = actor_state.apply_gradients(
                            grads=grads,
                            grad_accum=jax.tree_util.tree_map(jnp.zeros_like, grads),
                        )
                        return new_state
            
                    actor_state = jax.lax.cond(
                        actor_state.step % accumulate_gradients_every == 0,
                        lambda _: update_fn(),
                        lambda _: actor_state.replace(grad_accum=actor_grad_accum, step=actor_state.step + 1),
                        None,
                    )
                    
                    return (actor_state, critic_state), (loss, pg_loss, v_loss, entropy_loss, approx_kl, actor_grad_accum)
                
                (actor_state, critic_state), (loss, pg_loss, v_loss, entropy_loss, approx_kl, actor_grad_accum) = jax.lax.scan(
                    update_minibatch, (actor_state, critic_state), shuffled_storage
                )
                return (actor_state, critic_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, actor_grad_accum)
    
            (actor_state, critic_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, actor_grad_accum) = jax.lax.scan(
                update_epoch, (actor_state, critic_state, key), (), length=1 # only update a single epoch (args.n_update_epochs = 1)
            )
            return actor_state, critic_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key

        
        
        def create_rollout(n_steps, envs):
            def rollout_(actor_state, critic_state, episode_stats, next_obs, next_done, storage, key, global_step):
                for step in range(0, n_steps):
                    global_step += args.n_envs
                    storage, action, key = get_action_and_value(
                        actor_state, critic_state, next_obs, next_done, storage, step, key
                    )
                    # TRY NOT TO MODIFY: execute the game and log data.
                    action = np.array(action)
                    if any(substring in args.env_id for substring in ['MultiRoom', 'Unlock', 'GoToDoor', 'UnlockPickup', 'DoorKey', 'RedBlueDoors']):
                        action = np.array([action_indices[single_action] for single_action in action])
                    next_obs, reward, next_done, trunc, info = envs.step(action)
                    new_episode_return = episode_stats.episode_returns + reward
                    new_episode_length = episode_stats.episode_lengths + 1
                    episode_stats = episode_stats.replace(
                        episode_returns=(new_episode_return) * (1 - next_done) * (1 - trunc),
                        episode_lengths=(new_episode_length) * (1 - next_done) * (1 - trunc),
                        # only update the `returned_episode_returns` if the episode is done
                        returned_episode_returns=jnp.where(
                            next_done + trunc,
                            new_episode_return,
                            episode_stats.returned_episode_returns,
                        ),
                        returned_episode_lengths=jnp.where(
                            next_done + trunc,
                            new_episode_length,
                            episode_stats.returned_episode_lengths,
                        ),
                    )
                    storage = storage.replace(rewards=storage.rewards.at[step].set(reward))
                return actor_state, critic_state, episode_stats, next_obs, next_done, storage, key, global_step       
            return rollout_
        
        global_step = 0
        next_obs, _ = envs.reset(seed=env_seed)
        
        next_done = np.zeros(args.n_envs).astype(bool)
 
        #hyperparameters = {key: value for key, value in vars(args).items()}
    
        # Save hyperparameters to wandb
        #wandb.config.update(hyperparameters)
        avg_score_list = []
        iteration = 1
        last_eval = 0
        n_steps_old = 0

        avg_episodic_return_list = []
        total_time_cleaned = 0
        
        while global_step < args.total_steps:
            #for iteration in range(1, n_iterations + 1):            
            wandb_log = {}
            # ALGO Logic: Storage setup
            #increase_index = global_step // (args.total_steps//len(increase_factor_list))

            if args.dynamic_buffer:
                #increase_index = global_step // (args.total_steps//sum(increase_factor_list))
                increase_factor = int(2**(np.ceil((((global_step+1)*8)/(1+args.total_steps)))-1)) # int(increase_factor_list_long[increase_index])
                increase_factor_batch = int(2**(np.ceil((((global_step+1)*8)/(1+args.total_steps)))-1)) # int(increase_factor_list_long[increase_index])
                if args.dynamic_buffer:
                    n_steps = initial_steps * increase_factor 
                else:
                    n_steps = initial_steps                       
                batch_size = int(args.n_envs * n_steps)
                #n_iterations = args.total_steps // batch_size
                #eval_freq = max(args.eval_freq // batch_size, 1)
                current_eval = global_step // args.eval_freq
                if n_steps != n_steps_old:               
                    #compute_gae  = create_compute_gae(n_steps)
                    #update_ppo = create_update_ppo(batch_size, minibatch_size, accumulate_gradients_every)
                    rollout = create_rollout(n_steps, envs)
                    n_steps_old = n_steps
            else:
                if global_step == 0:
                    #compute_gae  = create_compute_gae(n_steps)
                    #update_ppo = create_update_ppo(batch_size, minibatch_size, accumulate_gradients_every)
                    rollout = create_rollout(n_steps, envs)
                current_eval = global_step // args.eval_freq   

            accumulate_gradients_every = np.ceil(batch_size / minibatch_size) # accumulate over complete batch!!! single update
            start_time_cleaned = time.time()
            
            storage = Storage(
                obs=jnp.zeros((n_steps, args.n_envs) + envs.single_observation_space.shape),
                actions=jnp.zeros((n_steps, args.n_envs) + envs.single_action_space.shape, dtype=jnp.int32),
                logprobs=jnp.zeros((n_steps, args.n_envs)),
                dones=jnp.zeros((n_steps, args.n_envs)),
                values=jnp.zeros((n_steps, args.n_envs)),
                advantages=jnp.zeros((n_steps, args.n_envs)),
                returns=jnp.zeros((n_steps, args.n_envs)),
                rewards=jnp.zeros((n_steps, args.n_envs)),
            )
            actor_state, critic_state, episode_stats, next_obs, next_done, storage, key, global_step = rollout(
                actor_state, critic_state, episode_stats, next_obs, next_done, storage, key, global_step
            )
            storage = compute_gae(critic_state, next_obs, next_done, storage)
            actor_state, critic_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key = update_ppo(
                actor_state,
                critic_state,
                storage,
                key,
                accumulate_gradients_every,
            )
            elapsed_time_cleaned = time.time() - start_time_cleaned
            total_time_cleaned += elapsed_time_cleaned
            
            avg_episodic_return = np.mean(np.array(episode_stats.returned_episode_returns))
            avg_episodic_return_list.append(avg_episodic_return)
            #writer.add_scalar("charts/avg_train_episodic_return", avg_episodic_return, global_step)
            if iteration == 1 or current_eval > last_eval or global_step + batch_size >= args.total_steps:
           
                last_eval = current_eval
                render_now = True if args.render_each_eval else True if global_step + batch_size >= args.total_steps else False

                def fit_stateActionDT(actor_state, env_id, n_episodes, name_appendix, seed=1_000):
                    
                    action_obs_store = ObservationActionBuffer(
                        #obs=jnp.zeros((n_steps, args.n_envs) + envs.single_observation_space.shape),
                        obs=jnp.zeros((n_steps, n_episodes) + envs.single_observation_space.shape),
                        #actions=jnp.zeros((n_steps, args.n_envs) + envs.single_action_space.shape,
                        actions=jnp.zeros((n_steps, n_episodes) + envs.single_action_space.shape,
                                          dtype=jnp.int32)
                    )
                
                    total_eval_steps = 0
                    for episode_index in range(n_episodes):
                        #temp_env = build_env(env_id, n_env=1)
                        temp_env = build_env(env_id, n_env=1, view_size=args.view_size)   

                        done, trunc = False, False
                        obs, info = temp_env.reset(seed=seed + episode_index)#random.randint(0, 1000))
                        step_counter = 0
                        while not done and not trunc:
                            actor_params = actor_state.params
                
                            if args.action_type == 'discrete':
                                action_logits = actor.apply(actor_params, np.array([obs]), indices=actor_state.indices)
                                action = jnp.argmax(action_logits, axis=1)
                                action = jnp.squeeze(action, axis=0) #jnp.squeeze(action, axis=0) if action.shape[0] == 1 else action #action[0]
                            else:
                                result = actor.apply(actor_params, np.array([obs]), indices=actor_state.indices)
                                action_distribution = distrax.MultivariateNormalDiag(result[0], jnp.exp(result[1]))
                                action = action_distribution.mean()
                                action = jnp.squeeze(action, axis=0)
                
                            action_obs_store = action_obs_store.replace(
                                obs=storage.obs.at[total_eval_steps].set(obs),
                                actions=storage.actions.at[total_eval_steps].set(action)
                            )

                            action = np.array(action)
                            if any(substring in args.env_id for substring in ['MultiRoom', 'Unlock', 'GoToDoor', 'UnlockPickup', 'DoorKey', 'RedBlueDoors']):
                                action = action_indices[action]                            
                            next_obs, rewards, done, trunc, info = temp_env.step(action)
                
                            obs = next_obs
                            step_counter += 1
                            total_eval_steps += 1
                
                        temp_env.close()
                
                    # Initialize decision tree
                    if args.action_type == 'discrete':
                        decision_tree = DecisionTreeClassifier(max_depth=args.depth)
                    else:
                        if action_dim == 1:
                            decision_tree = DecisionTreeRegressor(max_depth=args.depth)
                        else:
                            decision_tree = [DecisionTreeRegressor(max_depth=args.depth) for _ in range(action_dim)]
                
                    # Train the decision tree
                    X = np.array(action_obs_store.obs).reshape(-1, temp_env.observation_space.shape[-1])
                           
                    if args.action_type == 'discrete' or action_dim == 1:
                        y = np.array(action_obs_store.actions).reshape(-1)
                        decision_tree.fit(X, y)
                    else:
                        for i in range(action_dim):
                            y = np.array(action_obs_store.actions).reshape(-1, action_dim)
                            decision_tree[i].fit(X, y[:,i])
                    
                    return decision_tree
                
                def evaluate_agent(actor_state, env_id, n_episodes, name_appendix, seed=100, decision_tree=None):
                    video_folder = 'videos/wandb'
                    if not os.path.exists(video_folder):
                        os.makedirs(video_folder)
                    #temp_env = Monitor(temp_env, video_folder) #, force=True        
    
                    score = []
                    score_interpretable = []
                    node_count = 0
                    
                    for episode_index in range(n_episodes):
                        if args.actor == "stateActionDT":
                            #temp_env = build_env(env_id, n_env=1)
                            temp_env = build_env(env_id, n_env=1, view_size=args.view_size)               
                            video_path = os.path.join(video_folder, run_name + "-" + "-" + env_id  + str(episode_index) + ".mp4")
                            image_path = os.path.join(video_folder, run_name + "-" + "-" + env_id)#  + str(episode_index))
                            
                            done, trunc = False, False
                            obs, info = temp_env.reset(seed=seed + episode_index)#random.randint(0, 1000))
                            running_reward = 0
                            frames = []
                            dones = False
                            step_counter = 0
                            while not done and not trunc:
                                if args.render_env and render_now:
                                    frame = temp_env.render()
                            
                                    image = Image.fromarray(frame)
                                    draw = ImageDraw.Draw(image)
                                    text_step = f'Step: {step_counter}'
                                    font_size = frame.shape[0] // 20
                                    draw.text((font_size, font_size*0.5), text_step, (200, 200, 200), font=ImageFont.truetype("DejaVuSansMono-Bold.ttf", font_size))
                                    text_reward = f'Reward: {running_reward}'
                                    draw.text((font_size, font_size*2.0), text_reward, (200, 200, 200), font=ImageFont.truetype("DejaVuSansMono-Bold.ttf", font_size))
                                    
                                    frames.append(np.array(image))
                                    
                                flat_obs = obs.reshape(1, -1)
                                if args.action_type == 'discrete':
                                    action = decision_tree.predict(flat_obs)[0]
                                else:
                                    if action_dim == 1:
                                        action = decision_tree.predict(flat_obs)
                                    else:
                                        action_list = []
                                        for i in range(action_dim):
                                            action_by_tree = decision_tree[i].predict(flat_obs)[0]
                                            action_list.append(action_by_tree)
                                        action = np.array(action_list)
                                
                                if any(substring in args.env_id for substring in ['MultiRoom', 'Unlock', 'GoToDoor', 'UnlockPickup', 'DoorKey', 'RedBlueDoors']):
                                    action = action_indices[action]
                                next_obs, rewards, done, trunc, info = temp_env.step(action)
                                
                                running_reward += rewards
    
                                obs = next_obs
                                step_counter += 1
                                
                            score_interpretable.append(running_reward)
                            if (args.render_env and render_now):
                                if False:
                                    frame = temp_env.render()
                            
                                    image = Image.fromarray(frame)
                                    draw = ImageDraw.Draw(image)
                                    text_step = f'Step: {step_counter}'
                                    font_size = frame.shape[0] // 20
                                    draw.text((font_size, font_size*0.5), text_step, (200, 200, 200), font=ImageFont.truetype("DejaVuSansMono-Bold.ttf", font_size))
                                    text_reward = f'Reward: {running_reward}'
                                    draw.text((font_size, font_size*2.0), text_reward, (200, 200, 200), font=ImageFont.truetype("DejaVuSansMono-Bold.ttf", font_size))
                                    
                                    frames.append(np.array(image))
            
                                    numpy_clip = np.transpose(np.array(frames), (0, 3, 1, 2)) 
                                    fps = 5 if 'MiniGrid' in env_id else 25
                                    if args.track:
                                        wandb.log({"gameplay_" + name_appendix + '_trial' + str(episode_index): wandb.Video(numpy_clip, fps=fps, format="mp4")}, commit=False)
                                    #wandb_log["gameplay_" + name_appendix + '_trial' + str(episode_index)] = wandb.Video(numpy_clip, fps=fps, format="mp4")

                                if episode_index==0:
                                    if args.action_type == 'discrete' or action_dim == 1:
                                        # Plot the decision tree
                                        plt.figure(figsize=(20, 10))
                                        plot_tree(decision_tree, filled=True)
                                        plt.title("Decision Tree")
                                
                                        image_path = os.path.join(video_folder, run_name + "-" + "-" + args.env_id)
                                        plot_filename = image_path + "state_action_DT.png"
                                        plt.savefig(plot_filename)
                                        plt.close()
                                
                                        # Log the image to wandb
                                        if args.track:
                                            wandb.log({"state_action_DT": wandb.Image(plot_filename)})
                                        node_count += decision_tree.tree_.node_count
                                    else:
                                        node_count = 0
                                        for i in range(action_dim):                    
                                            # Plot the decision tree
                                            plt.figure(figsize=(20, 10))
                                            plot_tree(decision_tree[i], filled=True)
                                            plt.title("Decision Tree " + str(i))
                                    
                                            image_path = os.path.join(video_folder, run_name + "-" + "-" + args.env_id)
                                            plot_filename = image_path + "state_action_DT_reg" + str(i) + ".png"
                                            plt.savefig(plot_filename)
                                            plt.close()
                                            node_count += decision_tree[i].tree_.node_count
                                            # Log the image to wandb
                                            if args.track:
                                                wandb.log({"state_action_DT_" + str(i): wandb.Image(plot_filename)})
                                
                            temp_env.close()

                        
                        elif args.actor == "d-sdt":

                            #temp_env = build_env(env_id, n_env=1)
                            temp_env = build_env(env_id, n_env=1, view_size=args.view_size)               
                            video_path = os.path.join(video_folder, run_name + "-" + "-" + env_id  + str(episode_index) + ".mp4")
                            image_path = os.path.join(video_folder, run_name + "-" + "-" + env_id)#  + str(episode_index))
                            
                            done, trunc = False, False
                            obs, info = temp_env.reset(seed=seed + episode_index)#random.randint(0, 1000))
                            running_reward = 0
                            frames = []
                            dones = False
                            step_counter = 0
                            while not done and not trunc:
                                if args.render_env and render_now:
                                    frame = temp_env.render()
                            
                                    image = Image.fromarray(frame)
                                    draw = ImageDraw.Draw(image)
                                    text_step = f'Step: {step_counter}'
                                    font_size = frame.shape[0] // 20
                                    draw.text((font_size, font_size*0.5), text_step, (200, 200, 200), font=ImageFont.truetype("DejaVuSansMono-Bold.ttf", font_size))
                                    text_reward = f'Reward: {running_reward}'
                                    draw.text((font_size, font_size*2.0), text_reward, (200, 200, 200), font=ImageFont.truetype("DejaVuSansMono-Bold.ttf", font_size))
                                    
                                    frames.append(np.array(image))
                                    
                                actor_params = actor_state.params
                                actor_params_discrete = convert_to_discrete_tree(actor_params, args.action_type, temperature=args.temperature)
                                if args.action_type == 'discrete':
                                    action_logits = actor.apply(actor_params_discrete, np.array([obs]), max_path=True, indices=actor_state.indices)
                                    
                                    action = jnp.argmax(action_logits, axis=1)
                                    action = jnp.squeeze(action, axis=0) #jnp.squeeze(action, axis=0) if action.shape[0] == 1 else action #action[0]
                                else:
                                    result = actor.apply(actor_params_discrete, np.array([obs]), max_path=True, indices=actor_state.indices)

                                    action_distribution = distrax.MultivariateNormalDiag(result[0], jnp.exp(result[1]))
                                    action = action_distribution.mean()
                                    action = jnp.squeeze(action, axis=0)    
                                        
                                action = np.array(action)
    
                                if any(substring in args.env_id for substring in ['MultiRoom', 'Unlock', 'GoToDoor', 'UnlockPickup', 'DoorKey', 'RedBlueDoors']):
                                    action = action_indices[action]
                                next_obs, rewards, done, trunc, info = temp_env.step(action)
                                
                                running_reward += rewards
    
                                obs = next_obs
                                step_counter += 1
                                
                            score_interpretable.append(running_reward)
                            if (args.render_env and render_now):
                                if False:
                                    frame = temp_env.render()
                            
                                    image = Image.fromarray(frame)
                                    draw = ImageDraw.Draw(image)
                                    text_step = f'Step: {step_counter}'
                                    font_size = frame.shape[0] // 20
                                    draw.text((font_size, font_size*0.5), text_step, (200, 200, 200), font=ImageFont.truetype("DejaVuSansMono-Bold.ttf", font_size))
                                    text_reward = f'Reward: {running_reward}'
                                    draw.text((font_size, font_size*2.0), text_reward, (200, 200, 200), font=ImageFont.truetype("DejaVuSansMono-Bold.ttf", font_size))
                                    
                                    frames.append(np.array(image))
            
                                    numpy_clip = np.transpose(np.array(frames), (0, 3, 1, 2)) 
                                    fps = 5 if 'MiniGrid' in env_id else 25
                                    if args.track:
                                        wandb.log({"gameplay_" + name_appendix + '_trial' + str(episode_index): wandb.Video(numpy_clip, fps=fps, format="mp4")}, commit=False)
                                        #wandb_log["gameplay_" + name_appendix + '_trial' + str(episode_index)] = wandb.Video(numpy_clip, fps=fps, format="mp4")
                                if episode_index==0:
                                    split_values = actor_params_discrete['params']['SDT_0']['inner_nodes']['layers_0']['kernel'].T * jnp.expand_dims(actor_params_discrete['params']['SDT_0']['inner_nodes']['layers_0']['bias'],1)
                                    split_indices = actor_params_discrete['params']['SDT_0']['inner_nodes']['layers_0']['kernel'].T
                                    leaf_values = actor_params_discrete['params']['SDT_0']['leaf_nodes']['kernel']
                                    
                                    image_path, node_count = plot_decision_tree(
                                                                    split_values=split_values, 
                                                                    split_indices=split_indices, 
                                                                    leaf_values=leaf_values,
                                                                    features_by_estimator=[i for i in range(obs_dim)],
                                                                    image_path=image_path,
                                                                    observation_labels=None if args.env_id not in OBSERVATION_LABELS.keys() else OBSERVATION_LABELS[args.env_id],
                                                                    filename_appendix = 'D-SDT',
                                                                    env=temp_env, 
                                                                    #env_params=None, 
                                                                    prune=True,
                                                                    continuous = args.action_type != 'discrete'
                                                                   )
                                    image_path_plot = image_path + '.png'   
                                    if args.track:  
                                        wandb.log({"D-SDT_"+ name_appendix + '_trial' + str(episode_index): wandb.Image(image_path_plot)}, commit=False)
                                    
                                    image_path_complete = image_path + '_COMPLETE'
                                    image_path_complete, _ = plot_decision_tree(
                                                                    split_values=split_values, 
                                                                    split_indices=split_indices, 
                                                                    leaf_values=leaf_values,
                                                                    features_by_estimator=[i for i in range(obs_dim)],
                                                                    image_path=image_path_complete,
                                                                    observation_labels=None if args.env_id not in OBSERVATION_LABELS.keys() else OBSERVATION_LABELS[args.env_id],
                                                                    filename_appendix = 'D-SDT',
                                                                    env=temp_env, 
                                                                    #env_params=None,
                                                                    prune=False,
                                                                    continuous = args.action_type != 'discrete'
                                                                   )
                                    
                                    image_path_plot = image_path_complete + '.png'     
                                    if args.track:
                                        wandb.log({"D-SDT_COMPLETE"+ name_appendix + '_trial' + str(episode_index): wandb.Image(image_path_plot)}, commit=False)
                                
                            temp_env.close()                            

                        #temp_env = build_env(env_id, n_env=1)
                        temp_env = build_env(env_id, n_env=1, view_size=args.view_size)               
                        video_path = os.path.join(video_folder, run_name + "-" + "-" + env_id  + str(episode_index) + ".mp4")
                        image_path = os.path.join(video_folder, run_name + "-" + "-" + env_id)#  + str(episode_index))
                        
                        done, trunc = False, False
                        obs, info = temp_env.reset(seed=seed + episode_index)#random.randint(0, 1000))
                        running_reward = 0
                        frames = []
                        dones = False
                        step_counter = 0
                        while not done and not trunc:
                            if args.render_env and render_now:
                                frame = temp_env.render()
                        
                                image = Image.fromarray(frame)
                                draw = ImageDraw.Draw(image)
                                text_step = f'Step: {step_counter}'
                                font_size = frame.shape[0] // 20
                                draw.text((font_size, font_size*0.5), text_step, (200, 200, 200), font=ImageFont.truetype("DejaVuSansMono-Bold.ttf", font_size))
                                text_reward = f'Reward: {running_reward}'
                                draw.text((font_size, font_size*2.0), text_reward, (200, 200, 200), font=ImageFont.truetype("DejaVuSansMono-Bold.ttf", font_size))
                                
                                frames.append(np.array(image))
                                
                            actor_params = actor_state.params
                            if args.action_type == 'discrete':
                                action_logits = actor.apply(actor_params, np.array([obs]), indices=actor_state.indices)
                                
                                action = jnp.argmax(action_logits, axis=1)
                                action = jnp.squeeze(action, axis=0) #jnp.squeeze(action, axis=0) if action.shape[0] == 1 else action #action[0]
                            else:
                                result = actor.apply(actor_params, np.array([obs]), indices=actor_state.indices)
                                action_distribution = distrax.MultivariateNormalDiag(result[0], jnp.exp(result[1]))
                                action = action_distribution.mean()
                                action = jnp.squeeze(action, axis=0)    
                                
                            action = np.array(action)

                            if any(substring in args.env_id for substring in ['MultiRoom', 'Unlock', 'GoToDoor', 'UnlockPickup', 'DoorKey', 'RedBlueDoors']):
                                action = action_indices[action]
                            next_obs, rewards, done, trunc, info = temp_env.step(action)
                            
                            running_reward += rewards

                            obs = next_obs
                            step_counter += 1
                            
                        score.append(running_reward)
                        if (args.render_env and render_now):
                            if False:
                                frame = temp_env.render()
                        
                                image = Image.fromarray(frame)
                                draw = ImageDraw.Draw(image)
                                text_step = f'Step: {step_counter}'
                                font_size = frame.shape[0] // 20
                                draw.text((font_size, font_size*0.5), text_step, (200, 200, 200), font=ImageFont.truetype("DejaVuSansMono-Bold.ttf", font_size))
                                text_reward = f'Reward: {running_reward}'
                                draw.text((font_size, font_size*2.0), text_reward, (200, 200, 200), font=ImageFont.truetype("DejaVuSansMono-Bold.ttf", font_size))
                                
                                frames.append(np.array(image))
        
                                numpy_clip = np.transpose(np.array(frames), (0, 3, 1, 2)) 
                                fps = 5 if 'MiniGrid' in env_id else 25
                                if args.track:
                                    wandb.log({"gameplay_" + name_appendix + '_trial' + str(episode_index): wandb.Video(numpy_clip, fps=fps, format="mp4")}, commit=False)
                            #wandb_log["gameplay_" + name_appendix + '_trial' + str(episode_index)] = wandb.Video(numpy_clip, fps=fps, format="mp4")
                                
                            if args.n_estimators <= 5 and args.actor == "sympol" and episode_index==0:
                                for estimator_number in range(args.n_estimators):
                                    filename_appendix = '_' + str(estimator_number)
                                    image_path, node_count = plot_decision_tree(
                                                                    split_values=actor_params['split_values'][estimator_number], 
                                                                    split_indices=actor_params['split_idx_array'][estimator_number], 
                                                                    leaf_values=actor_params['leaf_array'][estimator_number],
                                                                    features_by_estimator=actor_state.indices['features_by_estimator'][estimator_number],
                                                                    image_path=image_path,
                                                                    observation_labels=None if args.env_id not in OBSERVATION_LABELS.keys() else OBSERVATION_LABELS[args.env_id],
                                                                    filename_appendix = filename_appendix,
                                                                    env=temp_env, 
                                                                    #env_params=None, 
                                                                    prune=True,
                                                                    continuous = args.action_type != 'discrete'
                                                                   )
                                image_path_plot = image_path + '.png'  
                                if args.track:   
                                    wandb.log({"DT_"+ name_appendix + '_trial' + str(episode_index) + '_estNumber' + str(estimator_number): wandb.Image(image_path_plot)}, commit=False)
                                for estimator_number in range(args.n_estimators):
                                    filename_appendix = '_' + str(estimator_number)                                       
                                    image_path_complete = image_path + '_COMPLETE'
                                    image_path_complete, _ = plot_decision_tree(
                                                                    split_values=actor_params['split_values'][estimator_number], 
                                                                    split_indices=actor_params['split_idx_array'][estimator_number], 
                                                                    leaf_values=actor_params['leaf_array'][estimator_number],
                                                                    features_by_estimator=actor_state.indices['features_by_estimator'][estimator_number],
                                                                    image_path=image_path_complete,
                                                                    observation_labels=None if args.env_id not in OBSERVATION_LABELS.keys() else OBSERVATION_LABELS[args.env_id],
                                                                    filename_appendix = filename_appendix,
                                                                    env=temp_env, 
                                                                    #env_params=None,
                                                                    prune=False,
                                                                    continuous = args.action_type != 'discrete'
                                                                   )
                                    
                                image_path_plot = image_path_complete + '.png'     
                                if args.track:
                                    wandb.log({"DT_COMPLETE"+ name_appendix + '_trial' + str(episode_index) + '_estNumber' + str(estimator_number): wandb.Image(image_path_plot)}, commit=False)
                            elif (args.actor == 'sdt' or args.actor == 'd-sdt') and episode_index==0:
    
                                split_values = actor_params['params']['SDT_0']['inner_nodes']['layers_0']['bias']
                                split_indices = actor_params['params']['SDT_0']['inner_nodes']['layers_0']['kernel'].T
                                leaf_values = actor_params['params']['SDT_0']['leaf_nodes']['kernel']
    
                                
                                image_path, node_count_sdt = plot_decision_tree_soft(
                                                                split_values=split_values, 
                                                                split_indices=split_indices, 
                                                                leaf_values=leaf_values,
                                                                image_path=image_path,
                                                                observation_labels=None if args.env_id not in OBSERVATION_LABELS.keys() else OBSERVATION_LABELS[args.env_id],
                                                                filename_appendix =  'SDT', 
                                                                temperature=args.temperature,
                                                               )
                                if args.actor == 'sdt':
                                    node_count = node_count_sdt                                
                                image_path_plot = image_path + '.png'    
                                if args.track: 
                                    wandb.log({"SDT_"+ name_appendix + '_trial' + str(episode_index): wandb.Image(image_path_plot)}, commit=False)
                            

                        
                        temp_env.close()
                        
                    #avg_score = np.mean(score)
                    #avg_score_interpretable = np.mean(score_interpretable)
                    #return avg_score.item(), avg_score_interpretable.item(), node_count
                    return score, score_interpretable, node_count
                    
                end_time = time.time()
                elapsed_time = end_time - start_time
                
                if args.actor == "stateActionDT":
                       
                    decision_tree = fit_stateActionDT(actor_state, 
                                                      args.env_id,
                                                      n_episodes=25,
                                                      name_appendix='',
                                                      seed=args.seed)                     
                else:
                    decision_tree = None
                
                score, score_interpretable, node_count = evaluate_agent(actor_state, args.env_id,
                                           n_episodes=args.n_eval_episodes,
                                           name_appendix='',
                                           decision_tree=decision_tree,
                                           seed=env_seed)

                avg_score = np.mean(score).item()
                avg_score_interpretable = np.mean(score_interpretable).item()
                std_score = np.std(score).item()
                std_score_interpretable = np.std(score_interpretable).item()                
                # use the negative avg score, since reduce on plataeu normally considers non-decreasing losses as a plataeu,
                # but we have a plataeu when the score is not increasing anymore
                if args.reduce_lr:
                    _, lr_scheduler_state = lr_scheduler.update(
                        updates=actor_state.params, state=lr_scheduler_state, value=avg_score
                    )
                    # [-1] is the adamw optimizer, while [0] would be the gradient clipping of the tx.chain
                    if args.actor != "sympol":
                        actor_state.opt_state[1].hyperparams["learning_rate"] = learning_rate_actor * lr_scheduler_state.scale
                    else:
                        actor_state.opt_state[1][0]['estimator_weights'][0].hyperparams["learning_rate"] = learning_rate_actor_weights * lr_scheduler_state.scale
                        actor_state.opt_state[1][0]['split_values'][0].hyperparams["learning_rate"] = learning_rate_actor_split_values * lr_scheduler_state.scale
                        actor_state.opt_state[1][0]['split_idx_array'][0].hyperparams["learning_rate"] = learning_rate_actor_split_idx_array * lr_scheduler_state.scale
                        actor_state.opt_state[1][0]['leaf_array'][0].hyperparams["learning_rate"] = learning_rate_actor_leaf_array * lr_scheduler_state.scale
                        actor_state.opt_state[1][0]['log_std'][0].hyperparams["learning_rate"] = learning_rate_actor_log_std * lr_scheduler_state.scale

                end_time = time.time()
                elapsed_time = end_time - start_time
                start_time = end_time
                if args.actor == "stateActionDT" or args.actor == "d-sdt":
                    print(f"global_step={global_step}, avg_eval_episodic_return={avg_score}, avg_eval_episodic_return_discrete={avg_score_interpretable} (Elapsed time: {elapsed_time} seconds)")
                    wandb_log['charts/avg_score'] = avg_score_interpretable
                    wandb_log['charts/avg_score_fully_complexity'] = avg_score
                    wandb_log['charts/std_score'] = std_score_interpretable
                    wandb_log['charts/std_score_fully_complexity'] = std_score
                    wandb_log['charts/score_interpretable_list'] = score_interpretable
                    wandb_log['charts/score_list'] = score
                    avg_score_list.append(avg_score_interpretable)                    
                else:
                    print(f"global_step={global_step}, avg_eval_episodic_return={avg_score} (Elapsed time: {elapsed_time} seconds)")
                    wandb_log['charts/avg_score'] = avg_score
                    wandb_log['charts/std_score'] = std_score
                    wandb_log['charts/score_list'] = score
                    
                    avg_score_list.append(avg_score)

                wandb_log['charts/node_count'] = node_count
                wandb_log['charts/total_time_cleaned'] = total_time_cleaned
                
                if global_step + batch_size >= args.total_steps: #TEST EVAL
                    test_seed = 123456
                    if args.actor == "stateActionDT":
                        decision_tree = fit_stateActionDT(actor_state, 
                                                          args.env_id,
                                                          n_episodes=25,
                                                          name_appendix='',
                                                          seed=args.seed)          
                    else:
                        decision_tree = None
                    
                    
                    score_test, score_interpretable_test, node_count_test = evaluate_agent(actor_state, args.env_id,
                                               n_episodes=args.n_eval_episodes,
                                               name_appendix='',
                                               decision_tree=decision_tree,
                                               seed=test_seed)
    
                    avg_score_test = np.mean(score_test).item()
                    avg_score_interpretable_test = np.mean(score_interpretable_test).item()
                    std_score_test = np.std(score_test).item()
                    std_score_interpretable_test = np.std(score_interpretable_test).item()                
                    # use the negative avg score, since reduce on plataeu normally considers non-decreasing losses as a plataeu,
                    # but we have a plataeu when the score is not increasing anymore
                    if args.actor == "stateActionDT" or args.actor == "d-sdt":
                        print(f"global_step={global_step}, avg_eval_episodic_return={avg_score_test}, avg_eval_episodic_return_discrete={avg_score_interpretable_test} (Elapsed time: {elapsed_time} seconds)")
                        wandb_log['charts/avg_score_test'] = avg_score_interpretable_test
                        wandb_log['charts/avg_score_fully_complexity_test'] = avg_score_test
                        wandb_log['charts/std_score_test'] = std_score_interpretable_test
                        wandb_log['charts/std_score_fully_complexity_test'] = std_score_test
                        wandb_log['charts/score_list_test'] = score_interpretable_test
                        wandb_log['charts/score_fully_complexity_list_test'] =  score_test 

                    else:
                        print(f"global_step={global_step}, avg_eval_episodic_return={avg_score_test} (Elapsed time: {elapsed_time} seconds)")
                        wandb_log['charts/avg_score_test'] = avg_score_test
                        wandb_log['charts/std_score_test'] = std_score
                        wandb_log['charts/score_list_test'] = score_test
                            
                    wandb_log['charts/node_count_test'] = node_count_test

                    
                
                try:
                    complexity_add = 1
                    while False:
                        #Evaluate next complexity level
                        string_list = args.env_id.split("-")
                        complexity_level_new = str(int(string_list[-2][-1])+complexity_add)
                        string_list[-2] = string_list[-2][:-1] + complexity_level_new
                        env_id_new = "-".join(string_list)
                        avg_score = evaluate_agent(actor_state, env_id_new, n_episodes=args.n_eval_episodes, name_appendix='complexity+' + str(complexity_add))
                        # [-1] is the adamw optimizer, while [0] would be the gradient clipping of the tx.chain

                        print(f"global_step={global_step}, complexity={complexity_level_new} avg_eval_episodic_return={avg_score}")
                        #writer.add_scalar("charts/avg_score_complexity" + complexity_level_new, avg_score, global_step)
                        wandb_log["charts/avg_score_complexity" + complexity_level_new] = avg_score
                        complexity_add += 1
                except:
                    pass
    
            if args.checkpoint:
                end_time = time.time()
                elapsed_time = end_time - start_time
                checkpoint_path = os.path.join(args.path, args.run_name)
                os.makedirs(checkpoint_path, exist_ok=True)
                ckpt = {'sympol': actor_state}
                orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
                save_args = orbax_utils.save_args_from_target(ckpt)
                orbax_checkpointer.save(checkpoint_path, ckpt, save_args=save_args)    
                end_time = time.time()
                elapsed_time = end_time - start_time
            # TRY NOT TO MODIFY: record rewards for plotting purposes
            #writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step)
            #writer.add_scalar(
            #    "charts/avg_episodic_length", np.mean(np.array(episode_stats.returned_episode_lengths)), global_step
            #)
            #writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
            #writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
            #writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
            #writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
            #writer.add_scalar("losses/loss", loss.item(), global_step)
            wandb_log['charts/global_step'] = global_step
            wandb_log['charts/avg_episodic_return'] = avg_episodic_return
            wandb_log['charts/avg_episodic_return_100'] = np.mean(avg_episodic_return_list[-100:])
            wandb_log['charts/avg_episodic_return_10'] = np.mean(avg_episodic_return_list[-10:])
            wandb_log['charts/avg_episodic_length'] = np.mean(np.array(episode_stats.returned_episode_lengths))
            try:
                wandb_log['losses/value_loss'] = np.mean(v_loss[-1])#.item()
                wandb_log['losses/policy_loss'] = np.mean(pg_loss[-1])#.item()
                wandb_log['losses/entropy'] = np.mean(entropy_loss[-1])#.item()
                wandb_log['losses/approx_kl'] = np.mean(approx_kl[-1])#.item()
                wandb_log['losses/loss'] = np.mean(loss[-1])#.item()
            except:
                wandb_log['losses/value_loss'] = v_loss#.item()
                wandb_log['losses/policy_loss'] = pg_loss#.item()
                wandb_log['losses/entropy'] = entropy_loss#.item()
                wandb_log['losses/approx_kl'] = approx_kl#.item()
                wandb_log['losses/loss'] = loss#.item()           
            if args.track:     
                wandb.log(wandb_log)   
        
            iteration = iteration + 1
        if args.track:
            wandb_run.finish()
        envs.close()
        
        #trial_scores.append(np.mean(avg_score_list[-5:]))
        trial_scores.append(avg_score_list[-1])
        
    if queue is None:
        return np.mean(trial_scores)
    else:
        queue.put(np.mean(trial_scores))  # Put the result in the queue

import multiprocessing

def multiprocessing_objective_fn(args, trial):
    queue = multiprocessing.Queue()
    p = multiprocessing.Process(target=train_agent, args=(args, trial, queue))
    p.start()
    p.join()
    result = queue.get()
    return result

if __name__ == "__main__":
    from optuna.storages import RDBStorage
    from sqlalchemy import create_engine
    import socket


    args = get_args()
    print(args)
    if args.optimize_config:
        # used to save information about trials, delete that if you want to start new trials, e.g. after changing the range
        # of hyperparamters or adding/ removing some hyperparameters
        storage = "sqlite:///hpopt_database_" + socket.gethostname() + ".db"

        # Step 2: Create the engine with the specified timeout
        #engine = create_engine("sqlite:///optuna_database.db", connect_args={'timeout': 30})
        
        # Step 3: Use this engine to create the Optuna storage
        #storage = RDBStorage("sqlite:///optuna_database_rdb.db")
        
        study = optuna.create_study(
            direction="maximize", storage=storage, load_if_exists=True, study_name=args.exp_name + '__' + args.env_id + '__' + args.run_name + '__' + args.actor
        )
        objective_fn = functools.partial(multiprocessing_objective_fn, args)
        #objective_fn = functools.partial(train_agent, args)
        # wandb only works for n_jobs = 1 ! See README.md for more infos about that
        study.optimize(objective_fn, n_trials=args.n_trials, n_jobs=1)
        #study.optimize(objective_fn, n_trials=args.n_trials, n_jobs=1)
        
    else:
        train_agent(args, trial=None, queue=None)
