from typing import Optional, Sequence, Tuple, Callable
import functools
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from tensorflow_probability.substrates import jax as tfp

from src.common import Params
from src.models.common import MLP, default_init


tfd = tfp.distributions
tfb = tfp.bijectors
LOG_STD_MIN = -10.0
LOG_STD_MAX = 2.0


@functools.partial(jax.jit, static_argnames=("apply_fn", "distribution"))
@functools.partial(jax.vmap, in_axes=(0, None, 0, 0, None, None))
def _sample_actions(
    rng: jax.random.PRNGKey,
    apply_fn: Callable,
    actor_params: Params,
    observations: np.ndarray,
    temperature: float = 1.0,
    distribution: str = "log_prob",
) -> Tuple[jax.random.PRNGKey, jax.Array]:

    if distribution == "det":
        mean = apply_fn(actor_params, observations, temperature)
        return rng, mean
    else:
        dist = apply_fn(actor_params, observations, temperature)
        rng, key = jax.random.split(rng)
        return rng, dist.sample(seed=key)


def sample_actions(
    rng: jax.random.PRNGKey,
    apply_fn: Callable,
    actor_params: Params,
    observations: np.ndarray,
    temperature: float = 1.0,
    distribution: str = "log_prob",
) -> Tuple[jax.random.PRNGKey, jax.Array]:

    return _sample_actions(
        rng, apply_fn, actor_params, observations, temperature, distribution
    )


class MSEPolicy(nn.Module):
    hidden_dims: Sequence[int]
    action_dim: int
    dropout_rate: Optional[float] = None

    @nn.compact
    def __call__(
        self,
        observations: jax.Array,
        temperature: float = 1.0,
        training: bool = False,
    ) -> jax.Array:

        outputs = MLP(
            self.hidden_dims,
            activate_final=True,
            dropout_rate=self.dropout_rate,
        )(observations, training=training)

        actions = nn.Dense(self.action_dim, kernel_init=default_init())(
            outputs
        )
        return nn.tanh(actions)


class NormalTanhPolicy(nn.Module):
    hidden_dims: Sequence[int]
    action_dim: int
    state_dependent_std: bool = True
    dropout_rate: Optional[float] = None
    log_std_scale: float = 1.0
    log_std_min: Optional[float] = None
    log_std_max: Optional[float] = None
    tanh_squash_distribution: bool = True

    @nn.compact
    def __call__(
        self,
        observations: jax.Array,
        temperature: float = 1.0,
        training: bool = False,
    ) -> tfd.Distribution:

        outputs = MLP(
            self.hidden_dims,
            activate_final=True,
            dropout_rate=self.dropout_rate,
        )(observations, training=training)

        means = nn.Dense(self.action_dim, kernel_init=default_init())(outputs)

        if self.state_dependent_std:
            log_stds = nn.Dense(
                self.action_dim, kernel_init=default_init(self.log_std_scale)
            )(outputs)
        else:
            log_stds = self.param(
                "log_stds", nn.initializers.zeros, (self.action_dim,)
            )

        log_std_min = self.log_std_min or LOG_STD_MIN
        log_std_max = self.log_std_max or LOG_STD_MAX
        # log_stds = jnp.clip(log_stds, log_std_min, log_std_max)
        # suggested by Ilya for stability
        log_stds = log_std_min + (log_std_max - log_std_min) * 0.5 * (
            1 + nn.tanh(log_stds)
        )

        if not self.tanh_squash_distribution:
            means = nn.tanh(means)

        base_dist = tfd.MultivariateNormalDiag(
            loc=means, scale_diag=jnp.exp(log_stds) * temperature
        )
        if self.tanh_squash_distribution:
            return tfd.TransformedDistribution(
                distribution=base_dist, bijector=tfb.Tanh()
            )
        else:
            return base_dist
