import optax
from basics.layers import create_mlp, FourierFeatureNetwork, PReLU, HyperNetwork
import jax
import jax.numpy as jnp
from flax import nnx
from typing import Optional
from functools import partial


class MLP(nnx.Module):
    def __init__(self, d_model: int, d_ff: int, *, rngs):
        self.fc1 = nnx.Linear(d_model, d_ff, rngs=rngs)
        self.ln = nnx.LayerNorm(d_ff, rngs=rngs)
        self.fc2 = nnx.Linear(d_ff, d_model, rngs=rngs)
        self.prelu = PReLU()

    def __call__(self, x):
        inputs = x
        x = self.fc1(x)
        x = self.ln(x)
        x = self.prelu(x)
        x = self.fc2(x)
        return x + inputs


class DecoderLayer(nnx.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, *, rngs):
        self.attn = nnx.MultiHeadAttention(
            num_heads=n_heads,
            in_features=d_model,
            qkv_features=d_model,
            out_features=d_model,
            use_bias=False,
            rngs=rngs,
        )
        self.norm1 = nnx.LayerNorm(d_model, rngs=rngs)
        self.norm2 = nnx.LayerNorm(d_model, rngs=rngs)
        self.ff = MLP(d_model, d_ff, rngs=rngs)

    def __call__(self, q, k, v, mask: Optional[jnp.ndarray] = None):
        if mask is None:
            mask = jnp.tril(jnp.ones((q.shape[-2], q.shape[-2]), dtype=jnp.bool))

        attn_out = self.attn(q, k, v, mask=mask, decode=False)
        x = self.norm1(q + attn_out)
        ff_out = self.ff(x)
        x = self.norm2(x + ff_out)
        return x


class TransformerDecoder(nnx.Module):
    def __init__(self,
                 in_features,
                 num_layers: int, d_model: int, n_heads: int, d_ff: int, *, rngs):
        self.in_layer = nnx.Linear(in_features, d_model, rngs=rngs)

        self.proj_previous = nnx.Linear(in_features, d_model, rngs=rngs)
        self.cross_attention = DecoderLayer(d_model, n_heads, d_ff, rngs=rngs)
        self.layers = [DecoderLayer(d_model, n_heads, d_ff, rngs=rngs)
                       for _ in range(num_layers)]
        self.norm_2 = nnx.LayerNorm(d_model, rngs=rngs)
        self.norm = nnx.LayerNorm(d_model, rngs=rngs)
        self.fc_out = nnx.Linear(d_model, in_features, rngs=rngs)

    def __call__(self, tokens: jax.Array, previous: jax.Array):
        x = self.in_layer(tokens)
        previous = self.proj_previous(previous)
        x = self.cross_attention(previous, x, x)
        for layer in self.layers:
            x = self.norm_2(self.norm(layer(x, x, x)) + x)
        outs = self.fc_out(x) + tokens
        return outs


class PositionalEncoding(nnx.Module):
    def __init__(self, d_model, max_len=5000):
        position = jnp.arange(max_len)[:, None]
        div_term = jnp.exp(jnp.arange(0, d_model, 2) * (-jnp.log(10000.0) / d_model))
        pe = jnp.zeros((max_len, d_model))
        pe = pe.at[:, 0::2].set(jnp.sin(position * div_term))
        pe = pe.at[:, 1::2].set(jnp.cos(position * div_term))
        self.pe = nnx.Variable(pe[None, :, :])  # shape: (1, max_len, d_model)

    def __call__(self, x: jax.Array):
        seq_len = x.shape[1]
        return x + self.pe[:, :seq_len, :]


