import os
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.training.train_state import TrainState
from utils import load_config # Assuming you have this utility
import pickle
from functools import partial
import time
import gymnax
from wrappers import LogWrapper, FlattenObservationWrapper
from common import Transition, DiscTransitionData, plot_il_metrics
from minatar_networks import ActorCritic, Discriminator 
from utils import make_expert_transitions
from utils import compute_policy_expert_divergence_minatar as compute_policy_expert_divergence
import optuna
from optuna.samplers import TPESampler

def make_train(config, activation, sample_expert_transitions, get_rewards_fn):
    config["NUM_UPDATES"] = (
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
    config["MINIBATCH_SIZE"] = (
        config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )

    max_timesteps = 2500 if config["ENV_NAME"] == "Freeway-MinAtar" else 1000

    # INIT ENV FOR MINATAR
    env, env_params = gymnax.make(config["ENV_NAME"])
    env = FlattenObservationWrapper(env) # Flatten obs for MLP
    env = LogWrapper(env) # Log returns and lengths

    # Get action and observation shapes
    num_actions = env.action_space(env_params).n
    obs_shape = env.observation_space(env_params).shape
    # Input shape for discriminator: depends on DISC_INP config
    if config['DISC_INP'] == 'sa':
        disc_input_shape = (np.prod(obs_shape) + num_actions,)
    elif config['DISC_INP'] == 'ss':
        disc_input_shape = (2 * np.prod(obs_shape),)
    elif config['DISC_INP'] == 's':
        disc_input_shape = (np.prod(obs_shape),)
    else:
        raise ValueError(f"Invalid DISC_INP: {config['DISC_INP']}")

    def linear_schedule(count):
        frac = (
            1.0
            - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
            / config["NUM_UPDATES"]
        )
        return config["LR"] * frac

    def train(rng):
        # INIT ACTOR-CRITIC NETWORK
        network = ActorCritic(
            num_actions, activation=activation
        )
        rng, _rng = jax.random.split(rng)
        init_x = jnp.zeros((1, *obs_shape)) # Add batch dim for init
        network_params = network.init(_rng, init_x)['params'] # Get params dict

        if config["ANNEAL_LR"]:
            actor_critic_tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(learning_rate=linear_schedule, eps=1e-5),
            )
        else:
            actor_critic_tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(config["LR"], eps=1e-5),
            )
        actor_critic_train_state = TrainState.create(
            apply_fn=network.apply,
            params=network_params,
            tx=actor_critic_tx,
        )

        # INIT DISCRIMINATOR
        disc_network = Discriminator(
            activation=activation,
            use_spectral_norm=config["USE_SPECTRAL_NORM"],
        )
        rng, _rng = jax.random.split(rng)
        init_disc_x = jnp.zeros(disc_input_shape) # Use the calculated shape
        # Initialize with train=True to potentially create batch_stats
        disc_variables = disc_network.init(_rng, init_disc_x, train=True)
        disc_params = disc_variables['params']
        disc_batch_stats = disc_variables.get('batch_stats', {}) # Use .get for safety

        disc_tx = optax.chain(
            optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
            optax.adam(learning_rate=float(config["DISC_LR"]), eps=1e-5), 
        )
        disc_train_state = TrainState.create(
            apply_fn=disc_network.apply,
            params=disc_params,
            tx=disc_tx,
        )

        # INIT ENV using vmap for Gymnax
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)

        def stagger_timesteps(rng, max_timesteps):
            return jax.random.randint(rng, (config["NUM_ENVS"],), 0, max_timesteps).astype(jnp.int32)

        rng, _rng = jax.random.split(rng)
        env_state = env_state.replace(env_state=env_state.env_state.replace(time=stagger_timesteps(rng, max_timesteps)))

        # TRAIN LOOP
        def _update_step(runner_state, unused):
            actor_critic_train_state, disc_train_state, disc_batch_stats, env_state, last_obs, rng = runner_state

            # COLLECT TRAJECTORIES
            def _env_step(runner_state, unused):
                actor_critic_train_state, env_state, last_obs, rng = runner_state

                # Since no normalization wrapper, last_obs is unnormalized
                last_obs_unnorm = last_obs

                # SELECT ACTION (Actor-Critic)
                rng, _rng = jax.random.split(rng)
                # Ensure last_obs has batch dim if needed by network (already has due to vmap)
                pi, value = network.apply({"params": actor_critic_train_state.params}, last_obs)
                action = pi.sample(seed=_rng)
                log_prob = pi.log_prob(action)

                # STEP ENV (Gymnax with vmap)
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["NUM_ENVS"])
                obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
                    rng_step, env_state, action, env_params
                )

                next_obsv = jnp.where(done[..., None], last_obs, obsv)

                # Reward is replaced by discriminator later
                transition = Transition(
                    done, action, value, jnp.zeros_like(reward), log_prob, last_obs, last_obs_unnorm, next_obsv, info
                )
                runner_state = (actor_critic_train_state, env_state, obsv, rng)
                return runner_state, transition

            # Scan over env steps
            scan_runner_state = (actor_critic_train_state, env_state, last_obs, rng)
            scan_runner_state, traj_batch = jax.lax.scan(
                _env_step, scan_runner_state, None, config["NUM_STEPS"]
            )
            # Unpack runner state after scan
            actor_critic_train_state, env_state, last_obs, rng = scan_runner_state


            # --- DISCRIMINATOR UPDATE ---
            def _update_epoch_discriminator(disc_update_state, unused):
                disc_train_state, disc_batch_stats, traj_batch, rng = disc_update_state

                def _update_minibatch(disc_mini_update_state, disc_transitions):
                    disc_train_state, disc_batch_stats, rng = disc_mini_update_state

                    # --- Gradient Penalty (Optional, only if not using Spectral Norm) ---
                    def grad_pen(params, batch_stats, disc_transitions, rng):
                        rng, _rng = jax.random.split(rng)
                        alpha_shape = (disc_transitions.pi_action.shape[0],) + (1,) * (disc_transitions.pi_unnorm_obs.ndim - 1)
                        alpha = jax.random.uniform(_rng, alpha_shape)

                        # Interpolate based on DISC_INP
                        if config['DISC_INP'] == 'sa':
                            inter_obs = alpha * disc_transitions.pi_unnorm_obs + (1 - alpha) * disc_transitions.expert_unnorm_obs
                            # Perform one-hot encoding here for interpolation
                            pi_action_onehot = jax.nn.one_hot(disc_transitions.pi_action, num_actions)
                            expert_action_onehot = jax.nn.one_hot(disc_transitions.expert_action, num_actions)
                            inter_act = alpha * pi_action_onehot + (1 - alpha) * expert_action_onehot
                            interpolated_input = jnp.concatenate([inter_obs, inter_act], axis=-1)
                        elif config['DISC_INP'] == 'ss':
                            inter_obs = alpha * disc_transitions.pi_unnorm_obs + (1 - alpha) * disc_transitions.expert_unnorm_obs
                            inter_next_obs = alpha * disc_transitions.pi_unnorm_next_obs + (1 - alpha) * disc_transitions.expert_unnorm_next_obs
                            interpolated_input = jnp.concatenate([inter_obs, inter_next_obs], axis=-1)
                        elif config['DISC_INP'] == 's':
                            interpolated_input = alpha * disc_transitions.pi_unnorm_obs + (1 - alpha) * disc_transitions.expert_unnorm_obs
                        else:
                            raise ValueError(f"Invalid DISC_INP: {config['DISC_INP']}")

                        def disc_out(inputs):
                            # Need to handle potential batch stats update within grad
                            output, updates = disc_network.apply(
                                {'params': params, 'batch_stats': batch_stats},
                                inputs,
                                train=True, # Important: Keep train=True for potential SN updates inside grad_pen
                                mutable=['batch_stats']
                            )
                            # We don't use the updated stats from here, just need the output for grad
                            return output.squeeze()

                        # We need the gradient w.r.t the *interpolated input*
                        grad_fn = jax.vmap(jax.grad(disc_out))
                        grads = grad_fn(interpolated_input)

                        grad_norms = jnp.linalg.norm(grads, axis=-1)
                        penalty = jnp.mean((grad_norms - 0.0) ** 2) 
                        return penalty, rng


                    # --- Cross Entropy Loss ---
                    def cross_entropy_loss(params, batch_stats, disc_transitions, rng):
                        # Prepare inputs based on DISC_INP configuration
                        if config['DISC_INP'] == 'sa':
                            policy_input = jnp.concatenate([disc_transitions.pi_unnorm_obs, jax.nn.one_hot(disc_transitions.pi_action, num_actions)], axis=-1)
                            expert_input = jnp.concatenate([disc_transitions.expert_unnorm_obs, jax.nn.one_hot(disc_transitions.expert_action, num_actions)], axis=-1)
                        elif config['DISC_INP'] == 'ss':
                            policy_input = jnp.concatenate([disc_transitions.pi_unnorm_obs, disc_transitions.pi_unnorm_next_obs], axis=-1)
                            expert_input = jnp.concatenate([disc_transitions.expert_unnorm_obs, disc_transitions.expert_unnorm_next_obs], axis=-1)
                        elif config['DISC_INP'] == 's':
                            policy_input = disc_transitions.pi_unnorm_obs
                            expert_input = disc_transitions.expert_unnorm_obs
                        else:
                            raise ValueError(f"Invalid DISC_INP: {config['DISC_INP']}")

                        combined_input = jnp.concatenate([policy_input, expert_input], axis=0)

                        # Single forward pass potentially updating batch stats
                        mutable_collections = ['batch_stats'] if config["USE_SPECTRAL_NORM"] else []
                        variables = {'params': params, 'batch_stats': batch_stats}

                        all_logits, state_updates = disc_network.apply(
                            variables,
                            combined_input,
                            train=True, # Enable updates for spectral norm / batch norm
                            mutable=mutable_collections
                        )

                        # Update batch_stats if they exist and were mutated
                        if config["USE_SPECTRAL_NORM"]:
                            batch_stats = state_updates['batch_stats']

                        # Split logits
                        pi_logits = all_logits[:policy_input.shape[0]]
                        expert_logits = all_logits[policy_input.shape[0]:]

                        # Labels: 0 for policy (fake), 1 for expert (real)
                        pi_labels = jnp.zeros_like(pi_logits)
                        expert_labels = jnp.ones_like(expert_logits)

                        # Calculate sigmoid cross-entropy loss
                        loss_pi = optax.sigmoid_binary_cross_entropy(pi_logits, pi_labels).mean()
                        loss_expert = optax.sigmoid_binary_cross_entropy(expert_logits, expert_labels).mean()
                        ce_loss = (loss_pi + loss_expert) / 2

                        return ce_loss, batch_stats, rng

                    # --- Combined Discriminator Loss Function ---
                    def _loss_fn(params, batch_stats, disc_transitions, rng):
                        ce_loss, updated_batch_stats, rng = cross_entropy_loss(params, batch_stats, disc_transitions, rng)
                        gp_loss, rng = grad_pen(params, updated_batch_stats, disc_transitions, rng) # L2 penalty on gradient norm
                        loss = ce_loss + config["GP_WEIGHT"] * gp_loss
                        
                        # Return updated batch stats from CE loss calculation
                        return loss, (ce_loss, gp_loss, updated_batch_stats, rng)

                    # Compute gradients and update discriminator state
                    grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                    # Pass current batch_stats
                    (loss, (ce_loss, gp_loss, updated_batch_stats, rng)), grads = grad_fn(disc_train_state.params, disc_batch_stats, disc_transitions, rng)
                    disc_train_state = disc_train_state.apply_gradients(grads=grads)
                    # IMPORTANT: Use the updated batch_stats for the next minibatch/epoch
                    return (disc_train_state, updated_batch_stats, rng), (loss, ce_loss, gp_loss)


                # Generate minibatches for discriminator training
                batch_size = config["NUM_STEPS"] * config["NUM_ENVS"]
                assert batch_size == config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"], "Batch size mismatch"

                # Reshape policy trajectory batch
                pi_batch = jax.tree_util.tree_map(
                    lambda x: x.reshape((batch_size,) + x.shape[2:]), traj_batch
                )

                # Sample expert batch
                expert_batch, rng = sample_expert_transitions(batch_size, rng)

                # Ensure observation shapes match after flattening/loading
                # Note: Expert obs might need explicit flattening if not stored that way
                assert pi_batch.unnorm_obs.shape[1:] == expert_batch.unnorm_obs.shape[1:], "Obs shapes must match"
                # print("Pi obs shape:", pi_batch.unnorm_obs.shape)
                # print("Expert obs shape:", expert_batch.unnorm_obs.shape)


                disc_transitions_batch = DiscTransitionData(
                    pi_action=pi_batch.action,
                    pi_unnorm_obs=pi_batch.unnorm_obs, # Already flattened by wrapper
                    pi_unnorm_next_obs=pi_batch.unnorm_next_obs, # Already flattened by wrapper
                    expert_action=expert_batch.action,
                    expert_unnorm_obs=expert_batch.unnorm_obs, # Assumes expert obs is compatible (flattened)
                    expert_unnorm_next_obs=expert_batch.unnorm_next_obs, # Assumes expert obs is compatible (flattened)
                )

                # Create minibatches
                rng, _rng = jax.random.split(rng)
                permutation = jax.random.permutation(_rng, batch_size)
                shuffled_batch = jax.tree_util.tree_map(
                    lambda x: jnp.take(x, permutation, axis=0), disc_transitions_batch
                )
                disc_transitions_minibatches = jax.tree_util.tree_map(
                    lambda x: jnp.reshape(x, [config["NUM_MINIBATCHES"], config["MINIBATCH_SIZE"]] + list(x.shape[1:])),
                    shuffled_batch,
                )

                # Scan over minibatches
                # Pass the potentially updated disc_batch_stats from the previous epoch
                disc_mini_update_state = (disc_train_state, disc_batch_stats, rng)
                disc_mini_update_state, losses = jax.lax.scan(
                    _update_minibatch, disc_mini_update_state, disc_transitions_minibatches
                )
                # Get the final state after minibatch updates
                disc_train_state, disc_batch_stats, rng = disc_mini_update_state

                disc_update_state = (disc_train_state, disc_batch_stats, traj_batch, rng)
                return disc_update_state, losses

            # Scan over discriminator update epochs
            # Pass initial batch_stats for the first epoch
            disc_epoch_update_state = (disc_train_state, disc_batch_stats, traj_batch, rng)
            disc_epoch_update_state, disc_losses_per_epoch = jax.lax.scan(
                _update_epoch_discriminator, disc_epoch_update_state, None, config["DISC_UPDATE_EPOCHS"]
            )
            # Get final discriminator state and batch stats after all epochs
            disc_train_state, disc_batch_stats, _, rng = disc_epoch_update_state

            # --- RELABEL POLICY TRANSITIONS WITH DISCRIMINATOR REWARDS ---
            # Prepare discriminator input (flattened obs + one-hot action)
            # Prepare policy input based on DISC_INP
            if config['DISC_INP'] == 'sa':
                policy_input_for_reward = jnp.concatenate([traj_batch.obs, jax.nn.one_hot(traj_batch.action, num_actions)], axis=-1)
            elif config['DISC_INP'] == 'ss':
                policy_input_for_reward = jnp.concatenate([traj_batch.obs, traj_batch.next_obs], axis=-1) # Assuming next_obs is stored
            elif config['DISC_INP'] == 's':
                policy_input_for_reward = traj_batch.obs
            else:
                raise ValueError(f"Invalid DISC_INP: {config['DISC_INP']}")

            # Get logits from discriminator (use train=False, don't update batch stats)
            logits_pi, _ = disc_network.apply(
                {'params': disc_train_state.params, 'batch_stats': disc_batch_stats},
                policy_input_for_reward,
                train=False, # Important: Don't update batch stats when just getting rewards
                mutable=['batch_stats']
            )
            # Calculate rewards using the specified reward function type
            calculated_rewards = get_rewards_fn(logits_pi)
            # Reshape rewards back to (num_steps, num_envs)
            rewards_reshaped = calculated_rewards.reshape(config["NUM_STEPS"], config["NUM_ENVS"])

            # Replace placeholder rewards in traj_batch
            traj_batch = traj_batch._replace(
                reward=jax.lax.stop_gradient(rewards_reshaped)
            )

            # --- CALCULATE ADVANTAGES (GAE) ---
            # Need last value from actor-critic
            _, last_val = network.apply({"params": actor_critic_train_state.params}, last_obs)

            def _calculate_gae(traj_batch, last_val):
                def _get_advantages(gae_and_next_value, transition):
                    gae, next_value = gae_and_next_value
                    done, value, reward = (
                        transition.done,
                        transition.value,
                        transition.reward,
                    )
                    # Convert boolean done to float (0.0 or 1.0)
                    done_float = jnp.where(done, 1.0, 0.0)
                    delta = reward + config["GAMMA"] * next_value * (1.0 - done_float) - value
                    gae = delta + config["GAMMA"] * config["GAE_LAMBDA"] * (1.0 - done_float) * gae
                    return (gae, value), gae # Return (new_carry, gae_for_this_step)

                # Scan backwards through transitions
                _, advantages = jax.lax.scan(
                    _get_advantages,
                    (jnp.zeros_like(last_val), last_val), # Initial carry: (gae=0, next_value=last_val)
                    traj_batch, # Transitions are scanned from last to first
                    reverse=True,
                    unroll=16, # Optimization hint
                )
                # Advantages are calculated, targets are advantages + values
                return advantages, advantages + traj_batch.value

            advantages, targets = _calculate_gae(traj_batch, last_val)

            # --- UPDATE ACTOR-CRITIC NETWORK (PPO) ---
            def _update_epoch_actor_critic(update_state, unused):
                def _update_minbatch(train_state, batch_info):
                    traj_batch, advantages, targets = batch_info

                    def _loss_fn(params, traj_batch, gae, targets):
                        # Rerun network
                        pi, value = network.apply({"params": params}, traj_batch.obs)
                        log_prob = pi.log_prob(traj_batch.action)

                        # Value loss (clipped)
                        value_pred_clipped = traj_batch.value + jnp.clip(
                            value - traj_batch.value, -config["CLIP_EPS"], config["CLIP_EPS"]
                        )
                        value_losses = jnp.square(value - targets)
                        value_losses_clipped = jnp.square(value_pred_clipped - targets)
                        value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()

                        # Policy loss (clipped PPO objective)
                        ratio = jnp.exp(log_prob - traj_batch.log_prob)
                        gae = (gae - gae.mean()) / (gae.std() + 1e-8) # Normalize advantages
                        loss_actor1 = ratio * gae
                        loss_actor2 = jnp.clip(
                            ratio, 1.0 - config["CLIP_EPS"], 1.0 + config["CLIP_EPS"]
                        ) * gae
                        loss_actor = -jnp.minimum(loss_actor1, loss_actor2).mean()

                        # Entropy loss
                        entropy = pi.entropy().mean()

                        # Total loss
                        total_loss = (
                            loss_actor
                            + config["VF_COEF"] * value_loss
                            - config["ENT_COEF"] * entropy
                        )
                        return total_loss, (value_loss, loss_actor, entropy)

                    # Compute gradients and update actor-critic state
                    grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                    (loss, (vloss, aloss, eloss)), grads = grad_fn(
                        train_state.params, traj_batch, advantages, targets
                    )
                    train_state = train_state.apply_gradients(grads=grads)
                    return train_state, (loss, vloss, aloss, eloss)

                # Prepare data for PPO update epochs
                train_state, traj_batch, advantages, targets, rng = update_state
                rng, _rng = jax.random.split(rng)
                batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
                permutation = jax.random.permutation(_rng, batch_size)

                # Flatten and shuffle data across environments and steps
                batch = (traj_batch, advantages, targets)
                batch = jax.tree_util.tree_map(
                    lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
                )
                shuffled_batch = jax.tree_util.tree_map(
                    lambda x: jnp.take(x, permutation, axis=0), batch
                )

                # Create minibatches
                minibatches = jax.tree_util.tree_map(
                    lambda x: jnp.reshape(
                        x, [config["NUM_MINIBATCHES"], config["MINIBATCH_SIZE"]] + list(x.shape[1:])
                    ),
                    shuffled_batch,
                )

                # Scan over minibatches
                train_state, ppo_losses = jax.lax.scan(
                    _update_minbatch, train_state, minibatches
                )
                update_state = (train_state, traj_batch, advantages, targets, rng)
                # Return average losses for the epoch
                return update_state, jax.tree_map(lambda x: x.mean(), ppo_losses)

            # Scan over PPO update epochs
            ppo_update_state = (actor_critic_train_state, traj_batch, advantages, targets, rng)
            ppo_update_state, ppo_losses_per_epoch = jax.lax.scan(
                _update_epoch_actor_critic, ppo_update_state, None, config["UPDATE_EPOCHS"]
            )

            # Get final actor-critic state and losses from the last epoch
            actor_critic_train_state = ppo_update_state[0]
            rng = ppo_update_state[-1]
            
            # Correctly extract PPO losses from the last update epoch
            # ppo_losses_per_epoch is (total_loss_stack, value_loss_stack, actor_loss_stack, entropy_stack)
            # Each stack has shape (UPDATE_EPOCHS,)
            last_epoch_ppo_total_loss = ppo_losses_per_epoch[0][-1]
            last_epoch_ppo_value_loss = ppo_losses_per_epoch[1][-1]
            last_epoch_ppo_actor_loss = ppo_losses_per_epoch[2][-1]
            last_epoch_ppo_entropy = ppo_losses_per_epoch[3][-1]

            # --- METRICS ---
            # Combine env info with calculated losses
            metric = traj_batch.info
            
            # Get losses from the last discriminator epoch and average over minibatches
            # disc_losses_per_epoch is (total_loss_stack, ce_loss_stack, gp_loss_stack)
            # Each stack has shape (DISC_UPDATE_EPOCHS, NUM_MINIBATCHES)
            last_epoch_total_loss = disc_losses_per_epoch[0][-1].mean()
            last_epoch_ce_loss = disc_losses_per_epoch[1][-1].mean()
            last_epoch_gp_loss = disc_losses_per_epoch[2][-1].mean()
            
            # Add discriminator metrics (use the averaged losses over minibatches from the last disc epoch)
            metric['disc_total_loss'] = last_epoch_total_loss
            metric['disc_ce_loss'] = last_epoch_ce_loss
            metric['disc_gp_loss'] = last_epoch_gp_loss
            metric['disc_rewards'] = traj_batch.reward.mean() # Mean of the rewards calculated by the discriminator
            metric['disc_logits'] = logits_pi
            
            # Add PPO metrics (use averaged losses from the last PPO epoch)
            metric['ppo_total_loss'] = last_epoch_ppo_total_loss
            metric['value_loss'] = last_epoch_ppo_value_loss
            metric['actor_loss'] = last_epoch_ppo_actor_loss
            metric['entropy'] = last_epoch_ppo_entropy

            # Prepare runner state for the next iteration
            runner_state = (actor_critic_train_state, disc_train_state, disc_batch_stats, env_state, last_obs, rng)
            return runner_state, metric

        # --- Run the main training loop ---
        rng, _rng = jax.random.split(rng)
        # Initial runner state
        initial_runner_state = (actor_critic_train_state, disc_train_state, disc_batch_stats, env_state, obsv, _rng)

        # Scan over updates
        final_runner_state, metric_history = jax.lax.scan(
            _update_step, initial_runner_state, None, config["NUM_UPDATES"]
        )

        # Generate eval transitions
        def _update_eval_step(runner_state, unused):
            actor_critic_train_state, env_state, last_obs, rng = runner_state

            # SELECT ACTION
            # Pass params wrapped in a dictionary as expected by Flax apply
            pi, value = network.apply({'params': actor_critic_train_state.params}, last_obs)
            rng, _rng = jax.random.split(rng)
            action = pi.sample(seed=_rng)
            log_prob = pi.log_prob(action)

            # STEP ENV
            rng, _rng = jax.random.split(rng)
            rng_step = jax.random.split(_rng, config.get("EVAL_NUM_ENVS", 10))
            obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
                rng_step, env_state, action, env_params
            )

            next_obsv = jnp.where(done[..., None], last_obs, obsv)

            transition = Transition(
                done=done,
                action=action,
                value=value,
                reward=reward,
                log_prob=log_prob,
                obs=last_obs,
                unnorm_obs=last_obs,
                unnorm_next_obs=next_obsv,
                info=info
            )

            runner_state = (actor_critic_train_state, env_state, obsv, rng)
            return runner_state, transition

        # Initialize eval environment
        rng, _rng = jax.random.split(final_runner_state[-1])
        reset_rng = jax.random.split(_rng, config.get("EVAL_NUM_ENVS", 10))
        eval_obsv, eval_env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)
        
        # Create initial eval runner state
        eval_runner_state = (final_runner_state[0], eval_env_state, eval_obsv, rng)

        # Run evaluation steps
        _, eval_transitions = jax.lax.scan(
            _update_eval_step, eval_runner_state, None, length=max_timesteps
        )
        
        return {"runner_state": final_runner_state, "metrics": metric_history, "eval_transitions": eval_transitions}

    return train

