# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpool_xla_jaxpy
import os
import random
import time
import warnings
from dataclasses import dataclass, field
from functools import partial
from typing import Tuple, Optional, Callable, Dict, Union

import envpool
import flax
import gym
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tyro
from flax import struct
from flax.training.train_state import TrainState
from torch.utils.tensorboard import SummaryWriter

from utils.gp import lipschitz_gp, wgan_gp
from utils.logs import log
from utils.loss import smooth_l1_loss, hamming_distance, CategoricalCost
from networks.architectures import NetworkConv, NetworkFCOutput, Actor, Critic, DiscreteActionTransitionCNN, \
    LipDiscreteActionTransitionCNN, LipDiscreteActionRewardNetwork, \
    CategoricalEncoder, NetworkAttentionOutput, DiscreteActionTransitionNetworkSoftMoE, \
    DiscreteActionRewardNetworkSoftMoE, NetType, DiscreteActionTransitionNetwork, DiscreteActionRewardNetwork, \
    AutoregressiveDiscreteActionTransitionTransformer, DonePredictor, BoundedDiscriminator, LipschitzDiscriminator
from utils.distributions import TransitionDensity
from utils.scores import load_env_mean_std, atari_human_normalized_scores
from envs.confounding_gridworld import ConfoundingGridEnvPoolLike
import envs.confounding_gridworld as confounding_grid

# Fix weird OOM https://github.com/google/jax/discussions/6332#discussioncomment-1279991
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.6"
# Fix CUDNN non-determinisim; https://github.com/google/jax/issues/4823#issuecomment-952835771
os.environ["TF_XLA_FLAGS"] = "--xla_gpu_autotune_level=2 --xla_gpu_deterministic_reductions"
os.environ["TF_CUDNN DETERMINISTIC"] = "1"


@dataclass
class Args:
    exp_name: str = os.path.basename(__file__)[: -len(".py")]
    """the name of this experiment"""
    seed: int = 1
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = True
    """if toggled, cuda will be enabled by default"""
    track: bool = False
    """if toggled, this experiment will be tracked with Weights and Biases"""
    track_params: bool = False
    """if toggled, parameters and grads will be tracked with Weights and Biases"""
    track_grads: bool = False
    """if toggled, gradients will be tracked with Weights and Biases"""
    wandb_project_name: str = "cleanRL"
    """the wandb's project name"""
    wandb_entity: str = None
    """the entity (team) of wandb's project"""
    wandb_tags: list[str] = field(default_factory=list)
    """the tags of the wandb's run"""
    capture_video: bool = False
    """whether to capture videos of the agent performances (check out `videos` folder)"""
    save_model: bool = False
    """whether to save model into the `runs/{run_name}` folder"""
    upload_model: bool = False
    """whether to upload the saved model to huggingface"""
    hf_entity: str = ""
    """the user or org name of the model repository from the Hugging Face Hub"""
    stochastic_env: bool = False
    """whether to use stochastic environment (default is False, which means deterministic environment)"""
    reward_clip: bool = True
    """whether to clip the reward to [-1, 1] in the environment (default is True, which means rewards are clipped)"""
    compare_scores_csv: str = "utils/cleanrl_result_table.csv"
    """the csv file to save the comparison scores of this experiment with other experiments"""
    compare_scores_column: str = "CleanRL's ppo_atari_envpool_xla_jax.py"
    """the column name of the comparison scores in the csv file"""
    compare_scores_max: bool = False
    """whether to compare scores with the max value reported (default is False, which means *last* score reported)"""
    hp_tuning_mode: bool = False
    """if toggled, the script will run in parameter tuning mode, which means it will not run the full experiment"""

    # Algorithm specific arguments
    env_id: str = "Breakout-v5"
    """the id of the environment"""
    random_env_id: bool = False
    """if toggled, a random environment id will be used from the envpool registry; ignore the env_id argument"""
    total_timesteps: int = 10000000
    """total timesteps of the experiments"""
    learning_rate: float = 2.5e-4
    """the learning rate of the optimizer"""
    num_envs: int = 128
    """the number of parallel game environments"""
    num_steps: int = 8
    """the number of steps to run in each environment per policy rollout"""
    parallel_envs_config: str = None
    """string of the form 'n_env=8,n_steps=128' to override num_envs and num_steps"""
    anneal_lr: bool = True
    """Toggle learning rate annealing for policy and value networks"""
    activation: str = "relu"
    """the activation function to use"""
    hadamard_representation: bool = False
    """(experimental) use hadamard representation after the conv layers"""
    gamma: float = 0.99
    """the discount factor gamma"""
    gae_lambda: float = 0.95
    """the lambda for the general advantage estimation"""
    num_minibatches: int = 4
    """the number of mini-batches"""
    update_epochs: int = 4
    """the K epochs to update the policy"""
    norm_adv: bool = True
    """Toggles advantages normalization"""
    clip_coef: float = 0.1
    """the surrogate clipping coefficient"""
    clip_vloss: bool = True
    """Toggles whether or not to use a clipped loss for the value function, as per the paper."""
    ent_coef: float = 0.01
    """coefficient of the entropy"""
    vf_coef: float = 0.5
    """coefficient of the value function"""
    max_grad_norm: float = 0.5
    """the maximum norm for the gradient clipping"""
    target_kl: float = None
    """the target KL divergence threshold"""
    decoupled_repr: bool = False
    """whether to use decoupled actor and critic representations"""
    drift_formulation: bool = False
    """whether to use the drift formulation of PPO (default is the standard PPO)"""
    drift_coef: float = 1.0
    """drift coefficient drif_coef * D_{π_n}(π_{n + 1} | s) (only used if using drift formulation)"""
    transition_loss_coef: float = 0.0005
    """coefficient for the transition loss (only if using world model)"""
    reward_loss_coef: float = 0.01
    """coefficient for the reward loss (only if using world model)"""
    transition_density: TransitionDensity = TransitionDensity.MIXTURE_NORMAL
    """Choice of transition distribution: DETERMINISTIC | NORMAL | MIXTURE_NORMAL | CATEGORICAL"""
    lambda_gp: float = 0.01
    """coefficient for the gradient penalty (only if using world model)"""
    use_wgan_gp: bool = False
    """whether to use WGAN gradient penalty for enforcing the Lipschitzness of the world model"""
    lipschitz_nets: bool = True
    """whether to use Lipschitz networks for the world model (only if using world model)"""
    use_gumbel_softmax: bool = False
    """whether to use Gumbel-Softmax for the actor network (only if transition_density is CATEGORICAL)"""
    num_categories: int = 32
    """number of categories per dimension (only if transition_density is CATEGORICAL)"""
    num_classes: int = 32
    """number of classes per category (only if transition_density is CATEGORICAL)"""
    categorical_cost: CategoricalCost = CategoricalCost.L2
    """Choice of the categorical cost: L2 | CROSS_ENTROPY | HAMMING | JENSEN_SHANNON (only if transition_density is CATEGORICAL)"""
    use_attention: bool = False
    """whether to use attention in the actor network (only if transition_density is CATEGORICAL)"""
    auxiliary_task_net_type: NetType = NetType.FC
    """Type of auxiliary task network: CONV | FC | SOFTMOE | TRANSFORMER (only if transition_density is CATEGORICAL)"""
    layer_norm_cnn_output: bool = False
    """whether to use layer normalization before the CNN output"""
    use_layer_norm: bool = False
    """whether to use layer normalization in all the layers of the networks"""
    dreamer_architecture: bool = False
    """whether to use the same architecture as DreamerV2 for the actor and critic networks"""
    use_feature_group: bool = True
    """whether to use feature group convolution in the CNN transition network (only if auxiliary_task_net_type is CONV)"""
    w_balancing: bool = False
    """whether to use 'Wasserstein' balancing for the auxiliary task, when computing the transition loss"""
    w_balancing_weight: float = .8
    """the weight (say alpha) for the 'Wasserstein' balancing, set on the prior: (1 - alpha) * z_prime + alpha * z_prime_sampled"""
    wasserstein_discriminator: bool = False
    """whether to use a discriminator to approximate Wasserstein for the auxiliary task, when computing the transition loss"""
    discriminator_obs_encoder: bool = False
    """whether to learn a representation through the discriminator, i.e., learn an additional obs encoder for the discriminator (only if wasserstein_discriminator is True)"""
    piecewise_auxiliary_ratio: bool = True
    """Whether to apply the policy ratio sample-wise to reweight the auxiliary losses."""
    clip_reward_net_output: bool = False
    """whether to clip the output of the reward network to [-1, 1]"""
    use_huber: bool = False
    """Whether to use Huber loss instead of MSE for the value loss."""
    deep_mdp: bool = False
    """whether to use DeepMDP formulation for the auxiliary tasks -- i.e., no representation/model update constraint."""

    # confounding policy update didactic environment
    toy_confounding_env: bool = False
    """whether to use the toy confounding environment"""
    toy_confounding_env_epsilon: float = 0.2
    """the epsilon parameter for the toy confounding environment"""
    toy_confounding_env_n_path: int = 5
    """the n_path parameter for the toy confounding environment"""

    # to be filled in runtime
    batch_size: int = 0
    """the batch size (computed in runtime)"""
    minibatch_size: int = 0
    """the mini-batch size (computed in runtime)"""
    num_iterations: int = 0
    """the number of iterations (computed in runtime)"""