class TransformerDecoderPositionalEncoding(nnx.Module):
    def __init__(self,
                 in_features,
                 num_layers: int, d_model: int, n_heads: int, d_ff: int, *, rngs):
        self.in_layer = nnx.Linear(in_features, d_model, rngs=rngs)

        self.proj_previous = nnx.Linear(in_features, d_model, rngs=rngs)
        self.cross_attention = DecoderLayer(d_model, n_heads, d_ff, rngs=rngs)
        self.layers = [DecoderLayer(d_model, n_heads, d_ff, rngs=rngs)
                       for _ in range(num_layers)]
        self.norm_2 = nnx.LayerNorm(d_model, rngs=rngs)
        self.norm = nnx.LayerNorm(d_model, rngs=rngs)
        self.fc_out = nnx.Linear(d_model, in_features, rngs=rngs)
        self.pos_encoding = PositionalEncoding(d_model, max_len=100)

    def __call__(self, tokens: jax.Array, previous: jax.Array):
        x = self.in_layer(tokens)
        x = self.pos_encoding(x)
        previous = self.proj_previous(previous)
        x = self.cross_attention(previous, x, x)
        for layer in self.layers:
            x = self.norm_2(self.norm(layer(x, x, x)) + x)
        outs = self.fc_out(x) + tokens
        return outs


def generate_random_permutation(x, key):
    """
    :param x: shape (sequence, emb)
    :return: random permutation
    """
    # first we have d
    index = jnp.arange(x.shape[-2])
    perm = jax.random.permutation(key, index)
    return perm


def permute_x(x, permutation):
    """
    :param x: (*batch, sequence, emb): maybe batch, maybe not
    :param permutation: (*batch, 2) : maybe batch, maybe not. return value of generated random permutation
    :return: permuted x
    """
    index = permutation
    permuted = jnp.take_along_axis(x, indices=index[..., None], axis=-2, )
    return permuted


def inverse_permute(y, permutation):
    """
    :param y: return value of permutation_x.
    :param permutation: (*batch, 2) : maybe batch, maybe not. return value of generated random permutation
    :return: must be x == inverse_permute(permute_x(x, perm), perm)
    """

    inv = jnp.argsort(permutation, axis=-1)
    return jnp.take_along_axis(y, inv[..., None], axis=-2)


class MarginalIQNHead(nnx.Module):
    def __init__(self, features_dim: int,
                 n_rewards: int,
                 *, rngs):
        self.features_dim = features_dim
        self.n_rewards = n_rewards
        self.layers = nnx.Sequential(
            *create_mlp(features_dim, features_dim, net_arch=(128,),
                        activation_fn=nnx.relu,
                        rngs=rngs)
        )

        self.taus_embedding = nnx.Sequential(
            FourierFeatureNetwork(1, 256, stddev=1e-3, rngs=rngs),
            PReLU(),
            nnx.Linear(256, self.features_dim, rngs=rngs),
            PReLU(),
        )
        self.n_rewards = n_rewards
        # nnx.List is more appropriate, but because of the version compatibility, I used nnx.Dict, instead
        self.keys = [str(i) for i in range(self.n_rewards)]
        self.mlps = nnx.Dict({ str(i): nnx.Sequential(
            *create_mlp(features_dim, 1,
                        net_arch=(256,),
                        activation_fn=nnx.relu,
                        rngs=rngs)
        ) for i in range(self.n_rewards)})

    def __call__(self, feature, taus):
        # feature (batch, emb)
        # taus: (batch, n_rewards, n_taus) ->  (batch, n_rewards, n_taus, emb)
        feature = self.layers(feature) + feature
        taus = self.taus_embedding(taus[..., None])


        return jnp.stack([self.mlps[k](feature[..., None, :] * taus[..., i, :, :]).squeeze(axis=-1) for i, k in enumerate(self.keys)], axis=-2)


class KnotheRosenblattesIQNHead(nnx.Module):
    def __init__(self,
                 features_dim,
                 reward_dim,
                 *,
                 rngs
                 ):
        self.features_dim = features_dim
        self.reward_dim = reward_dim
        self.layers = nnx.Sequential(
            *create_mlp(features_dim, features_dim, net_arch=(128,),
                        activation_fn=nnx.relu,
                        rngs=rngs)
        )
        self.taus_embedding = nnx.Sequential(
            FourierFeatureNetwork(1, 256, stddev=1e-3, rngs=rngs),
            PReLU(),
            nnx.Linear(256, self.features_dim, rngs=rngs),
            PReLU()
        )

        self.mlp = nnx.Sequential(*create_mlp(features_dim, 1,
                                              net_arch=(256,), activation_fn=nnx.relu,
                                              rngs=rngs))

    def __call__(self, feature, taus, rewards):
        """
        :param feature: shape (batch, reward_dim, emb) -> Assume that task embedding is multiplied
        (task_0, task_1, ...., task_n)
        :param taus:  (batch, reward_dim, num_taus)
        :param rewards: (batch, reward_dim) previous reward prediction for auto regressive process
        :return:
        """

        feature = self.layers(feature) + feature
        taus = self.taus_embedding(taus[..., None])
        feature = feature * taus * rewards
        return self.mlp(feature).squeeze(axis=-1)


