from typing import Tuple

import chex
import jax
import jax.numpy as jnp
import optax
from flax.core.frozen_dict import FrozenDict
from omegaconf import DictConfig

from stoix.utils.flatten_util import ravel_pytree


def get_diff_gradient(
    latest_params: chex.ArrayTree, behaviour_params: chex.ArrayTree
) -> chex.ArrayTree:
    # Compute the vector from behaviour params to current actor params
    natural_gradient = jax.tree_util.tree_map(lambda x, y: x - y, latest_params, behaviour_params)
    return natural_gradient


def old_free_step(
    config: DictConfig,
    current_params: FrozenDict,
    behaviour_params: FrozenDict,
    momentum: chex.ArrayTree,
) -> Tuple[FrozenDict, chex.ArrayTree]:
    ### MOMENTUM UPDATE / APPLICATION START ###

    # compute the vector from behaviour params to 'current' params
    natural_gradient = get_diff_gradient(current_params, behaviour_params)

    def _update_momentum(momentum, natural_gradient):
        # if config.system.ppo_momentum == 0.0 momentum will be the natural gradient
        # if config.system.ppo_momentum > 0 and we are in first iteration momentum is also the natural gradient
        # if config.system.ppo_momentum > 0 and we are in subsequent iterations momentum is computed as
        # ppo_momentum * prev_momentum + (1 - ppo_momentum) * natural_gradient

        # compute the updated momentum to be used if not in first iteration
        maybe_updated_momentum = jax.tree.map(
            lambda m, ng: config.system.ppo_momentum * m + (1 - config.system.ppo_momentum) * ng,
            momentum,
            natural_gradient,
        )

        # jax.lax.select wasn't playing nicely with pytrees so i ravel all into vectors
        momentum_vector, unravel_params = ravel_pytree(momentum)
        natural_gradient_vector, _ = ravel_pytree(natural_gradient)
        maybe_updated_momentum_vector, _ = ravel_pytree(maybe_updated_momentum)

        # if momentum is all zero (first loop), use the natural gradient, otherwise use the updated momentum
        momentum_vector = jax.lax.select(
            jnp.all(
                momentum_vector == 0
            ),  # this could be true outside of first loop but is highly improbable
            natural_gradient_vector,
            maybe_updated_momentum_vector,
        )

        # unravel the momentum vector back into a parameter pytree
        momentum = unravel_params(momentum_vector)

        return momentum

    momentum = _update_momentum(momentum, natural_gradient)

    if config.system.ppo_step_update:
        # update the behaviour params using momentum
        params = jax.tree_util.tree_map(
            lambda x, y: x + config.system.ppo_step_update_lr * y,
            behaviour_params,
            momentum,
        )
    else:
        # just use the 'current' params from the inner loop
        # we still had to update the momentum vector in case it is used in free-step in the next iteration
        params = current_params

    return params, momentum
    ### MOMENTUM UPDATE / APPLICATION END ###