"""
Reward function
logits: ratio of densities \log(\frac{\rho^{E}(s,a)}{\rho^{\pi}(s,a)})
"""
def get_rewards(logits, config):
    if config['REWARD_TYPE'] == 'gail':
        return jax.nn.softplus(logits)          
    elif config['REWARD_TYPE'] == 'airl':
        return logits
    elif config['REWARD_TYPE'] == 'fairl':
        return -logits*jnp.exp(logits)
    elif config['REWARD_TYPE'] == 'logd':
        return -jax.nn.softplus(-logits)
    elif config['REWARD_TYPE'] == 'meta-disc':
        sigmoid_logits = jax.nn.sigmoid(logits)           # Smooth gating, values in (0,1)
        tanh_mod = 0.5 * (jnp.tanh(logits) + 1.0)         # Scales tanh from [-1,1] to [0,1]
        reward = sigmoid_logits * tanh_mod                 # Product yields reward in [0,1]
        return reward
    elif config['REWARD_TYPE'] == 'sigmoid':
        return jax.nn.sigmoid(logits)
    elif config['REWARD_TYPE'] == 'tanh_plus_1':
        return (jnp.tanh(logits) + 1) * 0.5
    else:
        raise ValueError(f"Unknown reward type: {config['REWARD_TYPE']}")
    
