import jax

from basics.layers import create_mlp, FourierFeatureNetwork, HyperNetwork
from flax import nnx
import distrax
import jax.numpy as jnp


class ActorBackbone(nnx.Module):
    def __init__(self,
                 embedding_dim,
                 actions_dim,
                 *,
                 rngs
                 ):
        self.mlp = nnx.Sequential(
            *create_mlp(embedding_dim, actions_dim * 2, rngs=rngs,
                        net_arch=(256, 256, 256,))
        )
        self.res = nnx.Linear(embedding_dim, actions_dim * 2, rngs=rngs)

    def __call__(self, embedding):
        return self.mlp(embedding) + self.res(embedding)


class Actor(nnx.Module):
    def __init__(self,
                 features_dim: int,
                 actions_dim: int,
                 num_rewards: int,
                 embedding_dim: int = 256,
                 *,
                 rngs
                 ):
        self.num_rewards = num_rewards
        self.proj_feature = nnx.Sequential(
            FourierFeatureNetwork(features_dim, embedding_dim,
                                  stddev=1e-4,
                                  rngs=rngs),
            nnx.Linear(embedding_dim, embedding_dim, rngs=rngs),
            nnx.LayerNorm(embedding_dim, rngs=rngs),
            nnx.relu
        )
        self.weight_extractor = nnx.Sequential(
            FourierFeatureNetwork(self.num_rewards, embedding_dim, rngs=rngs),
            nnx.Linear(embedding_dim, 32, rngs=rngs),
            nnx.LayerNorm(32, rngs=rngs),
            nnx.relu
        )

        self.layer_norm = nnx.LayerNorm(embedding_dim, rngs=rngs)
        self.backbone = ActorBackbone(embedding_dim + 32, actions_dim, rngs=rngs)
        self.rng = rngs

    def emb(self, feature, weight):
        feature = self.proj_feature(feature)
        w_bar = self.weight_extractor(weight)

        return feature, w_bar

    def distribution(self, feature, weight):
        feature, w_bar = self.emb(feature, weight)
        mu_log_std = self.backbone(jnp.concatenate([feature, w_bar], axis=-1))
        mu, log_std = jnp.split(mu_log_std, axis=-1, indices_or_sections=2)
        std = jnp.exp(log_std.clip(-20, 3) )
        normal_distribution = distrax.Normal(mu, std)
        return distrax.Transformed(normal_distribution, distrax.Tanh())

    def __call__(self, feature, weight):
        return self.distribution(feature, weight).sample(seed=self.rng())

    def sample_and_log_prob(self, feature, weight):
        a, log_p = self.distribution(feature, weight).sample_and_log_prob(seed=self.rng())
        return a, log_p.sum(axis=-1, keepdims=True)

    def deterministic(self, feature, weight):
        feature, w_bar = self.emb(feature, weight)
        mu_log_std = self.backbone(jnp.concatenate([feature, w_bar], axis=-1))
        mu, _ = jnp.split(mu_log_std, axis=-1, indices_or_sections=2)
        return jnp.tanh(mu)

    def score_sample(self, feature, weight):
        def weight_grad(feature, weight):
            distr = self.distribution(feature, weight)
            action, log_prob = distr.sample_and_log_prob(seed=self.rng())
            log_prob = log_prob.sum(axis=-1, keepdims=True)
            eps = 1e-6

            def _log_prob(action):

                log_prob = distr.log_prob(action)

                return log_prob.sum()
            # for the vmap, concatenate,
            dlog_pi_da = jax.grad(_log_prob)(action.clip(-1+eps, 1-eps))
            return dlog_pi_da, jnp.concatenate([action, dlog_pi_da, log_prob], axis=-1)

        dlog_pi_dadw, action_and_log_prob = jax.vmap(jax.jacfwd(weight_grad, has_aux=True,  argnums=1), in_axes=(0, 0),
                                                   out_axes=(0, 0))(feature, weight)
        action, nabla_log_pi_a = jnp.split(action_and_log_prob[..., :-1], axis=-1, indices_or_sections=2)
        log_prob = action_and_log_prob[..., -1:]

        return dlog_pi_dadw, action, nabla_log_pi_a, log_prob
