import os
import jax
import jax.numpy as jnp
import optax
from flax.training.train_state import TrainState
from utils import load_config
from wrappers import (
    LogWrapper,
    ILBraxGymnaxWrapper,
    VecEnv,
    NormalizeVecObservation,
    ClipAction,
)
from functools import partial
import time
from utils import make_expert_transitions, compute_policy_expert_divergence
from brax_networks import ActorCritic, Discriminator
from common import Transition, DiscTransitionData, plot_il_metrics
import pickle
import optuna
from optuna.samplers import TPESampler

def make_train(config, activation, sample_expert_transitions, get_rewards_fn):
    config["NUM_UPDATES"] = (
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
    config["MINIBATCH_SIZE"] = (
        config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )
    env, env_params = ILBraxGymnaxWrapper(config["ENV_NAME"], backend=config['BACKEND']), None
    env = LogWrapper(env)
    env = ClipAction(env)
    env = VecEnv(env)
    if config["NORMALIZE_ENV"]:
        env = NormalizeVecObservation(env)

    eval_env, eval_env_params = ILBraxGymnaxWrapper(config["ENV_NAME"], backend=config['BACKEND']), None
    eval_env = LogWrapper(eval_env)
    eval_env = ClipAction(eval_env)
    eval_env = VecEnv(eval_env)

    def linear_schedule(count):
        frac = (
            1.0
            - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
            / config["NUM_UPDATES"]
        )
        return config["LR"] * frac
   
    def train(rng):
        # INIT NETWORK
        network = ActorCritic(
            env.action_space(env_params).shape[0], activation=activation
        )
        rng, _rng = jax.random.split(rng)
        init_x = jnp.zeros(env.observation_space(env_params).shape)
        network_params = network.init(_rng, init_x)
        if config["ANNEAL_LR"]:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(learning_rate=linear_schedule, eps=1e-5),
            )
        else:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(config["LR"], eps=1e-5),
            )
        train_state = TrainState.create(
            apply_fn=network.apply,
            params=network_params,
            tx=tx,
        )

        # INIT DISCRIMINATOR
        disc_network = Discriminator(activation=activation, use_spectral_norm=config["USE_SPECTRAL_NORM"])
        rng, _rng = jax.random.split(rng)
        if config['DISC_INP'] == 'sa':
            init_x = jnp.zeros((env.observation_space(env_params).shape[0] + env.action_space(env_params).shape[0],))
        elif config['DISC_INP'] == 'ss':
            init_x = jnp.zeros((2*env.observation_space(env_params).shape[0],))
        elif config['DISC_INP'] == 's':
            init_x = jnp.zeros((env.observation_space(env_params).shape[0],))
        disc_variables = disc_network.init(_rng, init_x, train=True)
        disc_params = disc_variables['params']
        
        # only nonempty if using spectral normalization
        disc_batch_stats = disc_variables.get('batch_stats', {})
        
        disc_tx = optax.chain(
            optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
            optax.adam(config["DISC_LR"], eps=1e-5),
        )
        disc_train_state = TrainState.create(
            apply_fn=disc_network.apply,
            params=disc_params,
            tx=disc_tx,
        )

        def stagger_timesteps(rng):
            return jax.random.randint(rng, (config['NUM_ENVS'],), 0, 1000).astype(jnp.float32)
        
        # INIT ENV
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state = env.reset(reset_rng, env_params)

        rng, _rng = jax.random.split(rng)
        staggered_timesteps = stagger_timesteps(_rng)
        if config['NORMALIZE_ENV']:
            env_state.env_state.env_state.info['steps'] = staggered_timesteps
        else:
            env_state.env_state.info['steps'] = staggered_timesteps

        # TRAIN LOOP
        def _update_step(runner_state, unused):
            # COLLECT TRAJECTORIES
            def _env_step(runner_state, unused):
                train_state, env_state, last_obs, rng = runner_state

                # GET UNNORMALIZED OBS (FOR DISCRIMINATOR)
                mean, var = env_state.mean, env_state.var
                last_obs_unnorm = last_obs * jnp.sqrt(var) + mean

                # SELECT ACTION
                rng, _rng = jax.random.split(rng)
                pi, value = network.apply(train_state.params, last_obs)
                action = pi.sample(seed=_rng)
                log_prob = pi.log_prob(action)

                # STEP ENV
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["NUM_ENVS"])
                obsv, env_state, reward, done, info = env.step(
                    rng_step, env_state, action, env_params
                )                

                # GET UNNORMALIZED NEXT OBS (FOR DISCRIMINATOR)
                mean, var = env_state.mean, env_state.var
                obsv_unnorm = obsv * jnp.sqrt(var) + mean
                # where done use last unnorm obs for next obs
                obsv_unnorm = jnp.where(done[...,None], last_obs_unnorm, obsv_unnorm)

                transition = Transition(
                    done, action, value, jnp.zeros_like(reward), log_prob, last_obs, last_obs_unnorm, obsv_unnorm, info   # reward filled in by discriminator later
                )
                runner_state = (train_state, env_state, obsv, rng)
                return runner_state, transition

            train_state, disc_train_state, disc_batch_stats, env_state, last_obs, rng = runner_state

            # data collection
            runner_state = (train_state, env_state, last_obs, rng)
            runner_state, traj_batch = jax.lax.scan(
                _env_step, runner_state, None, config["NUM_STEPS"]
            )

            # DISCRIMINATOR UPDATE
            def _update_epoch_discriminator(disc_update_state, unused):
                disc_train_state, disc_batch_stats, traj_batch, rng = disc_update_state
                def _update_minibatch(disc_mini_update_state, disc_transitions):
                    disc_train_state, disc_batch_stats, rng = disc_mini_update_state
                
                    def grad_pen(params, disc_transitions, rng):
                        rng, _rng = jax.random.split(rng)
                        alpha = jax.random.uniform(_rng, (disc_transitions.pi_action.shape[0], 1))
                        interpolates_obs = jax.tree_util.tree_map(
                            lambda x, y: alpha * x + (1 - alpha) * y,
                            disc_transitions.pi_unnorm_obs,
                            disc_transitions.expert_unnorm_obs,
                        )
                        interpolates_action = jax.tree_util.tree_map(
                            lambda x, y: alpha * x + (1 - alpha) * y,
                            disc_transitions.pi_action,
                            disc_transitions.expert_action,
                        )
                        interpolates_next_obs = jax.tree_util.tree_map(
                            lambda x, y: alpha * x + (1 - alpha) * y,
                            disc_transitions.pi_unnorm_next_obs,
                            disc_transitions.expert_unnorm_next_obs,
                        )
                        if config['DISC_INP'] == 'ss':
                            concatenated = jnp.concatenate([interpolates_obs, interpolates_next_obs], axis=-1)
                        elif config['DISC_INP'] == 'sa':
                            concatenated = jnp.concatenate([interpolates_obs, interpolates_action], axis=-1)
                        elif config['DISC_INP'] == 's':
                            concatenated = interpolates_obs
                        else:
                            raise ValueError(f"Invalid input type: {config['DISC_INP']}")
                        
                        def disc_out(inputs):
                            output = disc_network.apply(
                                {'params': params, 'batch_stats': disc_batch_stats},
                                inputs,
                                train=True,
                                mutable=['batch_stats']
                            )
                            # mutable collections return (output, mutated_variables)
                            output, _ = output
                            return output.squeeze()
                        
                        grad_fn = jax.vmap(jax.grad(disc_out))
                        grads = grad_fn(concatenated)
                        
                        grad_norms = jnp.linalg.norm(grads, axis=-1)
                        
                        penalty = jnp.mean((grad_norms - 0.0) ** 2)
                        return penalty, rng
    
                    def cross_entropy_loss(params, disc_transitions, rng, disc_batch_stats):
                        # Concatenate all inputs first
                        if config['DISC_INP'] == 'sa':
                            policy_input = jnp.concatenate([disc_transitions.pi_unnorm_obs, disc_transitions.pi_action], axis=-1)
                            expert_input = jnp.concatenate([disc_transitions.expert_unnorm_obs, disc_transitions.expert_action], axis=-1)
                        elif config['DISC_INP'] == 's':
                            policy_input = disc_transitions.pi_unnorm_obs
                            expert_input = disc_transitions.expert_unnorm_obs
                        elif config['DISC_INP'] == 'ss':
                            policy_input = jnp.concatenate([disc_transitions.pi_unnorm_obs, disc_transitions.pi_unnorm_next_obs], axis=-1)
                            expert_input = jnp.concatenate([disc_transitions.expert_unnorm_obs, disc_transitions.expert_unnorm_next_obs], axis=-1)        
                        else:
                            raise ValueError(f"Invalid input type: {config['DISC_INP']}")
                        combined_input = jnp.concatenate([policy_input, expert_input], axis=0)
                        
                        # Single forward pass that also updates batch stats
                        if config["USE_SPECTRAL_NORM"]:
                            all_logits, new_batch_stats = disc_network.apply(
                                {'params': params, 'batch_stats': disc_batch_stats},
                                combined_input,
                                train=True,
                                mutable=['batch_stats']
                            )
                            disc_batch_stats = new_batch_stats['batch_stats']
                        else:
                            all_logits = disc_network.apply(
                                {'params': params, 'batch_stats': disc_batch_stats},
                                combined_input,
                                train=False
                            )
                        
                        # Split logits back into policy and expert
                        pi_logits = all_logits[:policy_input.shape[0]]
                        expert_logits = all_logits[policy_input.shape[0]:]
                        
                        disc_logits = jnp.concatenate([pi_logits, expert_logits], axis=0)
                        disc_labels = jnp.concatenate([jnp.zeros((pi_logits.shape[0], 1)), jnp.ones((expert_logits.shape[0], 1))], axis=0)
                        
                        # shuffle the data
                        rng, _rng = jax.random.split(rng)
                        permutation = jax.random.permutation(_rng, disc_logits.shape[0])
                        disc_logits = jnp.take(disc_logits, permutation, axis=0)
                        disc_labels = jnp.take(disc_labels, permutation, axis=0)
                        
                        ce_loss = jnp.mean(optax.sigmoid_binary_cross_entropy(disc_logits, disc_labels))
                        return ce_loss, disc_batch_stats, rng

                    def _loss_fn(params, disc_transitions, rng, disc_batch_stats):
                        ce_loss, disc_batch_stats, rng = cross_entropy_loss(params, disc_transitions, rng, disc_batch_stats)
                        if not config["USE_SPECTRAL_NORM"]:
                            gp_loss, rng = grad_pen(params, disc_transitions, rng) # L2 penalty on gradient norm
                            loss = ce_loss + config["GP_WEIGHT"] * gp_loss
                        else:
                            loss = ce_loss
                            gp_loss = jnp.zeros_like(ce_loss)
                        return loss, (ce_loss, gp_loss, disc_batch_stats, rng)

                    grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                    (loss, (ce_loss, gp_loss, disc_batch_stats, rng)), grads = grad_fn(disc_train_state.params, disc_transitions, rng, disc_batch_stats)
                    
                    disc_train_state = disc_train_state.apply_gradients(grads=grads)
                    return (disc_train_state, disc_batch_stats, rng), (loss, ce_loss, gp_loss)

                # generate minibatches for discriminator training 
                batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
                assert (
                    batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
                ), "batch size must be equal to number of steps * number of envs"
                
                # create batch of pi transitions
                pi_batch = jax.tree_util.tree_map(lambda x: x.reshape((batch_size,) + x.shape[2:]), traj_batch)

                # create batch of expert transitions
                expert_batch, rng = sample_expert_transitions(batch_size, rng)

                assert pi_batch.obs.shape == expert_batch.obs.shape, "pi and expert batch shapes must match"
                
                disc_transitions_batch = DiscTransitionData(
                    pi_action=pi_batch.action,
                    pi_unnorm_obs=pi_batch.unnorm_obs,
                    pi_unnorm_next_obs=pi_batch.unnorm_next_obs,
                    expert_action=expert_batch.action,
                    expert_unnorm_obs=expert_batch.unnorm_obs,
                    expert_unnorm_next_obs=expert_batch.unnorm_next_obs,
                )

                disc_transitions_minibatches = jax.tree_util.tree_map(
                    lambda x: jnp.reshape(x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])),
                    disc_transitions_batch,
                )               

                disc_mini_update_state = (disc_train_state, disc_batch_stats, rng)
                disc_mini_update_state, losses = jax.lax.scan(
                    _update_minibatch, disc_mini_update_state, disc_transitions_minibatches
                )

                disc_train_state, disc_batch_stats, rng = disc_mini_update_state
                
                disc_update_state = (disc_train_state, disc_batch_stats, traj_batch, rng)
                return disc_update_state, losses
            
            
            disc_update_state = (disc_train_state, disc_batch_stats, traj_batch, rng)
            disc_update_state, disc_losses = jax.lax.scan(
                _update_epoch_discriminator, disc_update_state, None, config["DISC_UPDATE_EPOCHS"]
            )
            disc_train_state, disc_batch_stats, _, rng = disc_update_state

            # relabel pi transitions with discriminator rewards
            disc_input = jnp.concatenate([traj_batch.unnorm_obs, traj_batch.action], axis=-1)
            if config['DISC_INP'] == 'ss':
                disc_input = jnp.concatenate([traj_batch.unnorm_obs, traj_batch.unnorm_next_obs], axis=-1)
            elif config['DISC_INP'] == 'sa':
                disc_input = jnp.concatenate([traj_batch.unnorm_obs, traj_batch.action], axis=-1)
            elif config['DISC_INP'] == 's':
                disc_input = traj_batch.unnorm_obs
            disc_logits = jax.lax.stop_gradient(
                disc_network.apply(
                    {'params': disc_train_state.params, 'batch_stats': disc_batch_stats},
                    disc_input,
                    train=False, # inference
                ).squeeze()
            )
            traj_batch = traj_batch._replace(
                reward=get_rewards_fn(disc_logits)
            )

            # CALCULATE ADVANTAGE
            train_state, env_state, last_obs, rng = runner_state
            _, last_val = network.apply(train_state.params, last_obs)

            def _calculate_gae(traj_batch, last_val):
                def _get_advantages(gae_and_next_value, transition):
                    gae, next_value = gae_and_next_value
                    done, value, reward = (
                        transition.done,
                        transition.value,
                        transition.reward,
                    )
                    delta = reward + config["GAMMA"] * next_value * (1 - done) - value
                    gae = (
                        delta
                        + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
                    )
                    return (gae, value), gae

                _, advantages = jax.lax.scan(
                    _get_advantages,
                    (jnp.zeros_like(last_val), last_val),
                    traj_batch,
                    reverse=True,
                    unroll=16,
                )
                return advantages, advantages + traj_batch.value

            advantages, targets = _calculate_gae(traj_batch, last_val)

            # UPDATE NETWORK
            def _update_epoch(update_state, unused):
                def _update_minbatch(train_state, batch_info):
                    traj_batch, advantages, targets = batch_info

                    def _loss_fn(params, traj_batch, gae, targets):
                        # RERUN NETWORK
                        pi, value = network.apply(params, traj_batch.obs)
                        log_prob = pi.log_prob(traj_batch.action)

                        # CALCULATE VALUE LOSS
                        value_pred_clipped = traj_batch.value + (
                            value - traj_batch.value
                        ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
                        value_losses = jnp.square(value - targets)
                        value_losses_clipped = jnp.square(value_pred_clipped - targets)
                        value_loss = (
                            0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
                        )

                        # CALCULATE ACTOR LOSS
                        ratio = jnp.exp(log_prob - traj_batch.log_prob)
                        gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                        loss_actor1 = ratio * gae
                        loss_actor2 = (
                            jnp.clip(
                                ratio,
                                1.0 - config["CLIP_EPS"],
                                1.0 + config["CLIP_EPS"],
                            )
                            * gae
                        )
                        loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
                        loss_actor = loss_actor.mean()
                        entropy = pi.entropy().mean()

                        total_loss = (
                            loss_actor
                            + config["VF_COEF"] * value_loss
                            - config["ENT_COEF"] * entropy
                        )
                        return total_loss, (value_loss, loss_actor, entropy)

                    grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                    total_loss, grads = grad_fn(
                        train_state.params, traj_batch, advantages, targets
                    )
                    train_state = train_state.apply_gradients(grads=grads)
                    return train_state, total_loss

                train_state, traj_batch, advantages, targets, rng = update_state
                rng, _rng = jax.random.split(rng)
                batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
                assert (
                    batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
                ), "batch size must be equal to number of steps * number of envs"
                permutation = jax.random.permutation(_rng, batch_size)
                batch = (traj_batch, advantages, targets)
                batch = jax.tree_util.tree_map(
                    lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
                )
                shuffled_batch = jax.tree_util.tree_map(
                    lambda x: jnp.take(x, permutation, axis=0), batch
                )
                minibatches = jax.tree_util.tree_map(
                    lambda x: jnp.reshape(
                        x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
                    ),
                    shuffled_batch,
                )
                train_state, total_loss = jax.lax.scan(
                    _update_minbatch, train_state, minibatches
                )
                update_state = (train_state, traj_batch, advantages, targets, rng)
                return update_state, total_loss

            update_state = (train_state, traj_batch, advantages, targets, rng)
            update_state, loss_info = jax.lax.scan(
                _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
            )
            train_state = update_state[0]
            metric = traj_batch.info
            metric['disc_total_loss'] = disc_losses[0].mean()
            metric['disc_ce_loss'] = disc_losses[1].mean()
            metric['disc_gp_loss'] = disc_losses[2].mean()
            metric['disc_rewards'] = traj_batch.reward.mean()
            metric['entropy'] = loss_info[-1][2].mean()
            metric['actor_loss'] = loss_info[-1][1].mean()
            metric['value_loss'] = loss_info[-1][0].mean()
            metric['disc_logits'] = disc_logits
            rng = update_state[-1]
            if config.get("DEBUG"):

                def callback(info):
                    return_values = info["returned_episode_returns"][
                        info["returned_episode"]
                    ]
                    timesteps = (
                        info["timestep"][info["returned_episode"]] * config["NUM_ENVS"]
                    )
                    for t in range(len(timesteps)):
                        print(
                            f"global step={timesteps[t]}, episodic return={return_values[t]}"
                        )

                jax.debug.callback(callback, metric)

            runner_state = (train_state, disc_train_state, disc_batch_stats, env_state, last_obs, rng)
            return runner_state, metric

        rng, _rng = jax.random.split(rng)
        runner_state = (train_state, disc_train_state, disc_batch_stats, env_state, obsv, _rng)
        runner_state, metric = jax.lax.scan(
            _update_step, runner_state, None, config["NUM_UPDATES"]
        )
        
        def _update_eval_step(runner_state, unused):
            params, mean, var, env_state, last_obs, rng = runner_state

            # GET NORMALIZED OBS
            last_obs_norm = (last_obs - mean) / jnp.sqrt(var)

            # SELECT ACTION
            pi, _ = network.apply(params, last_obs_norm)
            rng, _rng = jax.random.split(rng)
            action = pi.sample(seed=_rng)

            # STEP ENV
            rng, _rng = jax.random.split(rng)
            rng_step = jax.random.split(_rng, config["EVAL_NUM_ENVS"])
            obsv, env_state, reward, done, info = eval_env.step(
                rng_step, env_state, action, eval_env_params
            )

            transition = Transition(
                done, action, None, reward, None, last_obs_norm, last_obs, obsv, info   
            )

            runner_state = (params, mean, var, env_state, obsv, rng)
            return runner_state, transition
        
        params = runner_state[0].params
        mean, var = runner_state[3].mean[0], runner_state[3].var[0]
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["EVAL_NUM_ENVS"])
        eval_obsv, eval_env_state = eval_env.reset(reset_rng, eval_env_params)
        
        eval_runner_state = (params, mean, var, eval_env_state, eval_obsv, rng)
        _, eval_transitions = jax.lax.scan(
            _update_eval_step, eval_runner_state, None, length=1000
        )

        return {"runner_state": runner_state, "metrics": metric, "eval_transitions": eval_transitions}

    return train

