from typing import Callable
from functools import partial

import jax

from jax import Array, random, numpy as jnp
from flax.training.train_state import TrainState
from distrax import Chain, MultivariateNormalDiag, Transformed
from distrax import Distribution

from ppomdp.core import (
    PRNGKey,
    Carry,
    Parameters,
    TransitionModel,
    ObservationModel,
)
from ppomdp.utils import custom_split

from baselines.slac.arch import PolicyNetwork


def policy_sample_and_log_prob(
    rng_key: PRNGKey,
    carry: list[Carry],
    observation: Array,
    params: Parameters,
    network: PolicyNetwork,
    bijector: Chain,
) -> tuple[list[Carry], Array, Array, Array]:
    carry, mean, log_std = network.apply({"params": params}, carry, observation)
    base = MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std))
    dist = Transformed(distribution=base, bijector=bijector)
    action, log_prob = dist.sample_and_log_prob(seed=rng_key)
    return carry, action, log_prob, bijector.forward(mean)


@partial(
    jax.jit,
    static_argnames=(
        "num_time_steps",
        "num_trajectory_samples",
        "num_belief_particles",
        "init_dist",
        "belief_prior",
        "policy_network",
        "trans_model",
        "obs_model",
        "reward_fn",
    )
)
def policy_evaluation(
    rng_key: PRNGKey,
    num_time_steps: int,
    num_trajectory_samples: int,
    num_belief_particles: int,
    init_dist: Distribution,
    belief_prior: Distribution,
    policy_state: TrainState,
    policy_network: PolicyNetwork,
    trans_model: TransitionModel,
    obs_model: ObservationModel,
    reward_fn: Callable,
):

    from ppomdp.smc.utils import initialize_belief, update_belief

    def body(val, key):
        states, carry, observations, beliefs, time_idx = val

        # Sample actions.
        key, action_key = random.split(key)
        carry, _, _, actions = policy_state.apply_fn(
            rng_key=action_key,
            carry=carry,
            observation=observations,
            params=policy_state.params
        )

        # Compute rewards.
        rewards = jax.vmap(reward_fn, (0, 0, None))(states, actions, time_idx)

        # Sample next states.
        key, state_keys = custom_split(key, num_trajectory_samples + 1)
        states = jax.vmap(trans_model.sample)(state_keys, states, actions)

        # Sample observations.
        obs_keys = random.split(key, num_trajectory_samples)
        observations = jax.vmap(obs_model.sample)(obs_keys, states)

        belief_keys = random.split(key, num_trajectory_samples)
        beliefs = jax.vmap(update_belief, (0, None, None, 0, 0, 0))(
            belief_keys, trans_model, obs_model, beliefs, observations, actions
        )

        return (states, carry, observations, beliefs, time_idx + 1), (states, actions, beliefs, rewards)

    # Initialize.
    key, state_key = random.split(rng_key)
    init_states = init_dist.sample(seed=state_key, sample_shape=num_trajectory_samples)

    key, obs_keys = custom_split(key, num_trajectory_samples + 1)
    init_observations = jax.vmap(obs_model.sample)(obs_keys, init_states)
    init_carry = policy_network.reset(num_trajectory_samples)

    key, belief_keys = custom_split(key, num_trajectory_samples + 1)
    init_beliefs = jax.vmap(initialize_belief, in_axes=(0, None, None, 0, None))(
        belief_keys, belief_prior, obs_model, init_observations, num_belief_particles
    )

    _, (states, actions, beliefs, rewards) = jax.lax.scan(
        f=body,
        init=(init_states, init_carry, init_observations, init_beliefs, 0),
        xs=random.split(key, num_time_steps + 1)
    )

    def concat_trees(x, y):
        return jax.tree.map(lambda x, y: jnp.concatenate([x[None, ...], y]), x, y)

    states = concat_trees(init_states,  states)
    beliefs = concat_trees(init_beliefs, beliefs)
    return rewards, states, actions, beliefs