class RewardProj(nnx.Module):
    def __init__(self, out_dim: int,
                 *, rngs):
        self.reward_proj = nnx.Sequential(
            nnx.LayerNorm(1, rngs=rngs),
            FourierFeatureNetwork(1, 64, rngs=rngs),

            nnx.relu,
            nnx.Linear(64, out_dim, use_bias=False, rngs=rngs),
        )
        self.proj = nnx.Linear(1, out_dim, rngs=rngs)

    def __call__(self, r):
        return self.reward_proj(r) + self.proj(r)


class KRIQN(nnx.Module):
    def __init__(self,
                 obs_dim: int,
                 action_dim: int,
                 reward_dim: int,
                 n_critics: int = 2,
                 *,
                 ff_action: bool = True,
                 rngs
                 ):
        self.reward_dim = reward_dim
        self.obs_extractor = nnx.Sequential(
            FourierFeatureNetwork(obs_dim, 256, stddev=1e-2, rngs=rngs),
            nnx.Linear(256, 64,
                       rngs=rngs),
            nnx.LayerNorm(64, rngs=rngs),
            nnx.relu,
        )
        self.ff_action = ff_action
        if self.ff_action:
            self.action_extractor = nnx.Sequential(
                FourierFeatureNetwork(action_dim, 64, stddev=1e-3, rngs=rngs),
                nnx.Linear(64, 64,
                           rngs=rngs),
                nnx.LayerNorm(64, rngs=rngs),
                nnx.relu,
            )
        else:
            self.action_extractor = nnx.Sequential(
                nnx.Linear(action_dim, 64, rngs=rngs),
                nnx.LayerNorm(64, rngs=rngs),
                nnx.relu,
            )
        self.weight_extractor = nnx.Sequential(
            FourierFeatureNetwork(self.reward_dim, 64, rngs=rngs),
            nnx.Linear(64, 32, rngs=rngs)
        )
        self.merge = nnx.Sequential(
            *create_mlp(128 + 32, 64, net_arch=(256,), rngs=rngs),
        )

        self.ln = nnx.LayerNorm(64, rngs=rngs)
        self.tasks = nnx.Embed(reward_dim + 1, 64, rngs=rngs, embedding_init=nnx.initializers.orthogonal())
        self.reward_proj = RewardProj(64, rngs=rngs)

        self.reward_extractor = TransformerDecoder(
            in_features=64,
            num_layers=1,
            d_model=128,
            n_heads=8,
            d_ff=128,
            rngs=rngs,
        )
        self.n_critics = n_critics
        self.iqn_heads = nnx.Vmap(KnotheRosenblattesIQNHead,
                                  state_axes={ ...: 0 }, out_axes=-1,
                                  module_init_args=(64, reward_dim,), module_init_kwargs={ "rngs": rngs },
                                  axis_size=n_critics)
        self.rngs = rngs

    def __call__(self, observations, actions, taus, weight):
        predictions = jnp.ones_like(taus[:, :1, :])  # dummy first
        predictions = jnp.repeat(predictions[..., None], axis=-1, repeats=self.n_critics)

        tasks = self.tasks.embedding.value
        tasks = jnp.repeat(tasks[None], axis=0, repeats=observations.shape[0])

        feature = self.feature_extraction(observations, actions, weight, tasks)
        previous_task = jnp.repeat(self.tasks.embedding.value[self.reward_dim][None, None], axis=0,
                                   repeats=observations.shape[0])

        for i in range(1, self.reward_dim + 1):
            predictions = jnp.concatenate([predictions,
                                           self.heading(feature[:, :i, :], taus[:, :i, :], predictions,
                                                        previous_task)[:, -1:, :]
                                           ], axis=1)
            previous_task = jnp.concatenate([previous_task, tasks[:, [(i - 1)]]], axis=1)

        predictions = predictions[:, 1:]

        return predictions

    def loss_fn(self, observations, actions, td_target, weight, key):
        """
        :param observations: (b, n_dim_obs)
        :param actions: (b, n_dim_action)
        :param td_target: (b, reward_dim, n_td_target)
        :param key: PRNGKey
        :return: loss 1
        """
        key = jax.random.split(key, 2)
        taus = jax.random.uniform(key[0], shape=td_target.shape)

        # (b, n_perm)
        permutation = jax.vmap(generate_random_permutation, in_axes=(0, 0),
                               out_axes=0)(td_target, jax.random.split(key[1], td_target.shape[0]))
        # permute td target
        tasks = self.tasks.embedding.value

        tasks = jnp.repeat(tasks[None], axis=0, repeats=observations.shape[0])
        # permutation only ranges [0, ..., w-1] therefore, null task whose index is w is never included.
        tasks = jnp.take_along_axis(tasks, axis=1, indices=permutation[..., None])
        td_target = jnp.take_along_axis(td_target, axis=1, indices=permutation[..., None])
        weight = jnp.take_along_axis(weight, axis=1, indices=permutation)

        previous_rewards = jnp.ones_like(taus[:, :1, :])  # dummy first
        previous_rewards = jnp.repeat(previous_rewards[..., None], axis=-1, repeats=self.n_critics)
        # (b, w + 1, n) -> (b, w, n), for auto regressive learning
        label = td_target
        td_target = jnp.repeat(td_target[..., None], axis=-1, repeats=self.n_critics)
        previous_rewards = jnp.concatenate([previous_rewards, td_target], axis=1)[:, :-1, ...]

        mult_feature = self.feature_extraction(observations, actions, weight, tasks)
        null_task = jnp.repeat(self.tasks.embedding.value[self.reward_dim][None, None],
                               axis=0, repeats=observations.shape[0])
        previous_task = jnp.concatenate([null_task, tasks[:, :-1]], axis=1)
        prediction = self.heading(mult_feature, taus, previous_rewards, previous_task)
        return jax.vmap(self._quantile_regression_loss, in_axes=(-1, None, None), out_axes=-1)(prediction, label, taus)

    def heading(self, feature, taus, target, previous_task):
        # target: (b, w, n, c) or (a, b, w, n, c)
        proj_r = self.reward_proj(target[..., None])

        # (b, w, n, e, c) -> (b, w, n, e, c)
        proj_r = proj_r.swapaxes(-1, -2)
        # vmap over num critic (c) and num reward (n)
        reward_emb = jax.vmap(jax.vmap(self.reward_extractor, in_axes=(-1, None), out_axes=-1
                                       ), in_axes=(-3, None), out_axes=-3)(proj_r, previous_task)

        fn = partial(nnx.vmap(KnotheRosenblattesIQNHead.__call__,
                              in_axes=(nnx.StateAxes({ nnx.Param: 0 }), None, None, -1), out_axes=-1),
                     self.iqn_heads._submodule)

        prediction = jax.vmap(jax.vmap(fn, in_axes=(None, -1, 1), out_axes=-1),
                              in_axes=(1, 1, 1), out_axes=1,
                              )(feature, taus, reward_emb)
        prediction = prediction.swapaxes(-1, -2)
        return prediction

    @staticmethod
    @jax.jit
    def _quantile_regression_loss(predictions, td_target, taus):
        """
        :param predictions: (n_dim, )
        :param td_target: (n_dim, )
        :param taus: (n_dim, )
        :return: auto regressive pinball loss

        """
        delta = (td_target - predictions)
        abs_delta = jnp.abs(delta)

        return jnp.where(delta < 0, (1 - taus) * abs_delta, taus * abs_delta)

    def feature_extraction(self, obs, actions, weight, task_embedding):
        """
        :param obs: observation (b, n_dim_obs)
        :param actions: actions (b, n_dim_action)
        :param weight: task weight (b, n_rewards), in [0, 1]. must be shuffled when task embedding is shuffled
        :param task_embedding: task embedding (b, w, emb)
        :return: feature vector (b, n_rewards, E)
        """
        obs = self.obs_extractor(obs)
        act = self.action_extractor(actions)
        w_bar = self.weight_extractor(weight)
        feature = self.merge(jnp.concatenate([obs, act, w_bar], axis=-1))
        # (batch, emb)
        # (batch, w, 1)
        feature = feature[..., None, :] * task_embedding
        return self.ln(feature)

    def marginals_of(self, observations, actions, taus, weight):
        predictions = jnp.ones_like(taus[:, :1, :])  # dummy first
        predictions = jnp.repeat(predictions[..., None], axis=-1, repeats=self.n_critics)

        tasks = self.tasks.embedding.value
        tasks = jnp.repeat(tasks[None], axis=0, repeats=observations.shape[0])
        # (b, w, n)
        feature = self.feature_extraction(observations, actions, weight, tasks[:, :-1])
        # (b, w, 1, n)
        feature = jnp.expand_dims(feature, axis=1)
        # (b, 1, n)
        previous_task = jnp.repeat(self.tasks.embedding.value[self.reward_dim][None, None], axis=0,
                                   repeats=observations.shape[0])
        # (b, w, n) -> (b, w, 1, n)
        taus = jnp.expand_dims(taus, axis=1)
        # output: (b, 2, w, n)

        fn = jax.vmap(self.heading, in_axes=(2, 2, None, None), out_axes=1)
        # (b, 1, w, n) -> (b, w, n)
        predictions = fn(feature, taus, predictions, previous_task)
        predictions = predictions[:, :, 0]
        return predictions


