import jax
from flax import nnx
from basics.layers import FourierFeatureNetwork
from basics.layers import create_mlp, IQNHead
from typing import Sequence
import jax.numpy as jnp

class IQNVmap(nnx.Vmap):
    def __init__(self, features_dim,
                 net_arch: Sequence[int] = (64, 64),
                 num_rewards: int = 2, *, rngs):
        super().__init__(IQNHead,
                         in_axes=None, state_axes={ nnx.Param: 0 }, out_axes=-2,
                         module_init_args=(features_dim, net_arch),
                         module_init_kwargs={ "rngs": rngs, },
                         axis_size=num_rewards)


class ContinuousCritic(nnx.Module):
    def __init__(self,
                 features_dim,
                 actions_dim,
                 num_rewards: int,
                 *,
                 rngs
                 ):
        self.num_rewards = num_rewards
        self.task_encoding = nnx.Embed(
            num_rewards, 64, rngs=rngs
        )
        self.feature = nnx.Sequential(
            FourierFeatureNetwork(features_dim, 64, rngs=rngs),
            *create_mlp(64, 64, net_arch=(64,),
                        rngs=rngs))
        self.action_emb = nnx.Sequential(
            FourierFeatureNetwork(actions_dim, 64, rngs=rngs),
            *create_mlp(64, 64, net_arch=(64,),
                        rngs=rngs))
        self.merge = nnx.Sequential(*create_mlp(64 * 3, 64, net_arch=(64, 64), rngs=rngs))

        self.iqn_head = IQNVmap(64, num_rewards=num_rewards, rngs=rngs)

    def weight_encoding(self, weight):
        w = jnp.arange(self.num_rewards)
        # (R, 128)

        w = self.task_encoding(w)

        w = jnp.einsum('...i, ij-> ...j', weight, w)
        return w

    def encode_all(self, feature, actions, weight):
        w = self.weight_encoding(weight)
        feature_emb = self.feature(feature)
        act_emb = self.action_emb(actions)
        feature = jnp.concatenate([feature_emb, act_emb, w], axis=-1)

        return self.merge(feature)

    def __call__(self, feature, actions, taus, weight):
        # (B,  obs_)
        # (B,  act_)
        # (B, N_reward, num_quantile)

        return self.iqn_head(self.encode_all(feature, actions, weight), taus)


class VMapContinuousQNet(nnx.Vmap):
    def __init__(self, features_dim, actions_dim,
                 num_rewards: int,
                 n_critics: int = 2, *, rngs):
        super().__init__(ContinuousCritic,
                         in_axes=None, state_axes={ nnx.Param: 0 }, out_axes=-1,
                         module_init_args=(features_dim, actions_dim, num_rewards),
                         module_init_kwargs={ "rngs": rngs, },
                         axis_size=n_critics)


if __name__ == '__main__':
    import mo_gymnasium
    from risk_morl.utils.env_util import reward_dim

    env = mo_gymnasium.make('mo-lunar-lander-continuous-v3')
    q_net = VMapContinuousQNet(
        env.observation_space.shape[0],
        env.action_space.shape[0],
        reward_dim(env),
        rngs=nnx.Rngs(42)
    )
    obs = env.reset()[0]
    action = env.action_space.sample()
    taus = jnp.linspace(0, 1, 32)
    weight = jax.nn.softmax(jnp.ones(shape=(reward_dim(env),)))

    print(q_net(obs, action, taus, weight).shape)