def make_env(env_id, seed, num_envs, reward_clip: bool, stochastic_env: bool):
    def thunk():
        envs = envpool.make(
            env_id,
            env_type="gym",
            num_envs=num_envs,
            episodic_life=True,
            reward_clip=reward_clip,
            seed=seed,
            repeat_action_probability=0.3 if stochastic_env else 0.,
            noop_max=60 if stochastic_env else 30,
        )
        envs.num_envs = num_envs
        envs.single_action_space = envs.action_space
        envs.single_observation_space = envs.observation_space
        envs.is_vector_env = True
        return envs

    return thunk


@flax.struct.dataclass
class AgentParams:
    actor_network_params: Tuple[flax.core.FrozenDict, flax.core.FrozenDict]
    critic_network_params: Tuple[flax.core.FrozenDict, flax.core.FrozenDict]
    actor_params: flax.core.FrozenDict
    critic_params: flax.core.FrozenDict

@flax.struct.dataclass
class WorldModelParams:
    transition_network_params: flax.core.FrozenDict
    reward_network_params: flax.core.FrozenDict
    done_predictor_params: Optional[flax.core.FrozenDict] = None

@flax.struct.dataclass
class FullParams:
    agent: AgentParams
    world_model: WorldModelParams
    discriminator: Optional[flax.core.FrozenDict] = None

@flax.struct.dataclass
class Storage:
    obs: jnp.array
    actions: jnp.array
    logprobs: jnp.array
    dones: jnp.array
    values: jnp.array
    advantages: jnp.array
    returns: jnp.array
    rewards: jnp.array
    next_obs: jnp.array
    next_dones: Optional[jnp.array] = None
    hist_logprobs: Optional[jnp.array] = None

@flax.struct.dataclass
class EpisodeStatistics:
    episode_returns: jnp.array
    episode_lengths: jnp.array
    returned_episode_returns: jnp.array
    returned_episode_lengths: jnp.array

# utils

def linear_schedule(count, args: Args) -> float:
    # anneal learning rate linearly after one training iteration which contains
    # (args.num_minibatches * args.update_epochs) gradient updates
    frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_iterations
    return args.learning_rate * frac