class KRIQNAblation(KRIQN):

    def loss_fn(self, observations, actions, td_target, weight, key):
        """
        :param observations: (b, n_dim_obs)
        :param actions: (b, n_dim_action)
        :param td_target: (b, reward_dim, n_td_target)
        :param key: PRNGKey
        :return: loss 1
        """
        key = jax.random.split(key, 2)
        taus = jax.random.uniform(key[0], shape=td_target.shape)
        # NO SHUFFLE

        # (b, n_perm)
        '''
        permutation = jax.vmap(generate_random_permutation, in_axes=(0, 0),
                               out_axes=0)(td_target, jax.random.split(key[1], td_target.shape[0]))
        '''
        # order forcing. no permutation
        permutation = jnp.repeat(jnp.arange(self.reward_dim)[None], axis=0, repeats=observations.shape[0])
        tasks = self.tasks.embedding.value

        tasks = jnp.repeat(tasks[None], axis=0, repeats=observations.shape[0])
        # no shuffle

        # permutation only ranges [0, ..., w-1] therefore, null task whose index is w is never included.
        tasks = jnp.take_along_axis(tasks, axis=1, indices=permutation[..., None])
        td_target = jnp.take_along_axis(td_target, axis=1, indices=permutation[..., None])
        weight = jnp.take_along_axis(weight, axis=1, indices=permutation)

        previous_rewards = jnp.ones_like(taus[:, :1, :])  # dummy first
        previous_rewards = jnp.repeat(previous_rewards[..., None], axis=-1, repeats=self.n_critics)
        # (b, w + 1, n) -> (b, w, n), for auto regressive learning
        label = td_target
        td_target = jnp.repeat(td_target[..., None], axis=-1, repeats=self.n_critics)
        previous_rewards = jnp.concatenate([previous_rewards, td_target], axis=1)[:, :-1, ...]

        mult_feature = self.feature_extraction(observations, actions, weight, tasks)
        null_task = jnp.repeat(self.tasks.embedding.value[self.reward_dim][None, None],
                               axis=0, repeats=observations.shape[0])
        previous_task = jnp.concatenate([null_task, tasks[:, :-1]], axis=1)
        prediction = self.heading(mult_feature, taus, previous_rewards, previous_task)
        return jax.vmap(self._quantile_regression_loss, in_axes=(-1, None, None), out_axes=-1)(prediction, label, taus)


