from functools import partial
from typing import Union, List

import jax

from jax import Array, random, numpy as jnp
from flax.training.train_state import TrainState
from distrax import (
    Bijector,
    MultivariateNormalDiag,
    Transformed,
)

from ppomdp.core import (
    PRNGKey,
    Carry,
    Parameters,
    RecurrentPolicy,
)
from ppomdp.policy.arch import (
    RecurrentEncoder,
    NeuralGaussDecoder
)


def create_recurrent_neural_gauss_policy(
    encoder: RecurrentEncoder,
    decoder: NeuralGaussDecoder,
    bijector: Bijector
) -> RecurrentPolicy:
    r"""Creates a squashed neural Gaussian policy that conforms to the RecurrentPolicy interface.

    The policy uses a recurrent encoder to process observations and a decoder to output
    action distributions. Actions are transformed through a bijector to enforce bounds.

    Args:
        encoder (Union[LSTMEncoder, GRUEncoder]): The recurrent encoder network
        decoder (NeuralGaussDecoder): The neural network used for the policy
        bijector (Bijector): Policy bijector to enforce action limits

    Returns:
        RecurrentPolicy: A policy object implementing the RecurrentPolicy interface with
            methods for sampling actions and computing probabilities
    """

    def sample(
        rng_key: PRNGKey,
        carry: list[Carry],
        actions: Array,
        observations: Array,
        params: Parameters,
    ) -> tuple[list[Carry], Array, Array]:
        next_carry, encodings = encoder.apply({"params": params["encoder"]}, carry, observations, actions)
        mean, log_std = decoder.apply({"params": params["decoder"]}, encodings)
        base = MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std))
        dist = Transformed(distribution=base, bijector=bijector)
        next_actions = dist.sample(seed=rng_key)
        return next_carry, next_actions, bijector.forward(mean)

    def log_prob(
        next_actions: Array,
        carry: list[Carry],
        actions: Array,
        observations: Array,
        params: Parameters,
    ) -> Array:
        _, encodings = encoder.apply({"params": params["encoder"]}, carry, observations, actions)
        mean, log_std = decoder.apply({"params": params["decoder"]}, encodings)
        base = MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std))
        dist = Transformed(distribution=base, bijector=bijector)
        return dist.log_prob(next_actions)

    def sample_and_log_prob(
        rng_key: PRNGKey,
        carry: list[Carry],
        actions: Array,
        observations: Array,
        params: Parameters,
    ) -> tuple[list[Carry], Array, Array, Array]:
        next_carry, encodings = encoder.apply({"params": params["encoder"]}, carry, observations, actions)
        mean, log_std = decoder.apply({"params": params["decoder"]}, encodings)
        base = MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std))
        dist = Transformed(distribution=base, bijector=bijector)
        next_actions, log_prob = dist.sample_and_log_prob(seed=rng_key)
        return next_carry, next_actions, log_prob, bijector.forward(mean)

    def carry_and_log_prob(
        next_actions: Array,
        carry: list[Carry],
        actions: Array,
        observations: Array,
        params: Parameters,
    ) -> tuple[list[Carry], Array]:
        next_carry, encodings = encoder.apply({"params": params["encoder"]}, carry, observations, actions)
        mean, log_std = decoder.apply({"params": params["decoder"]}, encodings)
        base = MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std))
        dist = Transformed(distribution=base, bijector=bijector)
        log_probs = dist.log_prob(next_actions)
        return next_carry, log_probs

    @jax.jit
    def pathwise_carry(
        actions: Array,
        observations: Array,
        params: Parameters,
    ):
        def body(carry, args):
            obs, act = args
            next_carry, _ = encoder.apply({"params": params["encoder"]}, carry, obs, act)
            return next_carry, next_carry

        _, batch_size, _ = observations.shape
        init_carry = encoder.reset(batch_size)
        _, all_carry = jax.lax.scan(body, init_carry, (observations, actions))

        def concat_trees(x, y):
            return jax.tree.map(lambda x, y: jnp.concatenate([x[None, ...], y]), x, y)

        return concat_trees(init_carry, all_carry)

    @jax.jit
    def pathwise_log_prob(
        actions: Array,
        observations: Array,
        params: Parameters
    ) -> Array:

        def log_prob_fn(carry, args):
            _next_actions, _observations, _actions = args
            next_carry, log_prob_incs = \
                carry_and_log_prob(
                    next_actions=_next_actions,
                    carry=carry,
                    actions=_actions,
                    observations=_observations,
                    params=params
                )
            return next_carry, log_prob_incs

        _, batch_size, _ = actions.shape
        init_carry = encoder.reset(batch_size)

        _, log_probs = jax.lax.scan(
            f=log_prob_fn,
            init=init_carry,
            xs=(
                actions[1:, ...],
                observations[:-1, ...],
                actions[:-1, ...]
            )
        )
        return jnp.sum(log_probs, axis=0)

    def entropy(params: Parameters) -> Array:
        log_std = params["decoder"]["log_std"]
        return decoder.entropy(log_std)

    def reset(batch_size: int,) -> list[Carry]:
        return encoder.reset(batch_size)

    def init(
        rng_key: PRNGKey,
        obs_dim: int,
        action_dim: int,
        batch_dim: int,
    ) -> Parameters:
        dummy_key, encoder_key, decoder_key, _ = random.split(rng_key, 4)

        # initialize encoder network
        dummy_carry = encoder.reset(batch_dim)
        dummy_action = random.normal(dummy_key, (batch_dim, action_dim))
        dummy_observation = random.normal(dummy_key, (batch_dim, obs_dim))
        encoder_params = encoder.init(encoder_key, dummy_carry, dummy_observation, dummy_action)["params"]

        # initialize decoder network
        _, dummy_encoding = encoder.apply({"params": encoder_params}, dummy_carry, dummy_observation, dummy_action)
        decoder_params = decoder.init(decoder_key, dummy_encoding)["params"]

        # merge parameters
        params = {"encoder": encoder_params, "decoder": decoder_params}
        return params

    return RecurrentPolicy(
        dim=decoder.dim,
        init=init,
        reset=reset,
        sample=sample,
        log_prob=log_prob,
        pathwise_carry=pathwise_carry,
        pathwise_log_prob=pathwise_log_prob,
        sample_and_log_prob=sample_and_log_prob,
        carry_and_log_prob=carry_and_log_prob,
        entropy=entropy,
    )


@partial(jax.jit, static_argnames="policy")
def train_recurrent_neural_gauss_policy_pathwise(
    learner: TrainState,
    policy: RecurrentPolicy,
    actions: Array,
    observations: Array,
) -> tuple[TrainState, Array]:
    """
    Trains a recurrent neural Gaussian policy using a pathwise gradient-based
    approach. This function optimizes the policy parameters by maximizing the
    log probability of the observed actions given the observations.

    Args:
        learner (TrainState): The current state of the training process,
            including model parameters and optimizers.
        policy (RecurrentPolicy): The policy to be trained, which must define
            the necessary methods for computing pathwise log probabilities.
        actions (Array): The array of actions used for computing the log probabilities.
        observations (Array): The array of observations corresponding to the actions.

    Returns:
        tuple[TrainState, Array]: A tuple containing the updated learner's
        training state and the computed loss value.
    """
    def loss_fn(params):
        log_probs = policy.pathwise_log_prob(actions, observations, params)
        return -1.0 * jnp.mean(log_probs)

    loss, grads = jax.value_and_grad(loss_fn)(learner.params)
    learner = learner.apply_gradients(grads=grads)
    return learner, loss