def objective(trial, base_config):
    """
    Optuna objective function for hyperparameter tuning.
    Returns the mean divergence across seeds (lower is better).
    """
    # Define hyperparameter search space
    config = base_config.copy()
        
    # Learning rates
    config['LR'] = trial.suggest_float('lr', 1e-5, 1e-3, log=True)
    config['DISC_LR'] = trial.suggest_float('disc_lr', 1e-5, 1e-3, log=True)
    
    # PPO parameters
    config['CLIP_EPS'] = trial.suggest_float('clip_eps', 0.1, 0.3)
    config['VF_COEF'] = trial.suggest_float('vf_coef', 0.25, 1.0)
    config['ENT_COEF'] = trial.suggest_float('ent_coef', 0.01, 0.1)
    
    # Discriminator parameters
    config['GP_WEIGHT'] = trial.suggest_float('gp_weight', 0.01, 100.0, log=True)
    # config['USE_SPECTRAL_NORM'] = trial.suggest_categorical('use_spectral_norm', [True, False])
    
    # Training parameters
    # config['NUM_STEPS'] = trial.suggest_categorical('num_steps', [128, 256, 512])
    # config['NUM_ENVS'] = trial.suggest_categorical('num_envs', [4, 8, 16])
    config['UPDATE_EPOCHS'] = trial.suggest_int('update_epochs', 1, 8)
    config['DISC_UPDATE_EPOCHS'] = trial.suggest_int('disc_update_epochs', 1, 10)
    
    # GAE parameters
    config['GAE_LAMBDA'] = trial.suggest_float('gae_lambda', 0.9, 0.99)
    
    print(f"Trial {trial.number}: Testing config with LR={config['LR']:.2e}, DISC_LR={config['DISC_LR']:.2e}")
    
    rng = jax.random.PRNGKey(42)
    sample_expert_transitions = make_expert_transitions(config)
    sample_expert_transitions = jax.jit(sample_expert_transitions, static_argnums=(0,))
    get_rewards_fn = partial(get_rewards, config=config)        
    get_rewards_fn = jax.jit(get_rewards_fn)

    dummy_env, dummy_env_params = gymnax.make(config["ENV_NAME"])
    num_actions = dummy_env.action_space(dummy_env_params).n
    
    train = make_train(config, jax.nn.relu, sample_expert_transitions, get_rewards_fn)
    
    def single_rollout(rng):
        results = train(rng)
        rng, _rng = jax.random.split(rng)
        policy_states, policy_actions = results['eval_transitions'].unnorm_obs, results['eval_transitions'].action
        if config['SUB_SAMPLE_RATE'] > 1:
            _policy_states = policy_states[::config['SUB_SAMPLE_RATE']]
            policy_states = jnp.concatenate([_policy_states, policy_states[-1:]], axis=0)
            _policy_actions = policy_actions[::config['SUB_SAMPLE_RATE']]
            policy_actions = jnp.concatenate([_policy_actions, policy_actions[-1:]], axis=0)
        divergence = compute_policy_expert_divergence(policy_states, policy_actions, _rng, sample_expert_transitions, num_actions)
        return results, divergence
        
    rng, _rng = jax.random.split(rng)
    rngs = jax.random.split(_rng, config['NUM_SEEDS'])

    device_count = jax.device_count()
    num_batches = config['NUM_SEEDS'] // device_count
    
    # Reshape rngs for batching
    try:
        batched_rngs = jnp.reshape(rngs, (device_count, num_batches) + rngs.shape[1:])
        print('Batched rngs shape:', batched_rngs.shape)
    except ValueError as e:
        print(f"Error reshaping RNGs: {e}. RNGs shape: {rngs.shape}, Target shape: {(device_count, num_batches) + rngs.shape[1:]}")
        raise e
    
    pmap_train = jax.pmap(jax.vmap(single_rollout, in_axes=(0,)), in_axes=(0,))
    
    results, divergences = pmap_train(batched_rngs)

    results = jax.tree_util.tree_map(
        lambda x: jnp.reshape(x, (config['NUM_SEEDS'],) + x.shape[2:]),
        results
    )
    # Combine results
    divergences = jnp.reshape(divergences, (config['NUM_SEEDS'],))
    
    mean_divergence = float(divergences.mean())
    mean_return = float(results['metrics']["returned_episode_returns"].mean(axis=(-1, -2, 0))[-1])
    print(f"Trial {trial.number}: Mean divergence = {mean_divergence:.4f} Mean return = {mean_return:.4f}")
    
    return mean_divergence
        