"""
Reward function
logits: ratio of densities \log(\frac{\rho^{E}(s,a)}{\rho^{\pi}(s,a)})
"""
def get_rewards(logits, config):
    if config['REWARD_TYPE'] == 'gail':
        return jax.nn.softplus(logits)          
    elif config['REWARD_TYPE'] == 'airl':
        return logits
    elif config['REWARD_TYPE'] == 'fairl':
        return -logits*jnp.exp(logits)
    elif config['REWARD_TYPE'] == 'logd':
        return -jax.nn.softplus(-logits)
    elif config['REWARD_TYPE'] == 'meta-disc':
        sigmoid_logits = jax.nn.sigmoid(logits)           # Smooth gating, values in (0,1)
        tanh_mod = 0.5 * (jnp.tanh(logits) + 1.0)         # Scales tanh from [-1,1] to [0,1]
        reward = sigmoid_logits * tanh_mod                 # Product yields reward in [0,1]
        return reward
    elif config['REWARD_TYPE'] == 'sigmoid':
        return jax.nn.sigmoid(logits)
    elif config['REWARD_TYPE'] == 'tanh_plus_1':
        return (jnp.tanh(logits) + 1) * 0.5
    else:
        raise ValueError(f"Unknown reward type: {config['REWARD_TYPE']}")

