import jax.nn
import jax.numpy as jnp
import jax
from networks.base import MLP
import flax.linen as nn
from typing import Callable
from distrax import MultivariateNormalDiag, Tanh, Block, Transformed



class SquashedGaussianPolicy(nn.Module):
    action_dim: int
    activation_fn: Callable = nn.PReLU()
    CLIP_MIN: float = -20.
    CLIP_MAX: float = 3.

    def setup(self):
        self.mlp = nn.Sequential([MLP(output_size=self.action_dim * 2,
                                      activation_fn=self.activation_fn)])

    def __call__(self, observations):
        params: jax.Array = self.mlp(observations)
        mu, logstd = jnp.split(params, 2, axis=-1)
        logstd = jnp.clip(logstd, self.CLIP_MIN, self.CLIP_MAX)
        std = jnp.exp(logstd)
        distr = Transformed(MultivariateNormalDiag(mu, std), Block(Tanh(), 1))
        action, log_prob = distr.sample_and_log_prob(seed=self.make_rng('rng_stream'))
        return action, log_prob[..., None]

    def deterministic(self, observations):
        params: jax.Array = self.mlp(observations)
        mu, _ = jnp.split(params, 2, axis=-1)
        return jnp.tanh(mu)

    def log_prob_of(self, observation, action):
        params: jax.Array = self.mlp(observation)
        mu, logstd = jnp.split(params, 2, axis=-1)
        logstd = jnp.clip(logstd, self.CLIP_MIN, self.CLIP_MAX)
        std = jnp.exp(logstd)
        distr = Transformed(MultivariateNormalDiag(mu, std), Block(Tanh(), 1))
        return distr.log_prob(action.clip(-1 + 1e-6, 1 - 1e-6))


class DeterministicPolicy(nn.Module):
    action_dim: int
    layer_norm: bool = True
    activation_fn: Callable = nn.PReLU()
    ff_feature: bool = False

    @nn.compact
    def __call__(self, observation):
        if self.ff_feature:
            ff_proj = nn.Dense(250)(observation)
            ff = 2 * jnp.pi * ff_proj
            observation = jnp.concatenate([jnp.sin(ff), jnp.cos(ff), observation], axis=-1)

        return jax.nn.tanh(MLP(output_size=self.action_dim, hidden_sizes=(256, 256),
                               activation_fn=self.activation_fn, layer_norm=self.layer_norm)(observation))


class FFFeatureNet(nn.Module):

    @nn.compact
    def __call__(self, feature):
        ff_proj = nn.Dense(250)(feature)
        ff = 2 * jnp.pi * ff_proj
        feature = jnp.concatenate([jnp.sin(ff), jnp.cos(ff), feature], axis=-1)
        return feature


