import optax
from flax.training.train_state import TrainState
from flax import core
from flax import struct
from functools import partial
import jax
from typing import Any, Callable


class RLTrainState(TrainState):
    target_params: core.FrozenDict[str, Any] = struct.field(pytree_node=True)
    soft_update_rate: float = 0.005
    update_counter: int = 0

    @jax.jit
    def apply_target_update(self):
        return self.replace(target_params=optax.incremental_update(self.params,
                                                                   self.target_params, self.soft_update_rate))


class RLTrainStateWithSpuriousHandling(TrainState):
    quantile_and_probing: Callable = struct.field(pytree_node=False)
    target_params: core.FrozenDict[str, Any] = struct.field(pytree_node=True)
    soft_update_rate: float = 0.005
    update_counter: int = 0


class DDPMTrainState(TrainState):
    target_params: core.FrozenDict[str, Any] = struct.field(pytree_node=True)
    bc_loss_fn: Callable = struct.field(pytree_node=False)
    soft_update_rate: float = 0.005
    update_counter: int = 0

    @jax.jit
    def apply_target_update(self):
        return self.replace(target_params=optax.incremental_update(self.params,
                                                                   self.target_params, self.soft_update_rate))


class SeperatedCriticTrainState(TrainState):
    target_params: core.FrozenDict[str, Any] = struct.field(pytree_node=True)
    mean: Callable = struct.field(pytree_node=False)
    quantiles: Callable = struct.field(pytree_node=False)
    both: Callable = struct.field(pytree_node=False)
    soft_update_rate: float = 0.005
    update_counter: int = 0

    @jax.jit
    def apply_target_update(self):
        return self.replace(target_params=optax.incremental_update(self.params,
                                                                   self.target_params, self.soft_update_rate))



class DiffusionTrainState(TrainState):
    bc_update_fn: Callable = struct.field(pytree_node=False)


class CEPTrainState(TrainState):
    bc_loss_fn: Callable = struct.field(pytree_node=False)
    cep_loss_fn: Callable = struct.field(pytree_node=False)
    behavior_clone: Callable = struct.field(pytree_node=False)
    target_params: core.FrozenDict[str, Any] = struct.field(pytree_node=True)
    soft_update_rate: float = 0.005
    update_counter: int = 0


class CARDTrainState(TrainState):
    loss_fn: Callable = struct.field(pytree_node=False)
    target_params: core.FrozenDict[str, Any] = struct.field(pytree_node=True)


class VAETrainState(TrainState):
    loss_fn: Callable = struct.field(pytree_node=False)
    bulk_sample: Callable = struct.field(pytree_node=False)


class VAERLTrainState(TrainState):
    target_params: core.FrozenDict[str, Any] = struct.field(pytree_node=True)
    loss_fn: Callable = struct.field(pytree_node=False)
    bulk_sample: Callable = struct.field(pytree_node=False)


@partial(jax.jit, static_argnums=(0,))
def update(loss_fn: Callable, state: TrainState, *args, **kwargs):
    grads, aux = jax.grad(loss_fn, has_aux=True)(state.params, state, *args, **kwargs)
    state = state.apply_gradients(grads=grads)
    return state, aux


@partial(jax.jit, static_argnums=(0,))
def nan_to_num_update(loss_fn: Callable, state: TrainState, *args, **kwargs):
    grads, aux = jax.grad(loss_fn, has_aux=True)(state.params, state, *args, **kwargs)
    grads = jax.tree_map(jax.numpy.nan_to_num, grads)
    state = state.apply_gradients(grads=grads)
    return state, aux


@partial(jax.jit, static_argnums=(0,))
def soft_update(state: RLTrainState):
    return state.apply_target_update()