def objective(trial, base_config):
    """
    Optuna objective function for hyperparameter tuning.
    Returns the mean divergence across seeds (lower is better).
    """
    # Define hyperparameter search space
    config = base_config.copy()
        
    # Learning rates
    config['LR'] = trial.suggest_float('lr', 1e-5, 1e-3, log=True)
    config['DISC_LR'] = trial.suggest_float('disc_lr', 1e-5, 1e-3, log=True)
    
    # PPO parameters
    config['CLIP_EPS'] = trial.suggest_float('clip_eps', 0.1, 0.3)
    config['VF_COEF'] = trial.suggest_float('vf_coef', 0.25, 1.0)
    config['ENT_COEF'] = trial.suggest_float('ent_coef', 0.01, 0.1)
    
    # Discriminator parameters
    config['GP_WEIGHT'] = trial.suggest_float('gp_weight', 0.01, 100.0, log=True)
    # config['USE_SPECTRAL_NORM'] = trial.suggest_categorical('use_spectral_norm', [True, False])
    
    # Training parameters
    # config['NUM_STEPS'] = trial.suggest_categorical('num_steps', [128, 256, 512])
    # config['NUM_ENVS'] = trial.suggest_categorical('num_envs', [4, 8, 16])
    config['UPDATE_EPOCHS'] = trial.suggest_int('update_epochs', 1, 8)
    config['DISC_UPDATE_EPOCHS'] = trial.suggest_int('disc_update_epochs', 1, 10)
    
    # GAE parameters
    config['GAE_LAMBDA'] = trial.suggest_float('gae_lambda', 0.9, 0.99)
    
    print(f"Trial {trial.number}: Testing config with LR={config['LR']:.2e}, DISC_LR={config['DISC_LR']:.2e}")

    rng = jax.random.PRNGKey(42)
    sample_expert_transitions = make_expert_transitions(config)
    sample_expert_transitions = jax.jit(sample_expert_transitions, static_argnums=(0,))
    get_rewards_fn = partial(get_rewards, config=config)        
    get_rewards_fn = jax.jit(get_rewards_fn)

    train = make_train(config, jax.nn.relu, sample_expert_transitions, get_rewards_fn)
    
    def single_rollout(rng):
        results = train(rng)
        rng, _rng = jax.random.split(rng)
        policy_states, policy_actions = results['eval_transitions'].unnorm_obs, results['eval_transitions'].action
        if config['SUB_SAMPLE_RATE'] > 1:
            _policy_states = policy_states[::config['SUB_SAMPLE_RATE']]
            policy_states = jnp.concatenate([_policy_states, policy_states[-1:]], axis=0)
            _policy_actions = policy_actions[::config['SUB_SAMPLE_RATE']]
            policy_actions = jnp.concatenate([_policy_actions, policy_actions[-1:]], axis=0)
        divergence = compute_policy_expert_divergence(policy_states, policy_actions, _rng, sample_expert_transitions, num_actions)
        return results, divergence
        
    rng, _rng = jax.random.split(rng)
    rngs = jax.random.split(_rng, config['NUM_SEEDS'])
    
    device_count = jax.device_count()
    num_batches = config['NUM_SEEDS'] // device_count

    # Reshape rngs for batching
    try:
        batched_rngs = jnp.reshape(rngs, (device_count, num_batches) + rngs.shape[1:])
        print('Batched rngs shape:', batched_rngs.shape)
    except ValueError as e:
        print(f"Error reshaping RNGs: {e}. RNGs shape: {rngs.shape}, Target shape: {(device_count, num_batches) + rngs.shape[1:]}")
        raise e

    pmap_train = jax.pmap(jax.vmap(single_rollout, in_axes=(0,)), in_axes=(0,)) # pmap over the device dimension
    results, divergences = pmap_train(batched_rngs)

    # Combine results by flattening the first two axes (device, batch)
    results = jax.tree_util.tree_map(
        lambda x: jnp.reshape(x, (config['NUM_SEEDS'],) + x.shape[2:]),
        results
    )
    # Combine results
    divergences = jnp.reshape(divergences, (config['NUM_SEEDS'],))
    
    mean_divergence = float(divergences.mean())
    mean_return = float(results['metrics']["returned_episode_returns"].mean(axis=(-1, -2, 0))[-1])
    print(f"Trial {trial.number}: Mean divergence = {mean_divergence:.4f} Mean return = {mean_return:.4f}")
    
    return mean_divergence   
    