def step_env_wrapped(episode_stats, handle, action, step_env_fn):
    handle, (next_obs, reward, next_done, info) = step_env_fn(handle, action)
    new_episode_return = episode_stats.episode_returns + info["reward"]
    new_episode_length = episode_stats.episode_lengths + 1
    episode_stats = episode_stats.replace(
        episode_returns=(new_episode_return) * (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"]),
        episode_lengths=(new_episode_length) * (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"]),
        # only update the `returned_episode_returns` if the episode is done
        returned_episode_returns=jnp.where(
            info["terminated"] + info["TimeLimit.truncated"], new_episode_return, episode_stats.returned_episode_returns
        ),
        returned_episode_lengths=jnp.where(
            info["terminated"] + info["TimeLimit.truncated"], new_episode_length, episode_stats.returned_episode_lengths
        ),
    )
    return episode_stats, handle, (next_obs, reward, next_done, info)


def check_and_process_args(args: Args):
    args.batch_size = int(args.num_envs * args.num_steps)
    args.minibatch_size = int(args.batch_size // args.num_minibatches)
    args.num_iterations = args.total_timesteps // args.batch_size
    if args.parallel_envs_config:
        # parse parallel_envs_config string
        config = dict(item.split('=') for item in args.parallel_envs_config.split(','))
        args.num_envs = int(config.get('n_env', args.num_envs))
        args.num_steps = int(config.get('n_steps', args.num_steps))
        args.batch_size = int(args.num_envs * args.num_steps)
        args.minibatch_size = int(args.batch_size // args.num_minibatches)
        args.num_iterations = args.total_timesteps // args.batch_size

    if args.random_env_id:
        env_list = list(atari_human_normalized_scores.keys())
        args.env_id = env_list[np.random.randint(len(env_list))]
        print(f'Environment drawn: {args.env_id}')

    if args.lipschitz_nets and args.lambda_gp > 0.:
        # raise warning if using both lipschitz_nets and lambda_gp > 0
        warnings.warn("You should not use both `lipschitz_nets` and `lambda_gp > 0` at the same time. "
                      "Setting `lambda_gp` to 0 to avoid unnecessary computation.")
        args.lambda_gp = 0

    if args.transition_density in [TransitionDensity.MIXTURE_NORMAL, TransitionDensity.NORMAL]:
        args.use_gumbel_softmax = False

    if args.transition_density == TransitionDensity.CATEGORICAL:
        if args.lambda_gp > 0. or args.lipschitz_nets:
            # raise warning if using both CATEGORICAL transition density and lambda_gp > 0
            warnings.warn("You should not use `CATEGORICAL` transition density with `lambda_gp > 0` or Lipschitz networks. "
                          "Setting `lambda_gp` to 0 and lipschitz_net to False"
                          " to avoid unnecessary computation.")
            args.lipschitz_nets = False
            args.lambda_gp = 0.
        if args.layer_norm_cnn_output:
            # raise warning if using layer_norm_cnn_output with CATEGORICAL transition density
            warnings.warn("Using `layer_norm_cnn_output` with `CATEGORICAL` transition density is not forbidden. "
                          "Setting `layer_norm_cnn_output` to False.")
            args.layer_norm_cnn_output = False

    if args.transition_density != TransitionDensity.CATEGORICAL and args.auxiliary_task_net_type != NetType.CONV:
        # raise warning if using soft_moe with non-CATEGORICAL transition density
        warnings.warn("For non-CATEGORICAL transition density, only `CONV` is implemented for `auxiliary_task_net_type`")
        args.auxiliary_task_net_type = NetType.CONV

    if not args.drift_formulation:
        args.drift_coef = 1.

    if args.hp_tuning_mode:
        if args.transition_density == TransitionDensity.MIXTURE_NORMAL:
            args.use_wgan_gp = False
            args.lipschitz_nets = True  # otherwise too slow
        if args.wasserstein_discriminator:
            args.transition_loss_coef *= 100

@flax.struct.dataclass
class DeepSPIAgent:
    actor: Actor
    critic: Critic
    actor_conv: NetworkConv
    actor_fc: NetworkFCOutput
    critic_conv: NetworkConv
    critic_fc: NetworkFCOutput
    transition_network: DiscreteActionTransitionNetwork
    reward_network: DiscreteActionRewardNetwork
    done_predictor: DonePredictor
    discriminator: Optional[Union[LipschitzDiscriminator, BoundedDiscriminator]]
    train_state: TrainState

    # compiled functions
    pi_sample: Callable = struct.field(pytree_node=False)
    get_action_and_value: Callable = struct.field(pytree_node=False)
    compute_auxiliary_losses: Callable = struct.field(pytree_node=False)
    get_action_and_value2: Callable = struct.field(pytree_node=False)
    compute_gae_once: Callable = struct.field(pytree_node=False)
    compute_gae: Callable = struct.field(pytree_node=False)
    ppo_loss_grad_fn: Callable = struct.field(pytree_node=False)
    update_ppo: Callable = struct.field(pytree_node=False)
    step_once: Callable = struct.field(pytree_node=False)
    rollout: Callable = struct.field(pytree_node=False)

    @classmethod
    def create(
            cls,
            args: Args,
            envs: gym.vector.VectorEnv,
            raw_step_env: Callable,
            key: jax.random.PRNGKey,
            use_done_predictor: bool = False,
            latent_obs: bool = False
    ) -> "DeepSPIAgent":
        """
        Factory method to create a DeepSPIAgent instance.
        This method initializes the agent's networks, prepares the training state,
        and jit compiles the necessary functions.
        """
        network_key, actor_key, critic_key, transition_key, reward_key = jax.random.split(key, 5)

        _step_env_wrapped = partial(step_env_wrapped, step_env_fn=raw_step_env)

        fields = cls._build_networks(
            args, envs, transition_key, reward_key, network_key, actor_key, critic_key, use_done_predictor, latent_obs)

        # JIT PPO functions
        fields["pi_sample"] = partial(cls._pi_sample, actor=fields["actor"])
        get_action_and_value = partial(
            cls._get_action_and_value, args=args, critic=fields["critic"], actor_conv=fields["actor_conv"],
            actor_fc=fields["actor_fc"], critic_conv=fields["critic_conv"], critic_fc=fields["critic_fc"],
            pi_sample_fn=fields["pi_sample"], latent_obs=latent_obs)
        fields["get_action_and_value"] = jax.jit(get_action_and_value)

        compute_auxiliary_losses = partial(
            cls._compute_auxiliary_losses, args=args, transition_network=fields["transition_network"],
            reward_network=fields["reward_network"], envs=envs, pi_sample_fn=fields["pi_sample"],
            actor_fc=fields["actor_fc"],)
        fields["compute_auxiliary_losses"] = jax.jit(compute_auxiliary_losses)

        get_action_and_value2 = partial(
            cls._get_action_and_value2, args=args, actor=fields["actor"], actor_conv=fields["actor_conv"],
            actor_fc=fields["actor_fc"], critic=fields["critic"], critic_conv=fields["critic_conv"],
            critic_fc=fields["critic_fc"], compute_auxiliary_losses=fields["compute_auxiliary_losses"],
            latent_obs=latent_obs)
        fields["get_action_and_value2"] = jax.jit(get_action_and_value2)

        fields["compute_gae_once"] = partial(cls._compute_gae_once, gamma=args.gamma, gae_lambda=args.gae_lambda)
        compute_gae = partial(
            cls._compute_gae,  args=args,
            actor_conv=fields["actor_conv"], actor_fc=fields["actor_fc"],
            critic=fields["critic"], critic_conv=fields["critic_conv"], critic_fc=fields["critic_fc"],
            compute_gae_once=fields["compute_gae_once"], latent_obs=latent_obs)
        fields["compute_gae"] = jax.jit(compute_gae)

        ppo_loss = partial(
            cls._ppo_loss, hist_logprobs=None, args=args,
            get_action_and_value2=fields["get_action_and_value2"],
            off_policy_correction=False)
        fields["ppo_loss_grad_fn"] = jax.value_and_grad(ppo_loss, has_aux=True)

        update_ppo = partial(cls._update_ppo, args=args, ppo_loss_grad_fn=fields["ppo_loss_grad_fn"])
        fields["update_ppo"] = jax.jit(update_ppo)

        fields["step_once"] = partial(
            cls._step_once, env_step_fn=_step_env_wrapped, get_action_and_value_fn=fields["get_action_and_value"])
        fields["rollout"] = partial(cls._rollout, step_once_fn=fields["step_once"], max_steps=args.num_steps)

        return cls(**fields)

    @staticmethod
    def _pi_sample(
            params: FullParams,
            actor_hidden: np.ndarray,
            key: jax.random.PRNGKey,
            actor: Actor,
    ):
        logits = actor.apply(params.agent.actor_params, actor_hidden)

        key, subkey = jax.random.split(key)
        u = jax.random.uniform(subkey, shape=logits.shape)
        action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1)
        logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action]

        return action, logprob, key

    @staticmethod
    def _get_action_and_value(
            agent_state: TrainState,
            obs: np.ndarray,
            key: jax.random.PRNGKey,
            args: Args,
            critic: Critic,
            actor_conv: NetworkConv,
            actor_fc: NetworkFCOutput,
            critic_conv: NetworkConv,
            critic_fc: NetworkFCOutput,
            pi_sample_fn: Callable,
            latent_obs: bool = False,
    ):
        """sample action, calculate value, logprob, and update storage"""
        actor_conv_params, actor_fc_params = agent_state.params.agent.actor_network_params
        critic_conv_params, critic_fc_params = agent_state.params.agent.critic_network_params

        if latent_obs:
            z = obs
        else:
            z = actor_conv.apply(actor_conv_params, obs)
        actor_hidden = actor_fc.apply(actor_fc_params, z)

        if not args.decoupled_repr:
            critic_hidden = actor_hidden
        else:
            if latent_obs:
                critic_hidden = z
            else:
                critic_hidden = critic_conv.apply(critic_conv_params, obs)
            critic_hidden = critic_fc.apply(critic_fc_params, critic_hidden)

        action, logprob, key = pi_sample_fn(agent_state.params, actor_hidden, key)

        value = critic.apply(agent_state.params.agent.critic_params, critic_hidden).squeeze(1)
        return action, logprob, value, key

    @staticmethod
    def compute_gradient_penalty(
            params: FullParams,
            z: np.ndarray,
            action: np.ndarray,
            z_prime: np.ndarray,
            z_prime_sampled: np.ndarray,
            key: jax.random.PRNGKey,
            args: Args,
            transition_network_params: flax.core.FrozenDict,
            reward_network_params: flax.core.FrozenDict,
            envs: gym.vector.VectorEnv,
            transition_network: DiscreteActionTransitionNetwork,
            reward_network: DiscreteActionRewardNetwork,
            actor_fc: NetworkFCOutput,
            pi_sample_fn: Callable,
    ):
        B, H, W, C = z.shape
        flat_size = H * W * C
        A = envs.action_space.n

        def flatten_za(z, a):
            z_flat = jnp.reshape(z, (B, flat_size))
            a_onehot = jax.nn.one_hot(a, A)
            return jnp.concatenate([z_flat, a_onehot], -1)

        def decode(za_vec):
            z_flat = za_vec[..., :flat_size]  # (3136, )
            z_conv = z_flat.reshape((H, W, C))[None, ...]  # (H, W, C)
            a_int = jnp.argmax(za_vec[..., flat_size:], axis=-1)[None]  # (1,) int actions
            return z_conv, a_int

        def transition_network_apply(_za, key):
            _z, _a = decode(_za)
            dist = transition_network.apply(transition_network_params, (_z, _a))
            return dist.sample(seed=key)

        def reward_network_apply(_za, key):
            _z, _a = decode(_za)
            return reward_network.apply(reward_network_params, (_z, _a))

        if args.use_wgan_gp:
            actor_hidden_1 = actor_fc.apply(params.agent.actor_network_params[1], z_prime)
            actor_hidden_2 = actor_fc.apply(params.agent.actor_network_params[1], z_prime_sampled)
            action_1, _, key = pi_sample_fn(params, actor_hidden_1, key)
            action_2, _, key = pi_sample_fn(params, actor_hidden_2, key)
            z_prime_a = flatten_za(z_prime, jax.lax.stop_gradient(action_1))
            z_prime_a = jax.lax.stop_gradient(z_prime_a)
            z_prime_sampled_a = flatten_za(z_prime_sampled, jax.lax.stop_gradient(action_2))
            z_prime_sampled_a = jax.lax.stop_gradient(z_prime_sampled_a)
            transition_gp, key = wgan_gp(
                transition_network_apply,
                z_prime_a, z_prime_sampled_a,
                key)
            reward_gp, key = wgan_gp(
                reward_network_apply,
                z_prime_a, z_prime_sampled_a,
                key)
        else:
            z_a = flatten_za(z, action)
            z_a = jax.lax.stop_gradient(z_a)
            key, sub = jax.random.split(key)
            keys = jax.random.split(sub, z_a.shape[0])  # (B, 2)

            transition_gp = lipschitz_gp(
                transition_network_apply,
                z_a, keys)
            reward_gp = lipschitz_gp(
                reward_network_apply,
                z_a, keys)

        return transition_gp + reward_gp

    @staticmethod
    def _compute_auxiliary_losses(
            params: FullParams,
            obs: np.ndarray,
            z: np.ndarray,
            action: np.ndarray,
            reward: np.ndarray,
            z_prime: np.ndarray,
            next_done: Optional[np.ndarray],
            key: jax.random.PRNGKey,
            args: Args,
            transition_network: DiscreteActionTransitionNetwork,
            reward_network: DiscreteActionRewardNetwork,
            actor_fc: NetworkFCOutput,
            envs: gym.vector.VectorEnv,
            pi_sample_fn: Callable,
            done_predictor: Optional[DonePredictor] = None,
            discriminator: Optional[Union[LipschitzDiscriminator, BoundedDiscriminator]] = None
    ):
        """Compute auxiliary losses for the world model"""
        transition_network_params = params.world_model.transition_network_params
        reward_network_params = params.world_model.reward_network_params
        done_params = params.world_model.done_predictor_params

        # Transition loss
        dist = transition_network.apply(transition_network_params, (z, action))

        dont_sample_now = args.transition_density == TransitionDensity.CATEGORICAL and \
                          args.categorical_cost in [CategoricalCost.CROSS_ENTROPY, CategoricalCost.JENSEN_SHANNON]
        if dont_sample_now:
            z_prime_sampled = None
        else:
            key, subkey = jax.random.split(key)
            z_prime_sampled = dist.sample(seed=subkey)

        if args.transition_loss_coef == 0.:
            transition_loss = 0.
        else:
            if args.wasserstein_discriminator and discriminator is not None:
                discriminator_params = params.discriminator
                def transition_loss_fn(z_1, z_2):
                    # note that here, z_1 will be attributed to the "fake" state, coming from the world model, while
                    # z_2 will be attributed to z_prime, coming from the true environment
                    # discriminator loss
                    f_1 = discriminator.apply(discriminator_params, (obs, action, z, z_2))  # true
                    f_2 = discriminator.apply(discriminator_params, (obs, action, z, z_1))  # fake
                    return f_1 - f_2  # distinguish fake from true
            elif args.transition_density in [TransitionDensity.NORMAL, TransitionDensity.MIXTURE_NORMAL] \
                    and not args.lipschitz_nets:
                # closed form loss for Normal/MixtureNormal of E_{z_sampled ~ p(z_sampled|z,a)}[z' - z_sampled]
                # see https://en.wikipedia.org/wiki/Folded_normal_distribution
                def transition_loss_fn(_, z_2):
                    if args.transition_density == TransitionDensity.NORMAL:
                        mu, sigma = dist.distribution.loc, dist.distribution.scale  # (B, H, W, C)
                        _z_prime = z_2
                    else:
                        mu, sigma = dist.loc, dist.scale  # (B, K, H, W, C); K is number of components
                        _z_prime = z_2[:, None, ...]

                    delta = (_z_prime - mu) / (sigma + 1e-8)
                    term_1 = delta * (2. * jax.scipy.stats.norm.cdf(delta) - 1.)
                    term_2 = jnp.sqrt(2. / jnp.pi) * jnp.exp(-.5 * delta**2.)
                    expected_abs_diff = sigma * (term_1 + term_2)
                    if args.transition_density == TransitionDensity.NORMAL:
                        return smooth_l1_loss(
                            expected_abs_diff, 0., reduction='sum', axis=[1, 2, 3]) # norm; (B, )
                    else:
                        weights = dist.weights  # (B, K,)
                        # sum over components K
                        weighted_sum = jnp.einsum('bkhwc,bk->bhwc', expected_abs_diff, weights)
                        # norm; (B, H, W, C) -> (B, )
                        return smooth_l1_loss(weighted_sum,0., reduction='sum', axis=[1, 2, 3])

            elif args.transition_density == TransitionDensity.CATEGORICAL:
                sample_now = dont_sample_now  # True if CategoricalCost is CROSS_ENTROPY or JENSEN_SHANNON
                key, subkey = jax.random.split(key) if sample_now else (key, key)
                transition_loss_fn = {
                    CategoricalCost.HAMMING: lambda z_1, z_2: hamming_distance(z_1, z_2),
                    # balancing is not implemented for CROSS_ENTROPY and JENSEN_SHANNON
                    CategoricalCost.CROSS_ENTROPY: lambda z_1, z_2: jnp.sum(dist.relaxed_cross_entropy(z, subkey), axis=-1),
                    CategoricalCost.JENSEN_SHANNON: lambda z_1, z_2: jnp.sum(dist.relaxed_js_distance(z, subkey), axis=-1),
                    CategoricalCost.L2: lambda z_1, z_2: smooth_l1_loss(z_1, z_2, reduction='sum', axis=[1, 2]) / z_2.shape[1],
                }[args.categorical_cost]
            else:
                def transition_loss_fn(z_1, z_2):
                    reduction = 'mean' if args.transition_density == TransitionDensity.DETERMINISTIC else 'sum'
                    return smooth_l1_loss(z_1, z_2, reduction=reduction, axis=[1, 2, 3])

            if args.w_balancing:
                # Wasserstein balancing
                transition_loss = \
                    args.w_balancing_weight * transition_loss_fn(z_prime_sampled, jax.lax.stop_gradient(z_prime)) + \
                    (1 - args.w_balancing_weight) * transition_loss_fn(jax.lax.stop_gradient(z_prime_sampled), z_prime)
            else:
                transition_loss = transition_loss_fn(z_prime_sampled, z_prime)

        # Reward loss
        if args.reward_loss_coef == 0.:
            reward_loss = 0.
        else:
            reward_pred = reward_network.apply(reward_network_params, (z, action))
            reward_loss = smooth_l1_loss(reward_pred, reward, reduction='none')
            if done_predictor is not None:
                done_dist = done_predictor.apply(done_params, (z_prime))
                key, subkey = jax.random.split(key)
                # next_done_pred = done_dist.sample(seed=subkey).astype(jnp.float32)
                next_done_pred = done_dist.probs_parameter() + 1e-8
                done_loss = smooth_l1_loss(next_done_pred, next_done, reduction='none')
                reward_loss += done_loss

        # Gradient penalty
        if args.lambda_gp > 0.:
            gradient_penalty = DeepSPIAgent.compute_gradient_penalty(
                params, z, action, z_prime, z_prime_sampled, key,
                args, transition_network_params, reward_network_params,
                envs=envs,
                transition_network=transition_network,
                reward_network=reward_network,
                actor_fc=actor_fc,
                pi_sample_fn=pi_sample_fn,
            )

        else:
            gradient_penalty = 0.

        return transition_loss, reward_loss, gradient_penalty, key

    @staticmethod
    def _get_action_and_value2(
            params: FullParams,
            obs: np.ndarray,
            action: np.ndarray,
            reward: np.ndarray,
            done: np.array,
            next_obs: np.ndarray,
            next_done: np.ndarray,
            key: jax.random.PRNGKey,
            args: Args,
            actor: Actor,
            actor_conv: NetworkConv,
            actor_fc: NetworkFCOutput,
            critic: Critic,
            critic_conv: NetworkConv,
            critic_fc: NetworkFCOutput,
            compute_auxiliary_losses: Callable,
            latent_obs: bool = False,
    ):
        actor_conv_params, actor_fc_params = params.agent.actor_network_params
        critic_conv_params, critic_fc_params = params.agent.critic_network_params

        if latent_obs:
            z = obs
            z_prime = next_obs
        else:
            z = actor_conv.apply(actor_conv_params, obs)
            z_prime = actor_conv.apply(actor_conv_params, next_obs)

        expand = (None,) * (z.ndim - 1)  # expand done to match z except batch dim
        mask = done[(...,) + expand]  # broadcast along all other dims
        z_prime = jnp.where(mask, z, z_prime)
        actor_hidden = actor_fc.apply(actor_fc_params, z)

        if not args.decoupled_repr:
            critic_hidden = actor_hidden
        else:
            if latent_obs:
                critic_hidden = z
            else:
                critic_hidden = critic_conv.apply(critic_conv_params, obs)
            critic_hidden = critic_fc.apply(critic_fc_params, critic_hidden)

        logits = actor.apply(params.agent.actor_params, actor_hidden)
        logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action]

        # Entropy
        logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True)
        logits = logits.clip(min=jnp.finfo(logits.dtype).min)
        p_log_p = logits * jax.nn.softmax(logits)
        entropy = -p_log_p.sum(-1)

        # Auxiliary losses
        if args.transition_loss_coef == 0. and args.reward_loss_coef == 0.:
            transition_loss = jnp.array(0.0)
            reward_loss = jnp.array(0.0)
            gradient_penalty = jnp.array(0.0)
        else:
            transition_loss, reward_loss, gradient_penalty, key = compute_auxiliary_losses(
                params, obs, z, action, reward, z_prime, next_done, key)

        value = critic.apply(params.agent.critic_params, critic_hidden).squeeze()
        return logprob, entropy, value, transition_loss, reward_loss, gradient_penalty, key

    @staticmethod
    def _compute_gae_once(carry, inp, gamma, gae_lambda):
        advantages = carry
        nextdone, nextvalues, curvalues, reward = inp
        nextnonterminal = 1.0 - nextdone

        delta = reward + gamma * nextvalues * nextnonterminal - curvalues
        advantages = delta + gamma * gae_lambda * nextnonterminal * advantages
        return advantages, advantages

    @staticmethod
    def _compute_gae(
            agent_state: TrainState,
            next_obs: np.ndarray,
            next_done: np.ndarray,
            storage: Storage,
            args: Args,
            actor_conv: NetworkConv,
            actor_fc: NetworkFCOutput,
            critic: Critic,
            critic_conv: NetworkConv,
            critic_fc: NetworkFCOutput,
            compute_gae_once: Callable,
            latent_obs: bool = False,
    ):
        actor_conv_params, actor_fc_params = agent_state.params.agent.actor_network_params
        critic_conv_params, critic_fc_params = agent_state.params.agent.critic_network_params

        if latent_obs:
            z_prime = next_obs
        elif args.decoupled_repr:
            z_prime = critic_conv.apply(critic_conv_params, next_obs)
        else:
            z_prime = actor_conv.apply(actor_conv_params, next_obs)

        if args.decoupled_repr:
            next_hidden = critic_fc.apply(critic_fc_params, z_prime)
        else:
            next_hidden = actor_fc.apply(actor_fc_params, z_prime)

        next_value = critic.apply(agent_state.params.agent.critic_params, next_hidden).squeeze()

        dones = jnp.concatenate([storage.dones, next_done[None, :]], axis=0)
        values = jnp.concatenate([storage.values, next_value[None, :]], axis=0)
        advantages = jnp.zeros((jnp.shape(values)[1],))
        _, advantages = jax.lax.scan(
            compute_gae_once, advantages, (dones[1:], values[1:], values[:-1], storage.rewards), reverse=True
        )
        storage = storage.replace(
            advantages=advantages,
            returns=advantages + storage.values,
        )
        return storage

    @staticmethod
    def _ppo_loss(
            params,
            obs,
            a,
            logp,
            mb_advantages,
            mb_returns,
            reward,
            done,
            next_obs,
            next_done,
            key,
            hist_logprobs,  # only used if args.drift_formulation is toggled on and off_policy_correction is True
            args: Args,
            get_action_and_value2: Callable,
            off_policy_correction: bool = False,
            use_v_trace: bool = False
    ):
        newlogprob, entropy, newvalue, transition_loss, reward_loss, gradient_penalty, key = get_action_and_value2(
            params, obs, a, reward, done, next_obs, next_done, key)
        logratio = newlogprob - logp
        ratio = jnp.exp(logratio)
        approx_kl = ((ratio - 1) - logratio).mean()
        if not args.piecewise_auxiliary_ratio:
            reward_loss = reward_loss.mean()
            transition_loss = transition_loss.mean()
        if use_v_trace:
            # V-trace style correction
            v_trace_log_clip = getattr(args, "vtrace_log_ratio_clip", 3.)
            rho = jnp.exp(jnp.clip(logp - hist_logprobs, -v_trace_log_clip, v_trace_log_clip))
            reward_loss = rho * reward_loss
            transition_loss = rho * transition_loss

        auxiliary_loss = args.reward_loss_coef * reward_loss + args.transition_loss_coef * transition_loss

        if args.deep_mdp:
            # don't compute the auxiliary loss now, add them later (after clipping)
            deep_mdp_auxiliary_loss = auxiliary_loss
            auxiliary_loss = 0.
        else:
            deep_mdp_auxiliary_loss = 0.

        if args.norm_adv:
            mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

        clip_low = 1.0 - args.clip_coef
        clip_high = 1.0 + args.clip_coef
        clipped_ratio = jnp.clip(ratio, clip_low, clip_high)

        if args.drift_formulation:
            is_alive = 1.0 - done.astype(jnp.float32)
            if off_policy_correction and not use_v_trace:
                log_is_ratio = newlogprob - hist_logprobs
                log_is_ratio = jnp.clip(log_is_ratio, a_max=7.0)  # avoid is ratio explosion
                is_ratio = jnp.exp(log_is_ratio)
                # avoid huge ratios when episode ends (may occurs off_policy)
                is_alive = 1.0 - done.astype(jnp.float32)
                # if done, set is_ratio to 1 (no action nor correction to be applied in a terminal state)
                is_ratio = is_ratio * is_alive + (1.0 - is_alive)

                weighted_advantage = is_ratio * (mb_advantages - auxiliary_loss)
            else:
                weighted_advantage = ratio * (mb_advantages - auxiliary_loss)

            # Apply drift penalty \mathfrak{D}_{π_n}(π_{n + 1} | s)
            drift_penalty = jax.nn.relu((ratio - clipped_ratio) * (mb_advantages - auxiliary_loss))

            if off_policy_correction and not use_v_trace:
                log_is_ratio = logp - hist_logprobs
                log_is_ratio = jnp.clip(log_is_ratio, a_max=7.0)  # avoid is ratio explosion
                is_ratio = jnp.exp(log_is_ratio)
                # same as above, avoid huge ratios when episode ends
                is_ratio = is_ratio * is_alive + (1.0 - is_alive)

                drift_penalty = is_ratio * drift_penalty

            # Subtract the drift term from the utility
            pg_loss = (-weighted_advantage + args.drift_coef * drift_penalty).mean()
        else:
            # Standard clipped PPO
            pg_loss1 = -(mb_advantages - auxiliary_loss) * ratio
            pg_loss2 = -(mb_advantages - auxiliary_loss) * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
            pg_loss = jnp.maximum(pg_loss1, pg_loss2).mean()

        # Value loss
        if args.use_huber:
            v_loss = jnp.mean(optax.huber_loss(newvalue - mb_returns, delta=1.0))
        else:
            v_loss = 0.5 * ((newvalue - mb_returns) ** 2).mean()

        if args.drift_formulation:
            drift_penalty_mean = drift_penalty.mean()
        else:
            drift_penalty_mean = jnp.array(0.0)

        reward_loss = reward_loss.mean()
        transition_loss = transition_loss.mean()
        entropy_loss = entropy.mean()
        loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + args.lambda_gp * gradient_penalty

        if args.deep_mdp:
            loss += deep_mdp_auxiliary_loss.mean()

        return loss, (
            pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl),
            drift_penalty_mean, transition_loss, reward_loss, gradient_penalty, key)

    @staticmethod
    def _mean_discrepancy(
            params: FullParams,
            obs: np.ndarray,
            action: np.ndarray,
            z: np.ndarray,
            z_prime: np.ndarray,
            z_prime_sampled: np.ndarray,
            # to be fixed
            discriminator: Union[LipschitzDiscriminator, BoundedDiscriminator],
    ):
        """
        Gives the mean discrepancy between the two distributions, according to a discriminator.
        When the discriminator is 1-Lipschitz and its parameters are maximized, this is the Wasserstein distance.
        When it is 1/2-bounded and maximized, this is the total variation distance.
        """
        discriminator_params = params.discriminator
        f_1 = discriminator.apply(discriminator_params, (obs, action, z, z_prime))
        f_2 = discriminator.apply(discriminator_params, (obs, action, z, z_prime_sampled))
        div = f_1 - f_2
        return - jnp.mean(div)

    @staticmethod
    def _update_ppo(
            agent_state: TrainState,
            storage: Storage,
            key: jax.random.PRNGKey,
            args: Args,
            ppo_loss_grad_fn: Callable,
    ):
        def update_epoch(carry, unused_inp):
            agent_state, key = carry
            key, subkey = jax.random.split(key)

            def flatten(x):
                return x.reshape((-1,) + x.shape[2:])

            # taken from: https://github.com/google/brax/blob/main/brax/training/agents/ppo/train.py
            def convert_data(x: jnp.ndarray):
                x = jax.random.permutation(subkey, x)
                x = jnp.reshape(x, (args.num_minibatches, -1) + x.shape[1:])
                return x

            flatten_storage = jax.tree_map(flatten, storage)
            shuffled_storage = jax.tree_map(convert_data, flatten_storage)

            def update_minibatch(carry, minibatch):
                agent_state, key = carry
                (loss, (pg_loss, v_loss, entropy_loss, approx_kl, drift_penalty_mean, transition_loss, reward_loss, gradient_penalty, key)), grads = ppo_loss_grad_fn(
                    agent_state.params,
                    minibatch.obs,
                    minibatch.actions,
                    minibatch.logprobs,
                    minibatch.advantages,
                    minibatch.returns,
                    minibatch.rewards,
                    minibatch.dones,
                    minibatch.next_obs,
                    minibatch.next_dones,
                    key,
                )
                agent_state = agent_state.apply_gradients(grads=grads)
                return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, drift_penalty_mean, transition_loss, reward_loss, gradient_penalty, grads)

            (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, drift_penalty_mean, transition_loss, reward_loss, gradient_penalty, grads) = jax.lax.scan(
                update_minibatch, (agent_state, key), shuffled_storage
            )
            return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, drift_penalty_mean, transition_loss, reward_loss, gradient_penalty, grads)

        (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, drift_penalty_mean, transition_loss, reward_loss, gradient_penalty, grads) = jax.lax.scan(
            update_epoch, (agent_state, key), (), length=args.update_epochs
        )
        return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, drift_penalty_mean, transition_loss, reward_loss, gradient_penalty, grads, key

    @staticmethod
    def _step_once(carry, step, env_step_fn, get_action_and_value_fn):
        # based on https://github.dev/google/evojax/blob/0625d875262011d8e1b6aa32566b236f44b4da66/evojax/sim_mgr.py
        agent_state, episode_stats, obs, done, key, handle = carry
        action, logprob, value, key = get_action_and_value_fn(agent_state, obs, key)

        episode_stats, handle, (next_obs, reward, next_done, _) = env_step_fn(episode_stats, handle, action)
        storage = Storage(
            obs=obs,
            actions=action,
            logprobs=logprob,
            dones=done,
            values=value,
            rewards=reward,
            returns=jnp.zeros_like(reward),
            advantages=jnp.zeros_like(reward),
            next_obs=next_obs,
            next_dones=next_done,
        )
        return ((agent_state, episode_stats, next_obs, next_done, key, handle), storage)

    @staticmethod
    def _rollout(agent_state, episode_stats, next_obs, next_done, key, handle, step_once_fn, max_steps):
        (agent_state, episode_stats, next_obs, next_done, key, handle), storage = jax.lax.scan(
            step_once_fn, (agent_state, episode_stats, next_obs, next_done, key, handle), (), max_steps
        )
        return agent_state, episode_stats, next_obs, next_done, storage, key, handle

    @staticmethod
    def _build_networks(
            args: Args,
            envs: gym.vector.VectorEnv,
            transition_key: jax.random.PRNGKey,
            reward_key: jax.random.PRNGKey,
            network_key: jax.random.PRNGKey,
            actor_key: jax.random.PRNGKey,
            critic_key: jax.random.PRNGKey,
            use_done_predictor: bool = False,
            semi_coupled_repr: bool = False,
    ) -> Dict:
        linear_schedule_ = partial(linear_schedule, args=args)
        dummy_input = np.array([envs.single_observation_space.sample()])
        activation = args.activation

        if args.transition_density == TransitionDensity.CATEGORICAL:
            actor_conv_encoder_cls = lambda: flax.linen.Sequential([
                NetworkConv(use_layer_norm=args.use_layer_norm, activation=activation),
                CategoricalEncoder(n_cat=args.num_categories, n_cls=args.num_classes), ]
            )
            if args.use_attention:
                assert args.num_categories == args.num_classes == 32, \
                    "Attention-based encoder currently only supports 32 categories and 32 classes."
                actor_fc_encoder_cls = NetworkAttentionOutput
            elif args.dreamer_architecture:
                # same as dreamer-v2 config
                actor_fc_encoder_cls = lambda: NetworkFCOutput(
                    layers=(400, 400, 400, 400),
                    activation=activation,
                    use_layer_norm=args.use_layer_norm,
                    hadamard_representation=args.hadamard_representation,)
            else:
                actor_fc_encoder_cls = lambda: NetworkFCOutput(
                    use_layer_norm=args.use_layer_norm,
                    activation=activation,
                    hadamard_representation=args.hadamard_representation,)

            if args.decoupled_repr:
                if semi_coupled_repr:
                    conv_encoder_cls = actor_conv_encoder_cls
                else:
                    conv_encoder_cls = lambda: NetworkConv(
                        use_layer_norm=args.use_layer_norm,
                        layer_norm_output=args.layer_norm_cnn_output,
                        activation=activation,)
                if args.dreamer_architecture:
                    # same as dreamer-v2 config
                    fc_encoder_cls = lambda: NetworkFCOutput(
                        layers=(400, 400, 400, 400),
                        activation=activation,
                        use_layer_norm=args.use_layer_norm,
                        hadamard_representation=args.hadamard_representation)
                else:
                    fc_encoder_cls = lambda: NetworkFCOutput(
                        use_layer_norm=args.use_layer_norm,
                        activation=activation,
                        hadamard_representation=args.hadamard_representation)
            else:
                conv_encoder_cls = actor_conv_encoder_cls
                fc_encoder_cls =  actor_fc_encoder_cls
        else:
            actor_conv_encoder_cls = conv_encoder_cls = lambda: NetworkConv(
                layer_norm_output=args.layer_norm_cnn_output, use_layer_norm=args.use_layer_norm, activation=activation,)
            if args.dreamer_architecture:
                # same as dreamer-v2 config
                actor_fc_encoder_cls = fc_encoder_cls = lambda: NetworkFCOutput(
                    layers=(400, 400, 400, 400),
                    activation=activation,
                    use_layer_norm=args.use_layer_norm,
                    hadamard_representation=args.hadamard_representation)
            else:
                actor_fc_encoder_cls = fc_encoder_cls = lambda: NetworkFCOutput(
                    use_layer_norm=args.use_layer_norm,
                    activation=activation,
                    hadamard_representation=args.hadamard_representation)

        if args.lipschitz_nets:
            transition_network = LipDiscreteActionTransitionCNN(
                num_actions=envs.action_space.n,
                density=TransitionDensity(args.transition_density),
                feature_group=args.use_feature_group,)
            reward_network = LipDiscreteActionRewardNetwork(num_actions=envs.action_space.n, )
        elif args.transition_density == TransitionDensity.CATEGORICAL and args.auxiliary_task_net_type == NetType.SOFTMOE:
            transition_network = DiscreteActionTransitionNetworkSoftMoE(
                num_actions=envs.action_space.n,
                num_experts=2 * envs.action_space.n,
                gumbel_softmax=args.use_gumbel_softmax)
            reward_network = DiscreteActionRewardNetworkSoftMoE(
                num_actions=envs.action_space.n,
                num_experts=2 * envs.action_space.n,)
        elif args.transition_density == TransitionDensity.CATEGORICAL and args.auxiliary_task_net_type == NetType.FC:
            transition_network = DiscreteActionTransitionNetwork(
                num_actions=envs.action_space.n,
                n_cat=args.num_categories,
                n_cls=args.num_classes,
                hidden=400 if args.dreamer_architecture else 512,
                layers=4 if args.dreamer_architecture else 1,
                density=TransitionDensity(args.transition_density),
                gumbel_softmax=args.use_gumbel_softmax,
                use_layer_norm=args.use_layer_norm,
                activation=activation,)
            reward_network = DiscreteActionRewardNetwork(
                hidden=400 if args.dreamer_architecture else 512,
                layers=4 if args.dreamer_architecture else 1,
                num_actions=envs.action_space.n,
                use_layer_norm=args.use_layer_norm,
                clip_rewards=args.clip_reward_net_output,
                activation=activation)
        elif args.transition_density == TransitionDensity.CATEGORICAL and args.auxiliary_task_net_type == NetType.TRANSFORMER:
            transition_network = AutoregressiveDiscreteActionTransitionTransformer(
                num_actions=envs.action_space.n,
                num_classes=args.num_classes,
                num_categories=args.num_categories,
                density=TransitionDensity(args.transition_density),
                gumbel_softmax=args.use_gumbel_softmax, )
            reward_network = DiscreteActionRewardNetwork(num_actions=envs.action_space.n,)
        else:
            transition_network = DiscreteActionTransitionCNN(
                num_actions=envs.action_space.n,
                density=TransitionDensity(args.transition_density),
                gumbel_softmax=args.use_gumbel_softmax,
                feature_group=args.use_feature_group,
                activation=activation)
            reward_network = DiscreteActionRewardNetwork(
                num_actions=envs.action_space.n,
                use_embedding=False,
                activation=activation)

        if args.wasserstein_discriminator:
            transition_key, discr_key = jax.random.split(transition_key)
            if args.transition_density == TransitionDensity.CATEGORICAL:
                discriminator = BoundedDiscriminator(
                    num_actions=envs.action_space.n,
                    use_cnn=args.discriminator_obs_encoder,)
            else:
                discriminator = LipschitzDiscriminator(
                    num_actions=envs.action_space.n,
                    use_cnn=args.discriminator_obs_encoder,)
        else:
            discriminator = discr_key = None

        if use_done_predictor:
            reward_key, done_key = jax.random.split(reward_key)
            done_predictor = DonePredictor(activation=activation, use_layer_norm=args.use_layer_norm)
        else:
            done_predictor = done_key = None

        def _initialize_world_model_params(conv_out):
            transition_network_params = transition_network.init(transition_key, (conv_out, jnp.zeros((conv_out.shape[0],), dtype=jnp.int32)))
            reward_network_params = reward_network.init(reward_key, (conv_out, jnp.zeros((conv_out.shape[0],), dtype=jnp.int32)))
            if use_done_predictor:
                # Initialize done predictor if required
                done_predictor_params = done_predictor.init(done_key, conv_out)
            else:
                done_predictor_params = None

            return WorldModelParams(
                transition_network_params=transition_network_params,
                reward_network_params=reward_network_params,
                done_predictor_params=done_predictor_params,
            )

        def _initialize_discriminator(conv_out):
            if args.wasserstein_discriminator:
                return discriminator.init(
                    discr_key,
                    (dummy_input, jnp.zeros((conv_out.shape[0],)),conv_out, conv_out),
                )
            else:
                return None

        if args.decoupled_repr:
            # Decoupled actor/critic encoders
            actor_conv = actor_conv_encoder_cls()
            actor_fc = actor_fc_encoder_cls()
            if semi_coupled_repr:
                critic_conv = actor_conv
            else:
                critic_conv = conv_encoder_cls()
            critic_fc = fc_encoder_cls()

            actor = Actor(action_dim=envs.single_action_space.n)
            critic = Critic()

            # Initialize actor encoder
            actor_conv_params = actor_conv.init(network_key, dummy_input)
            actor_conv_out = actor_conv.apply(actor_conv_params, dummy_input)
            actor_fc_params = actor_fc.init(actor_key, actor_conv_out)
            actor_features = actor_fc.apply(actor_fc_params, actor_conv_out)

            # Initialize critic encoder
            if semi_coupled_repr:
                # Semi-coupled representation: use actor conv output as input to critic
                # useful when planning in the latent space
                # (the obs for the output *is* the actor conv output = latent state)
                critic_conv_params = actor_conv_params
                critic_conv_out = actor_conv_out
            else:
                critic_conv_params = critic_conv.init(network_key, dummy_input)
                critic_conv_out = critic_conv.apply(critic_conv_params, dummy_input)
            critic_fc_params = critic_fc.init(critic_key, critic_conv_out)
            critic_features = critic_fc.apply(critic_fc_params, critic_conv_out)

            world_model_params = _initialize_world_model_params(actor_conv_out)
            discriminator_params = _initialize_discriminator(actor_conv_out)
            agent_params = AgentParams(
                actor_network_params=(actor_conv_params, actor_fc_params),
                critic_network_params=(critic_conv_params, critic_fc_params),
                actor_params=actor.init(actor_key, actor_features),
                critic_params=critic.init(critic_key, critic_features),
            )

            if args.transition_density == TransitionDensity.CATEGORICAL and args.auxiliary_task_net_type == NetType.TRANSFORMER:
                agent_state = TrainState.create(
                    apply_fn=None,
                    params=FullParams(
                        agent=agent_params, world_model=world_model_params, discriminator=discriminator_params),
                    tx=optax.chain(
                        optax.clip_by_global_norm(args.max_grad_norm),
                        optax.inject_hyperparams(optax.adamw)(
                            learning_rate=linear_schedule_ if args.anneal_lr else args.learning_rate, eps=1e-5)))
            else:
                agent_state = TrainState.create(
                    apply_fn=None,
                    params=FullParams(
                        agent=agent_params, world_model=world_model_params, discriminator=discriminator_params),
                    tx=optax.chain(
                        optax.clip_by_global_norm(args.max_grad_norm),
                        optax.inject_hyperparams(optax.adam)(
                            learning_rate=linear_schedule_ if args.anneal_lr else args.learning_rate, eps=1e-5)))

            # JIT all applies
            actor_conv.apply = jax.jit(actor_conv.apply)
            actor_fc.apply = jax.jit(actor_fc.apply)
            if not semi_coupled_repr:
                critic_conv.apply = jax.jit(critic_conv.apply)
            critic_fc.apply = jax.jit(critic_fc.apply)

        else:
            # Shared encoder (same conv+fc for both actor and critic)
            shared_conv = conv_encoder_cls()
            shared_fc = fc_encoder_cls()
            actor = Actor(action_dim=envs.single_action_space.n)
            critic = Critic()

            conv_params = shared_conv.init(network_key, dummy_input)
            conv_out = shared_conv.apply(conv_params, dummy_input)
            fc_params = shared_fc.init(actor_key, conv_out)  # just once

            features = shared_fc.apply(fc_params, conv_out)

            world_model_params = _initialize_world_model_params(conv_out)
            discriminator_params = _initialize_discriminator(conv_out)
            agent_params = AgentParams(
                actor_network_params=(conv_params, fc_params),
                critic_network_params=(conv_params, fc_params),  # placeholder but needed for shape
                actor_params=actor.init(actor_key, features),
                critic_params=critic.init(critic_key, features),
            )

            if args.transition_density == TransitionDensity.CATEGORICAL and args.auxiliary_task_net_type == NetType.TRANSFORMER:
                agent_state = TrainState.create(
                    apply_fn=None,
                    params=FullParams(
                        agent=agent_params, world_model=world_model_params, discriminator=discriminator_params),
                    tx=optax.chain(
                        optax.clip_by_global_norm(args.max_grad_norm),
                        optax.inject_hyperparams(optax.adamw)(
                            learning_rate=linear_schedule_ if args.anneal_lr else args.learning_rate, eps=1e-5)))
            else:
                agent_state = TrainState.create(
                    apply_fn=None,
                    params=FullParams(
                        agent=agent_params, world_model=world_model_params, discriminator=discriminator_params),
                    tx=optax.chain(
                        optax.clip_by_global_norm(args.max_grad_norm),
                        optax.inject_hyperparams(optax.adam)(
                            learning_rate=linear_schedule_ if args.anneal_lr else args.learning_rate, eps=1e-5),))

            # Alias for consistent access
            actor_conv = critic_conv = shared_conv
            actor_fc = critic_fc = shared_fc

            actor_conv.apply = jax.jit(actor_conv.apply)
            actor_fc.apply = jax.jit(actor_fc.apply)

        # JIT transition and reward networks apply methods
        transition_network.apply = jax.jit(transition_network.apply)
        reward_network.apply = jax.jit(reward_network.apply)
        if use_done_predictor:
            done_predictor.apply = jax.jit(done_predictor.apply)
        if args.wasserstein_discriminator:
            discriminator.apply = jax.jit(discriminator.apply)

        return {
            "actor": actor, "actor_fc": actor_fc, "actor_conv": actor_conv, "critic": critic, "critic_fc": critic_fc,
            "critic_conv": critic_conv, "transition_network": transition_network,
            "reward_network": reward_network, "done_predictor": done_predictor, "discriminator": discriminator,
            "train_state": agent_state,}