class VAEActor(nn.Module):
    action_dim: int
    n_bulk_sample: int = 100
    layer_norm: bool = True
    activation_fn: Callable = nn.PReLU()
    ff_feature: bool = False

    def setup(self) -> None:
        self.latent_dim = self.action_dim * 2
        if self.ff_feature:
            self.encoder = nn.Sequential([
                FFFeatureNet(),
                MLP(output_size=self.latent_dim * 2, hidden_sizes=(750, 750),
                    activation_fn=self.activation_fn)])
            self.decoder = nn.Sequential([FFFeatureNet(), MLP(output_size=self.action_dim, hidden_sizes=(750, 750),
                                                              activation_fn=self.activation_fn), nn.tanh])

        else:
            self.encoder = MLP(output_size=self.latent_dim * 2, hidden_sizes=(750, 750),
                               activation_fn=self.activation_fn)
            self.decoder = nn.Sequential([MLP(output_size=self.action_dim, hidden_sizes=(750, 750),
                                              activation_fn=self.activation_fn), nn.tanh])

    @nn.nowrap
    def normal(self, x):
        key = self.make_rng('rng_stream')
        return jax.random.normal(key, shape=x.shape[:-1] + (self.latent_dim,))

    def encode_param(self, observation, action):
        mu_log_sigma = self.encoder(jnp.concatenate([observation, action], axis=-1))
        mu, log_sigma = jnp.split(mu_log_sigma, axis=-1, indices_or_sections=2)
        return mu, log_sigma

    def loss_fn(self, observation, action, beta: float = 0.5):
        mu, log_sigma = self.encode_param(observation, action)
        sigma = jnp.exp(log_sigma)
        z = mu + self.normal(sigma) * sigma
        action_hat = self.decoder(jnp.concatenate([z, observation], axis=-1))
        bc_loss = (action_hat - action) ** 2
        bc_loss = bc_loss.mean()
        kl_loss = -0.5 * (1 + log_sigma - mu ** 2 - sigma)
        kl_loss = kl_loss.mean()
        return bc_loss + beta * kl_loss, {"bc_loss": bc_loss, "kl_loss": kl_loss}

    def bulk_sample(self, observation):
        key = self.make_rng('rng_stream')
        normal = jax.random.normal(key, shape=observation.shape[:-1] + (self.n_bulk_sample, self.latent_dim))
        z = normal.clip(-0.5, 0.5)  # Fujimoto's implementation
        observation = observation[..., None, :]
        observation = jnp.repeat(observation, axis=-2, repeats=self.n_bulk_sample)
        return self.decoder(jnp.concatenate([z, observation], axis=-1))

    def non_clipping_bulk_sample(self, observation):
        key = self.make_rng('rng_stream')
        z = jax.random.normal(key, shape=observation.shape[:-1] + (self.n_bulk_sample, self.latent_dim))

        observation = observation[..., None, :]
        observation = jnp.repeat(observation, axis=-2, repeats=self.n_bulk_sample)
        return self.decoder(jnp.concatenate([z, observation], axis=-1))

    def __call__(self, observation, action):
        return self.loss_fn(observation, action)

    def sample(self, observation):
        z = self.normal(observation).clip(-0.5, 0.5)  # Fujimoto's implementation
        return self.decoder(jnp.concatenate([z, observation], axis=-1))

    def non_clipping_noise_sample(self, observation):
        z = self.normal(observation)
        return self.decoder(jnp.concatenate([z, observation], axis=-1))


class ORAACActor(nn.Module):
    actions_dim: int
    phi: float = 0.05

    @nn.compact
    def __call__(self, observation, action):
        a_hat = self.phi * jnp.tanh(MLP(self.actions_dim)(jnp.concatenate([observation, action], axis=-1)))
        return (a_hat + action).clip(-1, 1)


'''
class ORAACActor(nn.Module):
    action_dim: int
    latent_dim: int = 64
    layer_norm: bool = True
    residual: bool = False
    activation_fn: Callable = nn.silu

    def setup(self) -> None:
        # VAE
        self.encoder = MLP(output_size=self.latent_dim * 2, activation_fn=self.activation_fn,
                           layer_norm=self.layer_norm, residual=self.residual)
        self.decoder = MLP(output_size=self.action_dim, activation_fn=self.activation_fn,
                           layer_norm=self.layer_norm, residual=self.residual)
        self.make_rng('rng_stream')

    def __call__(self, obs):
        latent_variables = self.encoder(obs)
        mu, log_sigma = jnp.split(latent_variables, axis=-1, indices_or_sections=2)
        z = mu + jnp.exp(log_sigma) * jax.random.normal(key=self.make_rng('rng_stream'), shape=log_sigma.shape)
        action = self.decoder(z)
        return action

    def vae_loss(self, obs, batch_action):
        latent_variables = self.encoder(obs)
        mu, log_sigma = jnp.split(latent_variables, axis=-1, indices_or_sections=2)
        z = mu + jnp.exp(log_sigma) * jax.random.normal(key=self.make_rng('rng_stream'), shape=log_sigma.shape)
        action = self.decoder(z)
        bc_loss = ((batch_action - action) ** 2).mean()
        kl_loss = distrax.mc_estimate_kl()
        return
'''
