import flax.linen as nn
import jax.numpy as jnp
import jax

class RSObservationNorm(nn.Module):
    @nn.compact
    def __call__(self, x) -> jax.Array:
        mean = self.variable("running_obs_stats", "mean", lambda: jnp.zeros((x.shape[-1],), dtype=jnp.float64))
        var = self.variable("running_obs_stats", "var", lambda: jnp.ones((x.shape[-1],), dtype=jnp.float64))
        count = self.variable("running_obs_stats", "count", lambda: jnp.array(1.0, dtype=jnp.float64))

        mean = jax.lax.stop_gradient(mean.value)
        var = jax.lax.stop_gradient(var.value)
        norm64 = (x - mean) / jnp.sqrt(var + 1e-8)
        return norm64.astype(jnp.float32)

class SimbaV2CriticNet(nn.Module):
    # float64 should be used for normalizing the states and reward, otherwise float32 should be used
    hidden_dim = 512
    batch_norm: bool
    n_heads: int = 1
    n_bins = 101
    bins = jnp.linspace(-5, 5, 101, dtype=jnp.float32)

    def setup(self):
        self.mean = self.variable("running_reward_stats", "mean", lambda: jnp.array(0.0, dtype=jnp.float64))
        self.var = self.variable("running_reward_stats", "var", lambda: jnp.array(1.0, dtype=jnp.float64))
        self.count = self.variable("running_reward_stats", "count", lambda: jnp.array(1.0, dtype=jnp.float64))
        self.g = self.variable("running_reward_stats", "G", lambda: jnp.array(0.0, dtype=jnp.float64))
        self.g_max = self.variable("running_reward_stats", "G_max", lambda: jnp.array(0.0, dtype=jnp.float64))

    @nn.compact
    def __call__(self, state: jnp.ndarray, action: jnp.ndarray, use_running_average=False) -> jax.Array:
        x = RSObservationNorm()(jnp.squeeze(state))
        x = jnp.concatenate([x, action.astype(jnp.float32)], -1)

        # -- Embedder -- #
        # Shift: (x.shape[:-1] + (1,)) = (batch_size, 1) OR (1,)
        x = jnp.append(x, jnp.ones((x.shape[:-1] + (1,)), dtype=jnp.float32) * 3.0, axis=-1)
        x = x / jnp.maximum(jnp.linalg.norm(x, axis=-1, keepdims=True), 1e-8)
        if self.batch_norm:
            x = nn.BatchNorm(use_running_average)(x)
        x = nn.Dense(features=self.hidden_dim, kernel_init=nn.initializers.orthogonal(column_axis=0), use_bias=False)(x)
        scaler = self.param(
            "scaler_embedder",
            nn.initializers.constant(jnp.sqrt(2 / self.hidden_dim), dtype=jnp.float32),
            self.hidden_dim,
        )
        x = scaler * x
        x = x / jnp.maximum(jnp.linalg.norm(x, axis=-1, keepdims=True), 1e-8)

        # -- Encoder -- #
        for idx_block in range(2):
            res = x

            if self.batch_norm:
                x = nn.BatchNorm(use_running_average)(x)
            x = nn.Dense(
                features=self.hidden_dim * 4, kernel_init=nn.initializers.orthogonal(column_axis=0), use_bias=False
            )(x)
            scaler = self.param(
                f"scaler_{idx_block}",
                nn.initializers.constant(jnp.sqrt(2 / (self.hidden_dim * 4)), dtype=jnp.float32),
                self.hidden_dim * 4,
            )
            x = scaler * x
            x = nn.relu(x) + 1e-8
            if self.batch_norm:
                x = nn.BatchNorm(use_running_average)(x)
            x = nn.Dense(
                features=self.hidden_dim, kernel_init=nn.initializers.orthogonal(column_axis=0), use_bias=False
            )(x)
            x = x / jnp.maximum(jnp.linalg.norm(x, axis=-1, keepdims=True), 1e-8)

            scaler = self.param(
                f"scaler_interp_{idx_block}",
                nn.initializers.constant(1 / jnp.sqrt(self.hidden_dim), dtype=jnp.float32),
                self.hidden_dim,
            )
            scale = (1 / 3) / (1 / jnp.sqrt(self.hidden_dim))
            x = res + scale * scaler * (x - res)
            x = x / jnp.maximum(jnp.linalg.norm(x, axis=-1, keepdims=True), 1e-8)

        # -- Predictor -- #
        if self.batch_norm:
            x = nn.BatchNorm(use_running_average)(x)
        x = nn.Dense(features=self.hidden_dim, kernel_init=nn.initializers.orthogonal(column_axis=0), use_bias=False)(x)
        scaler = self.param("scaler_predictor", nn.initializers.constant(1.0, dtype=jnp.float32), self.hidden_dim)
        x = scaler * x
        if self.batch_norm:
            x = nn.BatchNorm(use_running_average)(x)
        x = nn.Dense(
            name="last_layer",
            features=self.n_heads * self.n_bins,
            kernel_init=nn.initializers.orthogonal(column_axis=0),
        )(x)
        # Reshape: (x.shape[:-1] + (n_heads, n_bins)) = (batch_size, n_heads, n_bins) OR (n_heads, n_bins)
        x = x.reshape((x.shape[:-1] + (self.n_heads, self.n_bins)))
        return x

    def normalize_reward(self, reward):
        norm64 = reward / jnp.maximum(jnp.sqrt(self.var.value + 1e-8), self.g_max.value / self.bins[-1])
        return norm64.astype(jnp.float32)


