import optax
from flax.training.train_state import TrainState
from flax import core
from flax import struct
from functools import partial
import jax
from typing import Any, Callable


class RLTrainState(TrainState):
    target_params: core.FrozenDict[str, Any] = struct.field(pytree_node=True)
    soft_update_rate: float = 0.005
    update_counter: int = 0

    @jax.jit
    def apply_target_update(self):
        return self.replace(target_params=optax.incremental_update(self.params,
                                                                   self.target_params, self.soft_update_rate))


@partial(jax.jit, static_argnums=(0,))
def update(loss_fn: Callable, state: TrainState, *args, **kwargs):
    grads, aux = jax.grad(loss_fn, has_aux=True)(state.params, state, *args, **kwargs)
    state = state.apply_gradients(grads=grads)
    return state, aux


@partial(jax.jit, static_argnums=(0,))
def nan_to_num_update(loss_fn: Callable, state: TrainState, *args, **kwargs):
    grads, aux = jax.grad(loss_fn, has_aux=True)(state.params, state, *args, **kwargs)
    grads = jax.tree_map(jax.numpy.nan_to_num, grads)
    state = state.apply_gradients(grads=grads)
    return state, aux


@partial(jax.jit, static_argnums=(0,))
def soft_update(state: RLTrainState):
    return state.apply_target_update()
