from risk_morl.mo_sac.policy import MOSACPolicy
import jax
from flax import nnx
import numpy as np


class LangevinSafeMORL(MOSACPolicy):
    def predict_langevin(self, observation, weight):
        action, self.s_actor = self._predict(
            self.g_actor, self.s_actor,
            observation, weight)
        return np.asarray(action).copy(), weight

    @staticmethod
    @jax.jit
    def _lagevin_predict(
            g_critic, s_critic,
            g_actor, s_actor, obs, weight, key):
        critic, _, _ = nnx.merge(g_critic, s_critic)
        actor, _, _ = nnx.merge(g_actor, s_actor)
        action = actor(obs, weight)
        dt = 1/100
        def vf(action):
            critic(obs, action, weight)


        def body(carry, _):
            key = carry['key']
            a_t = carry['a_t']