class SimbaV2ActorNet(nn.Module):
    action_dim: int
    hidden_dim = 128
    min_log_stds = -10
    max_log_stds = 2

    @nn.compact
    def __call__(self, state: jnp.ndarray, noise_key) -> tuple[jnp.ndarray, jnp.ndarray]:
        x = RSObservationNorm()(jnp.squeeze(state))
        # -- Embedder -- #
        # Shift: (x.shape[:-1] + (1,)) = (batch_size, 1) OR (1,)
        x = jnp.append(x, jnp.ones((x.shape[:-1] + (1,)), dtype=jnp.float32) * 3.0, axis=-1)
        x = x / jnp.maximum(jnp.linalg.norm(x, axis=-1, keepdims=True), 1e-8)
        x = nn.Dense(features=self.hidden_dim, kernel_init=nn.initializers.orthogonal(column_axis=0), use_bias=False)(x)
        scaler = self.param(
            "scaler_embedder",
            nn.initializers.constant(jnp.sqrt(2 / self.hidden_dim), dtype=jnp.float32),
            self.hidden_dim,
        )
        x = scaler * x
        x = x / jnp.maximum(jnp.linalg.norm(x, axis=-1, keepdims=True), 1e-8)

        # -- Encoder -- #
        res = x
        x = nn.Dense(
            features=self.hidden_dim * 4, kernel_init=nn.initializers.orthogonal(column_axis=0), use_bias=False
        )(x)
        scaler = self.param(
            "scaler",
            nn.initializers.constant(jnp.sqrt(2 / (self.hidden_dim * 4)), dtype=jnp.float32),
            self.hidden_dim * 4,
        )
        x = scaler * x
        x = nn.relu(x) + 1e-8
        x = nn.Dense(features=self.hidden_dim, kernel_init=nn.initializers.orthogonal(column_axis=0), use_bias=False)(x)
        x = x / jnp.maximum(jnp.linalg.norm(x, axis=-1, keepdims=True), 1e-8)

        scaler = self.param(
            "scaler_interp", nn.initializers.constant(1 / jnp.sqrt(self.hidden_dim), dtype=jnp.float32), self.hidden_dim
        )
        scale = (1 / 2) / (1 / jnp.sqrt(self.hidden_dim))
        x = res + scale * scaler * (x - res)
        x = x / jnp.maximum(jnp.linalg.norm(x, axis=-1, keepdims=True), 1e-8)

        # -- Predictor -- #
        means = nn.Dense(
            features=self.hidden_dim, kernel_init=nn.initializers.orthogonal(column_axis=0), use_bias=False
        )(x)
        scaler = self.param("scaler_mean", nn.initializers.constant(1.0, dtype=jnp.float32), self.hidden_dim)
        means = scaler * means
        means = nn.Dense(features=self.action_dim, kernel_init=nn.initializers.orthogonal(column_axis=0))(means)

        if noise_key is None:  # deterministic
            return means, 1
        else:
            stds = nn.Dense(
                features=self.hidden_dim, kernel_init=nn.initializers.orthogonal(column_axis=0), use_bias=False
            )(x)
            scaler = self.param("scaler_std", nn.initializers.constant(1.0, dtype=jnp.float32), self.hidden_dim)
            stds = scaler * stds
            log_stds_unclipped = nn.Dense(
                features=self.action_dim, kernel_init=nn.initializers.orthogonal(column_axis=0)
            )(stds)

            # Apply tanh to prediction (output in [-1, 1]) and then scale (output in [min_log_stds, max_log_stds])
            log_stds = self.min_log_stds + (self.max_log_stds - self.min_log_stds) / 2 * (
                1 + nn.tanh(log_stds_unclipped)
            )
            stds = jnp.exp(log_stds)

            action_pre_tanh = means + stds * jax.random.normal(noise_key, shape=stds.shape, dtype=jnp.float32)
            action = jnp.tanh(action_pre_tanh)

            # Gaussian log-prob: -1/2 ((x - mean) / std)^2 -1/2 log(2 pi) -log(sigma)
            log_prob_uncorrected = (
                -0.5 * jnp.square(action_pre_tanh / stds - means / stds) - 0.5 * jnp.log(2 * jnp.pi) - jnp.log(stds)
            )
            # d tanh^{-1}(y) / dy = 1 / (1 - y^2)
            log_prob = log_prob_uncorrected - jnp.log(1 - action**2 + 1e-6)

            return action, jnp.sum(log_prob, axis=-1)


@jax.jit
def update_g_g_max_stats(reward, episode_end, gamma, stats):
    # episode_end and not terminate to reinitialize G at the start of each episode
    new_g = gamma * (1 - episode_end) * stats["G"] + reward
    return {"G": new_g, "G_max": jnp.maximum(stats["G_max"], jnp.abs(new_g))}


def weight_normalization(params):
    for layer_name in params.keys():
        if layer_name.startswith(("Dense", "last_layer")):
            # Do not normalize bias
            params[layer_name] = jax.tree.map(normalize_matrix, params[layer_name])
    return params


def normalize_matrix(w):
    # Do not normalize bias
    doubleq_offset = 1 if w.shape[0] == 2 else 0
    return (
        w / jnp.maximum(jnp.linalg.norm(w, axis=doubleq_offset, keepdims=True), 1e-8)
        if w.ndim == 2 + doubleq_offset
        else w
    )

@jax.jit
def update_mean_var_stats(x, stats):
    return {
        # \mu_t = \mu_{t-1} + 1 / t * (o_t - \mu_{t-1})
        "mean": stats["mean"] + 1 / stats["count"] * (x - stats["mean"]),
        # \sigma_t = (t - 1) / t * [\sigma_{t-1} + 1 / t * (o_t - \mu_{t-1})^2]
        "var": (stats["count"] - 1) / stats["count"] * (stats["var"] + jnp.square(x - stats["mean"]) / stats["count"]),
        "count": stats["count"] + 1,
    }
