from typing import Sequence
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn

# Temporary patch because tfp still uses the older version of the jax api
import jax.core
jax.interpreters.xla.pytype_aval_mappings = jax.core.pytype_aval_mappings

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

from irl_baselines.algorithms.sac.flax_full_jit.tanh_transformed_distribution import TanhTransformedDistribution

from rl_x.environments.action_space_type import ActionSpaceType
from rl_x.environments.observation_space_type import ObservationSpaceType


def get_policy(config, env):
    action_space_type = env.general_properties.action_space_type
    observation_space_type = env.general_properties.observation_space_type
    policy_observation_indices = getattr(env, "policy_observation_indices", jnp.arange(env.single_observation_space.shape[0]))

    if action_space_type == ActionSpaceType.CONTINUOUS and observation_space_type == ObservationSpaceType.FLAT_VALUES:
        return (Policy(env.single_action_space.shape, config.algorithm.log_std_min, config.algorithm.log_std_max, policy_observation_indices),
                get_processed_action_function(jnp.array(env.single_action_space.low), jnp.array(env.single_action_space.high)))



class Policy(nn.Module):
    as_shape: Sequence[int]
    log_std_min: float
    log_std_max: float
    policy_observation_indices: Sequence[int]

    @nn.compact
    def __call__(self, x):
        x = x[..., self.policy_observation_indices]
        x = nn.Dense(512)(x)
        x = nn.LayerNorm()(x)
        x = nn.elu(x)
        x = nn.Dense(256)(x)
        x = nn.elu(x)
        x = nn.Dense(128)(x)
        x = nn.elu(x)

        mean = nn.Dense(np.prod(self.as_shape).item())(x)
        log_std = nn.Dense(np.prod(self.as_shape).item())(x)
        log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max)

        dist = TanhTransformedDistribution(tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)))
        return dist

        # return mean, log_std


def get_processed_action_function(env_as_low, env_as_high):
    def get_clipped_and_scaled_action(action, env_as_low=env_as_low, env_as_high=env_as_high):
        clipped_action = jnp.clip(action, -1, 1)
        return env_as_low + (0.5 * (clipped_action + 1.0) * (env_as_high - env_as_low))
    return jax.jit(get_clipped_and_scaled_action)
