"""
Generates a mean field sequence from a mean field policy. 
"""
import jax
import jax.numpy as jnp

from mfax.envs.sample.base import SampleMFSequence

def mf_sequence(
    rng,
    env,
    agent,
    num_envs,
    max_steps_in_episode,
) -> SampleMFSequence:
    """
    Generates a mean field sequence from a (recurrent) mean field policy.
    env: Mean-Field environment - i.e. steps entire Mean Field forward.
    agent: Mean-Field policy. i.e. policy must be wrapped in MFActorWrapper or MFRecurrentActorWrapper.
    """

    use_recurrent = hasattr(agent, "init_hidden")
    print(f"Using recurrent policy: {use_recurrent}")

    if use_recurrent:

        @jax.jit
        def _select_action(local_s, local_obs, hidden_state, done_mask):
            vec_a, next_hidden = agent(local_s, local_obs, hidden_state, done=done_mask)
            return vec_a, next_hidden

        @jax.jit
        def _policy_and_env_step(runner_state, _):
            last_vec_local_s, last_vec_local_obs, last_global_s, last_global_terminated, last_global_truncated, last_actor_hidden, rng = runner_state

            # --- select action ---
            last_done = jnp.logical_or(last_global_terminated, last_global_truncated)
            vec_a, next_actor_hidden = jax.vmap(_select_action, in_axes=(0, 0, 0, 0))(last_vec_local_s.state, last_vec_local_obs, last_actor_hidden, last_done)

            # --- step environment ---
            rng, _rng = jax.random.split(rng)
            rng_step = jax.random.split(_rng, num_envs)
            vec_local_obs, _, vec_local_s, _, global_s, _, vec_r, global_terminated, global_truncated, _ = jax.vmap(
                env.mf_step, in_axes=(0, 0, 0, 0)
            )(rng_step, last_vec_local_s, last_global_s, vec_a)

            # --- only accumulate rewards if environment is not done ---
            global_terminated = global_terminated | last_global_terminated
            global_truncated = global_truncated | last_global_truncated

            # --- transition ---
            transition = SampleMFSequence(
                global_s=last_global_s,
                global_terminated=last_global_terminated,
                global_truncated=last_global_truncated,
                vec_a=vec_a,
                vec_r=vec_r,
            )
            runner_state = (vec_local_s, vec_local_obs, global_s, global_terminated, global_truncated, next_actor_hidden, rng)
            return runner_state, transition

    else:
        @jax.jit
        def _select_action(local_s, local_obs):
            vec_a = agent(local_s, local_obs)
            return vec_a

        @jax.jit
        def _policy_and_env_step(runner_state, _):
            last_vec_local_s, last_vec_local_obs, last_global_s, last_global_terminated, last_global_truncated, rng = runner_state

            # --- select action ---
            vec_a = jax.vmap(_select_action, in_axes=(0, 0))(last_vec_local_s.state, last_vec_local_obs)

            # --- step environment ---
            rng, _rng = jax.random.split(rng)
            rng_step = jax.random.split(_rng, num_envs)
            vec_local_obs, _, vec_local_s, _, global_s, _, vec_r, global_terminated, global_truncated, _ = jax.vmap(
                env.mf_step, in_axes=(0, 0, 0, 0)
            )(rng_step, last_vec_local_s, last_global_s, vec_a)

            # --- only accumulate rewards if environment is not done ---
            global_terminated = global_terminated | last_global_terminated
            global_truncated = global_truncated | last_global_truncated

            # --- transition ---
            transition = SampleMFSequence(
                global_s=last_global_s,
                global_terminated=last_global_terminated,
                global_truncated=last_global_truncated,
                vec_a=vec_a,
                vec_r=vec_r,
            )
            runner_state = (vec_local_s, vec_local_obs, global_s, global_terminated, global_truncated, rng)
            return runner_state, transition

    # --- initialise environment ---
    rng, _rng = jax.random.split(rng)
    reset_rng = jax.random.split(_rng, num_envs)
    init_vec_local_obs, init_vec_local_s, init_global_s = jax.vmap(env.mf_reset, in_axes=(0,))(reset_rng)
    init_global_terminated = jnp.zeros((num_envs,), dtype=int)
    init_global_truncated = jnp.zeros((num_envs,), dtype=int)

    if use_recurrent:
        init_actor_hidden = agent.init_hidden(num_envs * env.params.n_agents).reshape((num_envs, env.params.n_agents, -1))
        runner_state = (
            init_vec_local_s,
            init_vec_local_obs,
            init_global_s,
            init_global_terminated,
            init_global_truncated,
            init_actor_hidden,
            rng,
        )
    else:
        runner_state = (init_vec_local_s, init_vec_local_obs, init_global_s, init_global_terminated, init_global_truncated, rng)

    # --- collect trajectories ---
    _, traj_batch = jax.lax.scan(_policy_and_env_step, runner_state, None, int(max_steps_in_episode))

    return traj_batch