if __name__ == '__main__':

    def main(config):
        print('Training: ', config['ENV_NAME'])
        print('Config:', config)
        sample_expert_transitions = make_expert_transitions(config)
        sample_expert_transitions = jax.jit(sample_expert_transitions, static_argnums=(0,))
        get_rewards_fn = partial(get_rewards, config=config)
        get_rewards_fn = jax.jit(get_rewards_fn)
        
        train = make_train(config, jax.nn.relu, sample_expert_transitions, get_rewards_fn)
        def single_rollout(rng):
            rng, _rng = jax.random.split(rng)
            results = train(_rng)
            policy_states, policy_actions = results['eval_transitions'].unnorm_obs, results['eval_transitions'].action
            if config['SUB_SAMPLE_RATE'] > 1:
                _policy_states = policy_states[::config['SUB_SAMPLE_RATE']]
                policy_states = jnp.concatenate([_policy_states, policy_states[-1:]], axis=0)
                _policy_actions = policy_actions[::config['SUB_SAMPLE_RATE']]
                policy_actions = jnp.concatenate([_policy_actions, policy_actions[-1:]], axis=0)

            divergence = compute_policy_expert_divergence(policy_states, policy_actions, rng, sample_expert_transitions, config['DISC_INP'])
            return results, divergence

        rng = jax.random.PRNGKey(42)   
        rng, _rng = jax.random.split(rng)
        rngs = jax.random.split(_rng, config['NUM_SEEDS'])
        # print(rngs)

        device_count = jax.device_count()
        num_batches = config['NUM_SEEDS'] // device_count
        # Reshape rngs for batching
        try:
            batched_rngs = jnp.reshape(rngs, (device_count, num_batches) + rngs.shape[1:])
            print('Batched rngs shape:', batched_rngs.shape)
            # print(batched_rngs)
        except ValueError as e:
            print(f"Error reshaping RNGs: {e}. RNGs shape: {rngs.shape}, Target shape: {(device_count, num_batches) + rngs.shape[1:]}")
            raise e

        pmap_train = jax.pmap(jax.vmap(single_rollout, in_axes=(0,)), in_axes=(0,)) # pmap over the device dimension
        print(f'Using Pmap with {device_count} devices.')
        
        results, divergences = pmap_train(batched_rngs)

        # Combine results by flattening the first two axes (device, batch)
        results = jax.tree_util.tree_map(
            lambda x: jnp.reshape(x, (config['NUM_SEEDS'],) + x.shape[2:]),
            results
        )
        divergences = jnp.reshape(divergences, (config['NUM_SEEDS'],))
            
        divergences = jax.device_put(divergences, jax.devices("cpu")[0])
        results['metrics'] = jax.device_put(results['metrics'], jax.devices("cpu")[0])
        results['eval_transitions'] = jax.device_put(results['eval_transitions'], jax.devices("cpu")[0])

        nan_ids = jnp.isnan(divergences)
        print('Divergences:', divergences)
        print('Nan ids:', nan_ids)
        results['metrics'] = jax.tree_util.tree_map(lambda x: x[~nan_ids], results['metrics'])
        results['eval_transitions'] = jax.tree_util.tree_map(lambda x: x[~nan_ids], results['eval_transitions'])
        divergences = divergences[~nan_ids]

        quality = 1 - nan_ids.sum() / config['NUM_SEEDS']
        
        print('Quality:', quality)
        
        print(f"Divergence: {divergences.mean()} +- {divergences.std()}")

        # save params
        save_dir = f"../main_results/{config['ENV_NAME']}/{config['DISC_INP']}/{config['N_EXPERT_TRAJS']}/{config['REWARD_TYPE']}/"
        os.makedirs(save_dir, exist_ok=True)

        # save results
        plot_il_metrics(results, save_dir)

        # save params
        # pi_params = results['runner_state'][0].params
        # disc_params = results['runner_state'][1].params
        # jnp.save(f"{save_dir}/pi_params.npy", pi_params)
        # jnp.save(f"{save_dir}/disc_params.npy", disc_params)

        # save returns
        returns = results['metrics']["returned_episode_returns"].mean(axis=(-1,-2))
        print('Returns shape:', returns.shape)
        jnp.save(f"{save_dir}/returns.npy", returns)

        # # save divergences
        print('Divergences shape:', divergences.shape)
        jnp.save(f"{save_dir}/divergences.npy", divergences)

        # save disc logits
        disc_logits = results['metrics']['disc_logits']
        disc_logits = disc_logits.reshape(config['NUM_SEEDS'], -1)
        disc_logits = disc_logits.mean(axis=0)
        # sample just 50k (uniformly) indices
        indices = jnp.linspace(0, disc_logits.shape[0], 50000).astype(jnp.int32)
        disc_logits = disc_logits[indices]
        # print('Disc logits shape:', disc_logits.shape)
        jnp.save(f"{save_dir}/disc_logits.npy", disc_logits)

        # save entropy
        entropy = results['metrics']['entropy']
        # print('Entropy shape:', entropy.shape)
        jnp.save(f"{save_dir}/entropy.npy", entropy)

        # # save actor loss
        # actor_loss = results['metrics']['actor_loss']
        # print('Actor loss shape:', actor_loss.shape)
        # jnp.save(f"{save_dir}/actor_loss.npy", actor_loss)

        # # save value loss
        # value_loss = results['metrics']['value_loss']
        # print('Value loss shape:', value_loss.shape)
        # jnp.save(f"{save_dir}/value_loss.npy", value_loss)

        # eval transitions
        # print('Eval transitions shape:', results["eval_transitions"].obs.shape)
        # pickle.dump(results["eval_transitions"], open(f"{save_dir}/eval_transitions.pkl", "wb"))

        # save eval returns
        # print('Eval returns shape:', results["eval_transitions"].info["returned_episode_returns"].shape)
        # jnp.save(f"{save_dir}/eval_returns.npy", results["eval_transitions"].info["returned_episode_returns"][results["eval_transitions"].info["returned_episode"]])

        # # save config
        # pickle.dump(config, open(f"{save_dir}/config.pkl", "wb"))


    config = {
        'LR': 3.0e-4,
        'NUM_ENVS': 2048,
        'NUM_STEPS': 10,
        'TOTAL_TIMESTEPS': 10_000_000,
        'ENV_NAME': "hopper",
        'UPDATE_EPOCHS': 4,
        'NUM_MINIBATCHES': 32,
        'GAMMA': 0.99,
        'GAE_LAMBDA': 0.95,
        'CLIP_EPS': 0.2,
        'ENT_COEF': 0.0,
        'VF_COEF': 0.5,
        'MAX_GRAD_NORM': 0.5,
        'ANNEAL_LR': False,
        'NORMALIZE_ENV': True,
        'DEBUG': False,
        'NUM_SEEDS': 16,
        'N_EXPERT_TRAJS': 10,
        'DISC_LR': 3.0e-4,
        'DISC_UPDATE_EPOCHS': 1,
        'GP_WEIGHT': 1.0,
        'REWARD_TYPE': "gail",
        'EVAL_NUM_ENVS': 10,
        'USE_SPECTRAL_NORM': False,
        'SUB_SAMPLE_RATE': 20,
        'BACKEND': 'mjx',
        'HP_TUNE': False,

    }
    config['TOTAL_TIMESTEPS'] = 50_000_000
    envs = ['ant', 'halfcheetah', 'hopper', 'walker2d', 'reacher', 'humanoid']
    reward_types = ['meta-disc', 'gail', 'airl', 'fairl']

    if config['HP_TUNE']:
        study = optuna.create_study(direction='minimize', sampler=TPESampler(seed=42), storage=None)
        study.optimize(partial(objective, base_config=config), n_trials=200)
        print("Best trial:")
        trial = study.best_trial
        print(f"  Value: {trial.value}")
        print("  Params: ")

    else:
        for env in envs:
            for reward_type in reward_types:
                print('\n')
                config['ENV_NAME'] = env
                config['REWARD_TYPE'] = reward_type
                # log all vars
                print('Reward type:', reward_type)
                print('Environment:', env)

                start = time.time()
                main(config)
                print('Time taken:', time.time()-start)