def train(
        agent: DeepSPIAgent,
        envs: gym.vector.VectorEnv,
        handle,
        args: Args,
        writer: SummaryWriter,
        key: jax.random.PRNGKey
) -> DeepSPIAgent:
    # TRY NOT TO MODIFY: start the game
    global_step = 0
    start_time = time.time()
    next_obs = envs.reset()
    next_done = jnp.zeros(args.num_envs, dtype=jax.numpy.bool_)

    try:
        import wandb
    except Exception as _:
        wandb = None

    full_obs_space = get_full_obs_space_confounding_env(envs) if args.toy_confounding_env else None
    policy_net_to_matrix = None

    episode_stats = EpisodeStatistics(
        episode_returns=jnp.zeros(args.num_envs, dtype=jnp.float32),
        episode_lengths=jnp.zeros(args.num_envs, dtype=jnp.int32),
        returned_episode_returns=jnp.zeros(args.num_envs, dtype=jnp.float32),
        returned_episode_lengths=jnp.zeros(args.num_envs, dtype=jnp.int32),
    )

    # Get the cleanRL ppo scores
    csv_path = args.compare_scores_csv
    score_col = args.compare_scores_column
    scores_mean, scores_sigma = load_env_mean_std(csv_path, score_col)
    max_avg_return = -np.infty
    agent_state = agent.train_state

    for iteration in range(1, args.num_iterations + 1):
        iteration_time_start = time.time()
        agent_state, episode_stats, next_obs, next_done, storage, key, handle = agent.rollout(
            agent_state, episode_stats, next_obs, next_done, key, handle
        )
        global_step += args.num_steps * args.num_envs
        storage = agent.compute_gae(agent_state, next_obs, next_done, storage)
        (
            agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, drift_penalty_mean,
            transition_loss, reward_loss, gradient_penalty, grads, key
         ) = agent.update_ppo(
            agent_state,
            storage,
            key,
        )
        avg_episodic_return = np.mean(jax.device_get(episode_stats.returned_episode_returns))
        max_avg_return = np.max([max_avg_return, avg_episodic_return])

        log(
            global_step, avg_episodic_return, max_avg_return, agent_state,
            v_loss, pg_loss, entropy_loss, approx_kl, drift_penalty_mean,
            transition_loss, reward_loss, scores_mean, gradient_penalty,
            start_time, iteration_time_start, episode_stats, loss, args, writer, grads)

        if args.toy_confounding_env:
            if policy_net_to_matrix is None:
                def _policy_net_to_matrix(
                        params: FullParams,
                        full_obs_space: jnp.ndarray,
                        actor: Actor,
                ) -> jnp.ndarray:
                    z = agent.actor_conv.apply(params.agent.actor_network_params[0], full_obs_space)  # (S, cat, cls)
                    actor_hidden = agent.actor_fc.apply(params.agent.actor_network_params[1], z)  # (S, H)
                    logits = actor.apply(params.agent.actor_params, actor_hidden)  # (S, A)
                    logprob = jax.nn.log_softmax(logits)  # (S, A)
                    return jnp.exp(logprob)  # (S, A)

                policy_net_to_matrix = partial(_policy_net_to_matrix, full_obs_space=full_obs_space, actor=agent.actor)
                policy_net_to_matrix = jax.jit(policy_net_to_matrix)

            params = agent_state.params
            pi = policy_net_to_matrix(params)  # (S, A)
            extra_logs = envs.log(pi, use_value_iteration=True, gamma=args.gamma, num_iters=1000)
            writer.add_scalar(f"charts/init_value", extra_logs["value_s_init"], global_step)
            print('value at s_init:', extra_logs["value_s_init"])
            pi_np = np.asarray(pi)

            # repr collapse
            s1, s2 = envs.spec.critical_states
            z = agent.actor_conv.apply(params.agent.actor_network_params[0], full_obs_space)
            z1, z2 = z[s1], z[s2]
            repr_dist = jnp.mean(jnp.square(z1 - z2))
            writer.add_scalar(f"charts/repr_dist", np.asarray(repr_dist), global_step)

            if wandb is not None and args.track:
                values = extra_logs["value_full"]
                wandb.log(
                    {
                        "charts/value_full_hist": wandb.Histogram(values),
                        # optional: log the raw vector as well (nice in tables)
                        "charts/value_full_vector": values,
                        # "charts/policy_matrix": pi_np,
                        "charts/policy_image": wandb.Image(pi_np, caption="Policy π(a|s)"),
                    },
                    step=global_step,
                )

    # update the agent's training state to its final state
    return agent.replace(train_state=agent_state)