def make_mf_sequence(
    env, 
    agent, 
    num_envs, 
    max_steps_in_episode):

    use_recurrent = hasattr(agent, "init_hidden")
    print(f"Using recurrent policy: {use_recurrent}")

    @jax.jit
    def _mf_sequence(
        rng,
        agent_params,
    ) -> SampleMFSequence:
        """
        Generates a mean field sequence from a (recurrent) mean field policy.
        env: Mean-Field environment - i.e. steps entire Mean Field forward.
        agent: Mean-Field policy. i.e. policy must be wrapped in MFActorWrapper or MFRecurrentActorWrapper.
        """

        if use_recurrent:
            def _select_action(local_s, local_obs, hidden_state, done_mask):
                vec_a, next_hidden = agent(local_s, local_obs, hidden_state, done=done_mask, mf_params=agent_params)
                return vec_a, next_hidden

            def _policy_and_env_step(runner_state, _):
                last_vec_local_s, last_vec_local_obs, last_global_s, last_global_terminated, last_global_truncated, last_actor_hidden, rng = runner_state

                # --- select action ---
                last_done = jnp.logical_or(last_global_terminated, last_global_truncated)
                vec_a, next_actor_hidden = jax.vmap(_select_action, in_axes=(0, 0, 0, 0))(last_vec_local_s.state, last_vec_local_obs, last_actor_hidden, last_done)

                # --- step environment ---
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, num_envs)
                vec_local_obs, _, vec_local_s, _, global_s, _, vec_r, global_terminated, global_truncated, _ = jax.vmap(
                    env.mf_step, in_axes=(0, 0, 0, 0)
                )(rng_step, last_vec_local_s, last_global_s, vec_a)

                # --- only accumulate rewards if environment is not done ---
                global_terminated = global_terminated | last_global_terminated
                global_truncated = global_truncated | last_global_truncated

                # --- transition ---
                transition = SampleMFSequence(
                    global_s=last_global_s,
                    global_terminated=last_global_terminated,
                    global_truncated=last_global_truncated,
                    vec_a=vec_a,
                    vec_r=vec_r,
                )
                runner_state = (vec_local_s, vec_local_obs, global_s, global_terminated, global_truncated, next_actor_hidden, rng)
                return runner_state, transition

        else:
            def _select_action(local_s, local_obs):
                vec_a = agent(local_s, local_obs, mf_params=agent_params)
                return vec_a

            def _policy_and_env_step(runner_state, _):
                last_vec_local_s, last_vec_local_obs, last_global_s, last_global_terminated, last_global_truncated, rng = runner_state

                # --- select action ---
                vec_a = jax.vmap(_select_action, in_axes=(0, 0))(last_vec_local_s.state, last_vec_local_obs)

                # --- step environment ---
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, num_envs)
                vec_local_obs, _, vec_local_s, _, global_s, _, vec_r, global_terminated, global_truncated, _ = jax.vmap(
                    env.mf_step, in_axes=(0, 0, 0, 0)
                )(rng_step, last_vec_local_s, last_global_s, vec_a)

                # --- only accumulate rewards if environment is not done ---
                global_terminated = global_terminated | last_global_terminated
                global_truncated = global_truncated | last_global_truncated

                # --- transition ---
                transition = SampleMFSequence(
                    global_s=last_global_s,
                    global_terminated=last_global_terminated,
                    global_truncated=last_global_truncated,
                    vec_a=vec_a,
                    vec_r=vec_r,
                )
                runner_state = (vec_local_s, vec_local_obs, global_s, global_terminated, global_truncated, rng)
                return runner_state, transition

        # --- initialise environment ---
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, num_envs)
        init_vec_local_obs, init_vec_local_s, init_global_s = jax.vmap(env.mf_reset, in_axes=(0,))(reset_rng)
        init_global_terminated = jnp.zeros((num_envs,), dtype=int)
        init_global_truncated = jnp.zeros((num_envs,), dtype=int)

        if use_recurrent:
            init_actor_hidden = agent.init_hidden(num_envs * env.params.n_agents).reshape((num_envs, env.params.n_agents, -1))
            runner_state = (
                init_vec_local_s,
                init_vec_local_obs,
                init_global_s,
                init_global_terminated,
                init_global_truncated,
                init_actor_hidden,
                rng,
            )
        else:
            runner_state = (init_vec_local_s, init_vec_local_obs, init_global_s, init_global_terminated, init_global_truncated, rng)

        # --- collect trajectories ---
        _, traj_batch = jax.lax.scan(_policy_and_env_step, runner_state, None, int(max_steps_in_episode))

        return traj_batch
    
    return _mf_sequence