if __name__ == '__main__':

    def main(config):
        print('Training: ', config['ENV_NAME'])
        rng = jax.random.PRNGKey(42)
        sample_expert_transitions = make_expert_transitions(config)
        sample_expert_transitions = jax.jit(sample_expert_transitions, static_argnums=(0,))
        get_rewards_fn = partial(get_rewards, config=config)        
        get_rewards_fn = jax.jit(get_rewards_fn)

        dummy_env, dummy_env_params = gymnax.make(config["ENV_NAME"])
        # Get action and observation    
        num_actions = dummy_env.action_space(dummy_env_params).n
        
        train = make_train(config, jax.nn.relu, sample_expert_transitions, get_rewards_fn)
        def single_rollout(rng):
            results = train(rng)
            rng, _rng = jax.random.split(rng)
            policy_states, policy_actions = results['eval_transitions'].unnorm_obs, results['eval_transitions'].action
            if config['SUB_SAMPLE_RATE'] > 1:
                _policy_states = policy_states[::config['SUB_SAMPLE_RATE']]
                policy_states = jnp.concatenate([_policy_states, policy_states[-1:]], axis=0)
                _policy_actions = policy_actions[::config['SUB_SAMPLE_RATE']]
                policy_actions = jnp.concatenate([_policy_actions, policy_actions[-1:]], axis=0)
            divergence = compute_policy_expert_divergence(policy_states, policy_actions, _rng, sample_expert_transitions, num_actions)
            return results, divergence
            
        rng, _rng = jax.random.split(rng)
        rngs = jax.random.split(_rng, config['NUM_SEEDS'])

        device_count = jax.device_count()
        num_batches = config['NUM_SEEDS'] // device_count
        # Reshape rngs for batching
        try:
            batched_rngs = jnp.reshape(rngs, (device_count, num_batches) + rngs.shape[1:])
            print('Batched rngs shape:', batched_rngs.shape)
        except ValueError as e:
            print(f"Error reshaping RNGs: {e}. RNGs shape: {rngs.shape}, Target shape: {(device_count, num_batches) + rngs.shape[1:]}")
            raise e

        pmap_train = jax.pmap(jax.vmap(single_rollout, in_axes=(0,)), in_axes=(0,)) # pmap over the device dimension
        print(f'Using Pmap with {device_count} devices.')
        
        results, divergences = pmap_train(batched_rngs)

        # Combine results by flattening the first two axes (device, batch)
        results = jax.tree_util.tree_map(
            lambda x: jnp.reshape(x, (config['NUM_SEEDS'],) + x.shape[2:]),
            results
        )
        divergences = jnp.reshape(divergences, (config['NUM_SEEDS'],))
        
        print(f"Divergence: {divergences.mean()} +- {divergences.std()}")

        # save params
        save_dir = f"../main_results/{config['ENV_NAME']}/{config['DISC_INP']}/{config['N_EXPERT_TRAJS']}/{config['REWARD_TYPE']}/"
        os.makedirs(save_dir, exist_ok=True)

        # save results
        plot_il_metrics(results, save_dir)

        # save params
        # pi_params = results['runner_state'][0].params
        # disc_params = results['runner_state'][1].params
        # _seed = 0
        # pi_params = jax.tree_util.tree_map(lambda x: x[_seed], pi_params)
        # pickle.dump(pi_params, open(f"{save_dir}/pi_params.pkl", "wb"))
        # jnp.save(f"{save_dir}/pi_params.npy", pi_params)
        # jnp.save(f"{save_dir}/disc_params.npy", disc_params)

        # save returns
        returns = results['metrics']["returned_episode_returns"].mean(axis=(-1,-2))
        print('Returns shape:', returns.shape)
        jnp.save(f"{save_dir}/returns.npy", returns)

        # # save divergences
        print('Divergences shape:', divergences.shape)
        jnp.save(f"{save_dir}/divergences.npy", divergences)

        # # save metrics
        disc_logits = results['metrics']['disc_logits']
        disc_logits = disc_logits.reshape(config['NUM_SEEDS'], -1)
        disc_logits = disc_logits.mean(axis=0)
        # sample just 50k (uniformly) indices
        indices = jnp.linspace(0, disc_logits.shape[0], 50000).astype(jnp.int32)
        disc_logits = disc_logits[indices]
        # # print('Disc logits shape:', disc_logits.shape)
        jnp.save(f"{save_dir}/disc_logits.npy", disc_logits)

        # # save entropy
        entropy = results['metrics']['entropy']
        # print('Entropy shape:', entropy.shape)
        jnp.save(f"{save_dir}/entropy.npy", entropy)

        # # eval transitions
        # # print('Eval transitions shape:', results["eval_transitions"].obs.shape)
        # # pickle.dump(results["eval_transitions"], open(f"{save_dir}/eval_transitions.pkl", "wb"))

        # # save eval returns
        # print('Eval returns shape:', results["eval_transitions"].info["returned_episode_returns"].shape)
        # jnp.save(f"{save_dir}/eval_returns.npy", results["eval_transitions"].info["returned_episode_returns"][results["eval_transitions"].info["returned_episode"]])

        # save config
        pickle.dump(config, open(f"{save_dir}/config.pkl", "wb"))

    config = {
        'LR': 0.005,
        'NUM_ENVS': 64,
        'NUM_STEPS': 128,
        'TOTAL_TIMESTEPS': 10_000_000,
        'UPDATE_EPOCHS': 4,
        'NUM_MINIBATCHES': 8,
        'GAMMA': 0.99,
        'GAE_LAMBDA': 0.95,
        'CLIP_EPS': 0.2,
        'ENT_COEF': 0.01,
        'VF_COEF': 0.5,
        'MAX_GRAD_NORM': 0.5,
        'ANNEAL_LR': True,
        'NUM_SEEDS': 16,
        'DEBUG': False,
        'REWARD_TYPE': 'gail',
        'DISC_LR': 3e-4,
        'GP_WEIGHT': 0.1,
        'DISC_UPDATE_EPOCHS': 1,
        'USE_SPECTRAL_NORM': False,
        'SUB_SAMPLE_RATE': 20,
        'N_EXPERT_TRAJS': 10,
        'BACKEND': 'positional',   # not used
        'HP_TUNE': False,
    }
    config['TOTAL_TIMESTEPS'] = 10_000_000
    
    envs = ['SpaceInvaders-MinAtar', 'Breakout-MinAtar', 'Asterix-MinAtar']
    reward_types = ['meta-disc','tanh_plus_1', 'sigmoid']
    disc_inp_types = ['sa'] 
    
    # Check if hyperparameter tuning is enabled
    if config.get('HP_TUNE', False):
        print("Starting hyperparameter tuning with Optuna...")
        
        for env in envs:
            for reward_type in reward_types:
                for disc_inp in disc_inp_types:
                    print(f'\n')
                    print(f'Hyperparameter tuning for:')
                    print(f'Environment: {env}')
                    print(f'Reward type: {reward_type}')
                    print(f'Disc input: {disc_inp}')
                    config['ENV_NAME'] = env
                    config['REWARD_TYPE'] = reward_type
                    config['DISC_INP'] = disc_inp
                    
                    # Create study for hyperparameter optimization
                    study_name = f"hp_tune_{env}_{reward_type}_{disc_inp}"
                    study = optuna.create_study(
                        study_name=study_name,
                        direction='minimize',  # Minimize divergence
                        sampler=TPESampler(seed=42),
                        storage=None  # Use in-memory storage
                    )
                    
                    # Run hyperparameter optimization
                    n_trials = config.get('HP_N_TRIALS', 200)
                    print(f"Running {n_trials} trials...")
                    
                    study.optimize(
                        partial(objective, base_config=config),
                        n_trials=n_trials,
                    )
                    
                    # Print best results
                    print(f"\nBest trial for {env}_{reward_type}_{disc_inp}:")
                    print(f"  Value (divergence): {study.best_value:.4f}")
                    print(f"  Best hyperparameters:")
                    for key, value in study.best_params.items():
                        print(f"    {key}: {value}")
                    
                    # Save study results
                    hp_save_dir = f'./hp_tune_results/{env}/{reward_type}/{disc_inp}/'
                    os.makedirs(hp_save_dir, exist_ok=True)
                    
                    # Save best hyperparameters
                    best_config = config.copy()
                    best_config.update(study.best_params)
                    best_config['ENV_NAME'] = env
                    best_config['REWARD_TYPE'] = reward_type
                    best_config['DISC_INP'] = disc_inp
                    best_config['TOTAL_TIMESTEPS'] = 10_000_000  # Restore full training
                    best_config['NUM_SEEDS'] = 8  # Restore full seeds
                    
                    pickle.dump(best_config, open(f"{hp_save_dir}/best_config.pkl", "wb"))
                    
                    # Save study object
                    pickle.dump(study, open(f"{hp_save_dir}/study.pkl", "wb"))
                    
    else:
        # Original training loop without HP tuning
        for env in envs:
            for reward_type in reward_types:
                for disc_inp in disc_inp_types: 
                    print(f'\n')
                    print(f'Environment: {env}')
                    print(f'Reward type: {reward_type}')
                    print(f'Disc input: {disc_inp}')
                    config['ENV_NAME'] = env
                    config['REWARD_TYPE'] = reward_type
                    config['DISC_INP'] = disc_inp
               
                    start = time.time()
                    print('Config:', config)
                    main(config)
                    print('Time taken:', time.time()-start)