class MarginalIQN(nnx.Module):
    def __init__(self,
                 obs_dim: int,
                 action_dim: int,
                 reward_dim: int,
                 n_critics: int = 2,
                 *,
                 ff_action: bool = True,
                 rngs
                 ):
        self.reward_dim = reward_dim
        self.n_critics = n_critics
        self.ff_action = ff_action

        self.obs_extractor = nnx.Sequential(
            FourierFeatureNetwork(obs_dim, 256, stddev=1e-2, rngs=rngs),
            nnx.Linear(256, 64, rngs=rngs),
            nnx.LayerNorm(64, rngs=rngs),
            nnx.relu,
        )

        if self.ff_action:
            self.action_extractor = nnx.Sequential(
                FourierFeatureNetwork(action_dim, 64, stddev=1e-3, rngs=rngs),
                nnx.Linear(64, 64, rngs=rngs),
                nnx.LayerNorm(64, rngs=rngs),
                nnx.relu,
            )
        else:
            self.action_extractor = nnx.Sequential(
                nnx.Linear(action_dim, 64, rngs=rngs),
                nnx.LayerNorm(64, rngs=rngs),
                nnx.relu,
            )

        self.weight_extractor = nnx.Sequential(
            FourierFeatureNetwork(self.reward_dim, 64, rngs=rngs),
            nnx.Linear(64, 32, rngs=rngs),
        )

        self.merge = nnx.Sequential(
            *create_mlp(128 + 32, 64, net_arch=(256,), rngs=rngs),
        )

        self.ln = nnx.LayerNorm(64, rngs=rngs)

        self.iqn_heads = nnx.Vmap(
            MarginalIQNHead,
            state_axes={...: 0},
            out_axes=-1,
            module_init_args=(64, ),
            module_init_kwargs={"rngs": rngs, "n_rewards": self.reward_dim},
            axis_size=n_critics,
        )
        self.rngs = rngs

    def __call__(self, observations, actions, taus, weight):
        obs = self.obs_extractor(observations)            # (b, 64)
        act = self.action_extractor(actions)              # (b, 64)
        w_bar = self.weight_extractor(weight)             # (b, 32)

        feature = self.merge(jnp.concatenate([obs, act, w_bar], axis=-1))  # (b, 64)
        feature = self.ln(feature)
        fn = partial(nnx.vmap(MarginalIQNHead.__call__,
                              in_axes=(nnx.StateAxes({ nnx.Param: 0 }), None, None), out_axes=-1),
                     self.iqn_heads._submodule)
        return fn(feature, taus)


    def loss_fn(self, observations, actions, td_target, weight, key):
        """
        :param observations: (b, n_dim_obs)
        :param actions: (b, n_dim_action)
        :param td_target: (b, reward_dim, n_td_target)
        :param key: PRNGKey
        :return: loss 1
        """
        key = jax.random.split(key, 2)
        taus = jax.random.uniform(key[0], shape=td_target.shape)
        # (batch, n_rewards, n_taus, n_critics)
        prediction = self(observations, actions, taus, weight)
        # parallel to n_critics, and n_rewards
        loss_fn = jax.vmap(jax.vmap(self._quantile_regression_loss, in_axes=(-1, None, None), out_axes=-1),
            in_axes=(1, 1, 1), out_axes=1
        )
        # (sum over n_critics, sum over reward dims)
        return loss_fn(prediction, td_target, taus)

    @staticmethod
    @jax.jit
    def _quantile_regression_loss(target, predict, taus):
        pairwise_delta = target[..., None, :] - predict[..., None]
        abs_pairwise_delta = jnp.abs(pairwise_delta)
        taus = taus[..., None]
        loss = jnp.where(pairwise_delta < 0, (1 - taus) * abs_pairwise_delta, taus * abs_pairwise_delta)
        return loss.mean()