def get_full_obs_space_confounding_env(env: gym.vector.VectorEnv) -> jnp.ndarray:
    """
    Get the full observation space as a JAX array.

    Returns:
        full_obs_space: (S, H, W, C) JAX array
    """
    S = env.spec.P.shape[0]
    obs_list = []

    for s in range(S):
        obs = confounding_grid.render_obs(
            env.spec, confounding_grid.EnvState(s=s, done=jnp.array(False), step=jnp.int32(0))
        )   # -> (1,84,84) uint8
        obs_list.append(obs)

    full_obs_space = jnp.stack(obs_list, axis=0)   # (S, 1, 84, 84) uint8
    return full_obs_space


if __name__ == "__main__":
    args = tyro.cli(Args)
    check_and_process_args(args)
    if args.wasserstein_discriminator:
        raise NotImplementedError(
            "Wasserstein discriminator is not implemented for pure representation learning agents yet. "
            "Please set --no_wasserstein_discriminator to use DeepSPI.")
    if args.toy_confounding_env:
        args.env_id = "toy_confounding_env"

    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    if args.track:
        import wandb

        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            monitor_gym=True,
            save_code=True,
            tags=args.wandb_tags
        )
        wandb.config.update(vars(args), allow_val_change=True)

        if args.hp_tuning_mode:
            # Log the final args, including any changes made by your code
            for k, v in vars(args).items():
                wandb.run.summary[f"hyperparam/{k}"] = v

    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
        )

    # TRY NOT TO MODIFY: seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    key = jax.random.PRNGKey(args.seed)

    # env setup
    if args.toy_confounding_env:
        args.reward_clip = False
        envs = ConfoundingGridEnvPoolLike(
            num_envs=args.num_envs,
            seed=args.seed,
            n_path=args.toy_confounding_env_n_path,
            epsilon=args.toy_confounding_env_epsilon,
            gamma=args.gamma,
        )
    else:
        envs = make_env(args.env_id, args.seed, args.num_envs, args.reward_clip, args.stochastic_env)()
    handle, recv, send, step_env = envs.xla()
    assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

    agent = DeepSPIAgent.create(
        args=args,
        envs=envs,
        raw_step_env=step_env,
        key=key)

    _ = train(agent, envs, handle, args, writer, key)

    envs.close()
    writer.close()