class KRIQNAblationPositionalEncoding(KRIQN):
    def __init__(self,
                 obs_dim: int,
                 action_dim: int,
                 reward_dim: int,
                 n_critics: int = 2,
                 *,
                 ff_action: bool = True,
                 rngs
                 ):
        super().__init__(
            obs_dim, action_dim, reward_dim, n_critics, ff_action=ff_action, rngs=rngs
        )
        del self.reward_extractor

        self.reward_extractor = TransformerDecoderPositionalEncoding(
            in_features=64,
            num_layers=1,
            d_model=128,
            n_heads=8,
            d_ff=128,
            rngs=rngs,
        )


if __name__ == '__main__':
    import os
    test = MarginalIQN(6, 2, 3, 5, rngs=nnx.Rngs(42))
    key = jax.random.PRNGKey(42)
    obs_ph = jax.random.normal(key,  shape=(256, 6))
    action_ph = jax.random.normal(key, shape=(256, 2))
    w = jax.random.uniform(key, shape=(256, 3))
    w = w / w.sum(axis=-1, keepdims=True)
    taus = jax.random.uniform(key, shape=(256, 3, 32))
    td_target=  jax.random.normal(key, shape=(256, 3, 32))
    out = test.loss_fn(obs_ph, action_ph,  td_target, w, jax.random.split(key, 2)[-1])
    print(out.shape)

    """
    os.environ['CUDA_VISIBLE_DEVICES'] = '1'
    iqn = ARIQN(4, 2, 2, rngs=nnx.Rngs(42))
    keys = jax.random.split(jax.random.PRNGKey(42), 1000)
    obs_ph = jax.random.normal(keys[0], shape=(256, 4))
    action_ph = jax.random.normal(keys[1], shape=(256, 2))
    taus_ph = jax.random.uniform(keys[2], shape=(256, 2, 32))
    td_target = jnp.ones(shape=(256, 3, 32), )
    # out_call = iqn(obs_ph, action_ph, taus_ph)
    # out = iqn.loss_fn(obs_ph, action_ph, td_target, jax.random.PRNGKey(32))
    opt = nnx.Optimizer(iqn, optax.chain(
        optax.adamw(3e-4, 0.5, 0.9),
    )
                        )
    graph, state = nnx.split((iqn, opt))
    from sklearn.datasets import make_swiss_roll
    import numpy as np
    import jax_dataloader as jdl
    import matplotlib.pyplot as plt
    from tqdm import trange

    N_PARTICLE = 64
    data, _ = make_swiss_roll(n_samples=25600, random_state=42)
    data = np.stack([data[:, 0], data[:, -1]], axis=-1)
    dataset = jdl.ArrayDataset(data)
    loader = jdl.DataLoader(dataset, backend='jax', batch_size=256 * N_PARTICLE,
                            shuffle=True,
                            drop_last=True)

    dummy_obs = jnp.zeros(shape=(256, 4))
    dummy_action = jnp.zeros(shape=(256, 2))
    dummy_taus = jax.random.uniform(jax.random.PRNGKey(42), shape=(256, 2, 4))
    import cloudpickle as pickle

    with open("iqn_plotting_model.pt", 'rb') as f:
        state = pickle.load(f)
 
    @jax.jit
    def update_fn(graph, state, x, key):
        x = x.reshape(-1, N_PARTICLE, 2).swapaxes(-2, -1)  # (256, 2, N_particle)

        model, opt = nnx.merge(graph, state)
        dummy_obs = jnp.zeros(shape=(x.shape[0], 4))
        dummy_action = jnp.zeros(shape=(x.shape[0], 2))

        def loss_fn(model: ARIQN):
            # observations, actions, td_target, key
            loss = model.loss_fn(dummy_obs, dummy_action, x, jnp.zeros(shape=dummy_action.shape), key)
            return loss.sum(axis=(-1, -2, -3)).mean()

        loss, grad = nnx.value_and_grad(loss_fn)(model)
        opt.update(grad)
        _, new_state = nnx.split((model, opt))
        new_key = jax.random.split(key, 2)[-1]
        return loss, new_state, new_key


    key = jax.random.PRNGKey(32)
    for _ in trange(2000 * N_PARTICLE):
        for x, in loader:
            loss, state, key = update_fn(graph, state, x, key)
 

    iqn, opt = nnx.merge(graph, state)
    taus = dummy_taus  # dummy_taus.at[:, 0].set(dummy_taus[:, 0] * 0.5)
    out = iqn(dummy_obs, dummy_action, taus, dummy_action)[..., 0]
    # (256, 2, 4) -> (256, 4, 2) -> (1024, 2)
    t = dummy_taus.swapaxes(1, -1).reshape(-1, 2)

    colors = np.stack([t[:, 0], np.zeros_like(t[:, 0]), t[:, 1]], axis=1)
    out = out.swapaxes(1, -1).reshape(-1, 2)
    plt.rcParams['font.family'] = 'Times New Roman'

    plt.scatter(out[:, 0], out[:, 1], c=colors)
    plt.tight_layout()
    plt.show()  # 'for_fig.svg')
    """

