import numpy as np

from risk_morl.mo_sac.policy import MOSACPolicy
from flax import nnx
import jax.numpy as jnp
import jax
from basics.autoregressive_iqn import KRIQN, KRIQNAblation, KRIQNAblationPositionalEncoding, MarginalIQN
from basics.layers import Lagrangian, SmoothLagrangian
from risk_morl.architecture.actor import Actor
from risk_morl.utils.network_utils import copy_param, polyak_update
from functools import partial
import optax
from risk_morl.buffer import ReplayBufferSamplesWeight
from basics.weight_sampler import PreferenceSampler
from typing import Callable, Literal, Optional
from risk_morl.utils.adamblief_w import adabeliefw


class KRCriticPolicy(MOSACPolicy):
    critic: KRIQN
    preference_sampler: PreferenceSampler
    s_sampler: nnx.State
    g_sampler: nnx.GraphDef

    def __init__(self,
                 env,
                 reward_dim: int,
                 n_env: int = 1,
                 gamma: float = 0.99,
                 soft_update_ratio: float = 5e-3,
                 critic_lr: float = 3e-4,
                 actor_lr: float = 3e-4,
                 risk_measure: Callable = lambda x: x * 0.1,  # TV@R 10%
                 ent_coef: float | Literal['auto'] = 'auto',
                 target_entropy: float | Literal['auto'] = 'auto',
                 truncation_upper: int = 1,
                 truncation_lower: int = 0,
                 proj_tqc: int = 1,
                 *,
                 optimizer_kwargs: Optional[dict] = None,
                 opt_class: Literal['adam', 'sgd', 'adabelief'] = 'adam',
                 actor_marginal_risk: bool = False,
                 comonotone: bool = False,
                 discrete_weights: bool = False,
                 num_grids: int = 20,
                 seed: int = 42,
                 ):
        self.truncation_upper = truncation_upper
        self.truncation_lower = truncation_lower
        self.actor_marginal = actor_marginal_risk
        self.proj_tqc = proj_tqc
        self.comonotone = comonotone
        if optimizer_kwargs is None:
            optimizer_kwargs = { }
        self.optimizer_kwargs = optimizer_kwargs
        self.opt_class = opt_class
        super().__init__(
            env=env,
            reward_dim=reward_dim,
            n_env=n_env,
            gamma=gamma,
            soft_update_ratio=soft_update_ratio, critic_lr=critic_lr,
            actor_lr=actor_lr, risk_measure=risk_measure, ent_coef=ent_coef,
            target_entropy=target_entropy,
            discrete_weights=discrete_weights, num_grids=num_grids, seed=seed
        )

    def build_actor(self):
        self.actor = Actor(self.obs_dim, self.n_action, self.reward_dim, rngs=self.rng)
        if self.opt_class == 'adam':
            opt_ = partial(optax.adam, b1=0.5, b2=0.5)
        elif self.opt_class == 'adabelief':
            opt_ = partial(optax.adabelief, b1=0.5, b2=0.5)
        else:
            opt_ = optax.sgd
        self.opt_actor = nnx.Optimizer(self.actor, optax.chain(
            opt_(self.actor_lr)))

        q_infos = { f"qf_{i}": nnx.metrics.Average(f'qf_{i}') for i in range(1, self.reward_dim + 1) }
        self.metric_actor = nnx.MultiMetric(
            pi_loss=nnx.metrics.Average('pi_loss'),
            corr=nnx.metrics.Average('corr'),
            ent_=nnx.metrics.Average('ent'),
            **q_infos
        )

        self.g_actor, self.s_actor = nnx.split(
            (self.actor, self.opt_actor, self.metric_actor)
        )
        self.actor_update_fn = self.build_actor_update_fn()

    def build_critic(self):
        self.critic = KRIQN(self.obs_dim, self.n_action, self.reward_dim,
                            n_critics=3,
                            rngs=self.rng)
        if self.opt_class == 'adam':
            opt_ = optax.adam
        elif self.opt_class == 'adabelief':
            opt_ = optax.adabelief
        else:
            opt_ = optax.sgd
        self.opt_critic = nnx.Optimizer(self.critic,
                                        optax.chain(opt_(self.critic_lr, ), )
                                        )

        self.metric_critic = nnx.MultiMetric(
            q_loss=nnx.metrics.Average('q_loss'),
        )
        self.g_critic, self.s_critic = nnx.split(
            (self.critic, self.opt_critic, self.metric_critic)
        )
        self.target_param = copy_param(self.critic)
        self.critic_update_fn = self.build_critic_update_fn()
        self.preference_sampler = PreferenceSampler(self.reward_dim, rngs=self.rng)
        self.g_sampler, self.s_sampler = nnx.split(self.preference_sampler)

    def predict(self, observation, weight, *, langevin: bool = False, deterministic: bool = False):
        squeeze = False
        if len(observation.shape) == 1:
            squeeze = True
            observation = observation[None]
        if len(weight.shape) == 1:
            _weight = weight[None]
        else:
            _weight = weight
        if langevin:
            taus = np.linspace(0, 1, 32)[None, None]
            taus = np.repeat(taus, axis=1, repeats=self.reward_dim)

            action, self.s_actor = self.predict_langevin(self.g_actor, self.s_actor,
                                                         self.g_critic, self.s_critic,
                                                         self.g_ent, self.s_ent,
                                                         observation, _weight, taus, self.rng(),
                                                         )
            return np.asarray(action.squeeze(axis=0)).copy(), weight
        else:
            if not deterministic:
                return super().predict(observation, _weight)
            else:
                action, self.s_actor = self.deterministic_predict(self.g_actor, self.s_actor,
                                                                  observation, _weight)
                if squeeze:
                    return np.asarray(action.squeeze(axis=0)).copy(), weight
                else:
                    return np.asarray(action).copy(), weight

    @staticmethod
    @jax.jit
    def deterministic_predict(g_actor, s_actor, obs, weight):
        actor, o, m = nnx.merge(g_actor, s_actor)
        new_action = actor.deterministic(obs, weight)
        _, actor_state = nnx.split((actor, o, m))
        return new_action, actor_state

    @staticmethod
    @jax.jit
    def predict_langevin(g_actor, s_actor,
                         g_critic, s_critic,
                         g_ent_coef, s_ent_coef,
                         obs, weight, taus, key
                         ):
        actor, *others = nnx.merge(g_actor, s_actor)
        critic, _, _ = nnx.merge(g_critic, s_critic)
        _ent, _, _ = nnx.merge(g_ent_coef, s_ent_coef)
        ent_coef = _ent()

        distribution = actor.distribution(obs, weight)

        action = actor(obs, weight)

        def q(action):
            qf = critic(obs, action, taus, weight)
            qf = qf * weight[..., None, None]
            qf = qf.sum(axis=-3).mean(axis=-2).mean(axis=-1)

            return qf.sum()

        T = 10
        dt = 1 / T
        grad_q = jax.grad(q)
        solver = optax.adam(dt)
        opt_state = solver.init(action)

        def loop(carry, i):
            a = carry['action']
            key = carry['key']
            opt_state = carry['state']

            da = grad_q(a)
            dWt = jax.random.normal(key, shape=a.shape) * jnp.sqrt(dt)
            grad = -(0.5 * da)
            updates, opt_state = solver.update(grad, opt_state, a)
            a = optax.apply_updates(a, updates)
            a = a + jnp.sqrt(ent_coef) * dWt
            a = a.clip(-1., 1.)
            next_carry = { "action": a, "key": jax.random.split(key, 2)[-1],
                           "state": opt_state
                           }
            return next_carry, a

        last_carry, _ = jax.lax.scan(loop, { "action": action, "key": key, "state": opt_state }, jnp.arange(T))
        _, s_actor = nnx.split((actor, *others))

        return last_carry['action'], s_actor

    def build_critic_update_fn(self,
                               ):
        upper = int(self.truncation_upper)
        lower = int(self.truncation_lower)
        proj_tqc = int(self.proj_tqc)

        def critic_update_fn(g_critic, s_critic, target_critic_params,
                             g_actor, s_actor,
                             g_ent_coef, s_ent_coef,
                             g_sampler, s_sampler,
                             batch: ReplayBufferSamplesWeight,
                             key):
            N_TAUS = 16
            keys = jax.random.split(key, 3)
            critic, opt_critic, metric_critic = nnx.merge(g_critic, s_critic)
            graph, param, *others = nnx.split(critic, nnx.Param, ...)
            target_q_network = nnx.merge(graph, target_critic_params, *others)
            preference_sampler = nnx.merge(g_sampler, s_sampler)

            actor, _, _ = nnx.merge(g_actor, s_actor)
            __ent_coef, _, _ = nnx.merge(g_ent_coef, s_ent_coef)
            ent_coef = __ent_coef()

            w = self.random_weight(keys[-1], batch.observations)  # preference_sampler.sample_weight(batch.observations)

            B = batch.observations.shape[0]
            next_taus = jax.random.uniform(keys[0], shape=(B, self.reward_dim, N_TAUS))
            next_actions, next_log_prob = actor.sample_and_log_prob(batch.next_observations, w)
            # feature, actions, taus, weight
            # (b, w, n, c) -> (B, W, N* C)

            next_quantile = target_q_network(batch.next_observations, next_actions, next_taus, w)
            next_quantile = next_quantile.reshape(w.shape[0], self.reward_dim, -1)

            # remove largest values for each axis
            for d in range(self.reward_dim):
                if upper > 0:
                    _index = jnp.argsort(next_quantile[:, d, :], axis=-1)[..., lower:-upper]
                else:
                    _index = jnp.argsort(next_quantile[:, d, :], axis=-1)[..., lower:]
                _index = _index[..., None, :]
                next_quantile = jnp.take_along_axis(next_quantile, _index, axis=-1)

            # (b, w, 1) * (b, w, n * c) -> (b, n * c)
            scalarized = ((w[..., None] * next_quantile).sum(axis=1))
            # (b, 1)
            if proj_tqc > 0:
                scalarized = scalarized.argsort(axis=-1)[..., :-proj_tqc * 3]
            else:
                scalarized = scalarized.argsort(axis=-1)
            # (b, 1, 1)
            scalarized = scalarized[:, None, :]
            # (b, w, n *c) -> (b, w, n* c - 2)
            next_quantile = jnp.take_along_axis(next_quantile, indices=scalarized, axis=-1)
            # (b, w, n *c) -> (b, w, (n - 2)* c)
            next_quantile = next_quantile - ent_coef * next_log_prob.reshape(B, 1, 1)
            # (B, n_reward, n_quantiles)
            reward = batch.rewards.reshape(w.shape[0], self.reward_dim, 1)
            non_terminal = 1 - batch.dones.reshape(B, 1, 1)
            td_target = reward + self.gamma * non_terminal * next_quantile

            preference_sampler(td_target.mean(axis=-1).reshape((batch.observations.shape[0], self.reward_dim)))

            def loss_fn(model):
                loss = model.loss_fn(batch.observations, batch.actions, td_target, w, keys[-1])
                loss = loss.sum(axis=(-1,)).mean()
                return loss

            loss, grads = nnx.value_and_grad(loss_fn)(critic)
            metric_critic.update(q_loss=loss)
            opt_critic.update(grads)
            _, critic_state = nnx.split((critic, opt_critic, metric_critic))
            _, s_sampler = nnx.split(preference_sampler)
            return critic_state, s_sampler

        return jax.jit(critic_update_fn)

    def build_actor_update_fn(self):

        if self.actor_marginal:
            if self.comonotone:
                def tau_mapping(key, shape_ph):
                    x = jax.random.uniform(key, shape=(shape_ph.shape[0], 8))
                    return self.risk_measure(jnp.repeat(x[:, None], axis=1, repeats=self.reward_dim))
            else:
                def tau_mapping(key, shape_ph):
                    x = jax.random.uniform(key, shape=(shape_ph.shape[0], self.reward_dim, 8))
                    return self.risk_measure(x)

            def actor_update_fn(g_critic, s_critic,
                                g_actor, s_actor,
                                g_ent_coef, s_ent_coef,
                                g_sampler, s_sampler,
                                batch: ReplayBufferSamplesWeight,
                                key: jax.Array):
                keys = jax.random.split(key, 3)

                critic, _, _ = nnx.merge(g_critic, s_critic)
                critic: KRIQN

                actor, opt_actor, metric_actor = nnx.merge(g_actor, s_actor)
                __ent_coef, _, _ = nnx.merge(g_ent_coef, s_ent_coef)
                ent_coef = __ent_coef()
                preference_sampler = nnx.merge(g_sampler, s_sampler)

                w = self.random_weight(keys[0], batch.observations)

                taus = tau_mapping(keys[0], batch.observations)

                def loss_fn(model: Actor):
                    action, log_prob = model.sample_and_log_prob(batch.observations, w)
                    # (b, w, n, c)
                    raw_q = critic.marginals_of(batch.observations, action, taus, w)
                    q_info = raw_q.mean(axis=-2).min(axis=-1).mean(axis=0)

                    corr = jax.vmap(lambda x, y: jnp.corrcoef(x, y)[0, 1], in_axes=(1, 1), out_axes=0)(
                        jax.lax.stop_gradient(raw_q.mean(axis=-2).min(axis=-1)), w)
                    qf = raw_q * w[..., None, None]
                    qf = qf.sum(axis=1).mean(axis=-2).min(axis=-1, keepdims=True)

                    kl_loss = (ent_coef * log_prob - qf).mean(axis=-1).mean()
                    loss = kl_loss

                    return loss, (kl_loss, log_prob, corr.mean(), q_info)

                grads, (kl_loss, log_prob, corr, q_info) = nnx.grad(loss_fn, has_aux=True)(actor)

                opt_actor.update(grads)
                log_prob = log_prob.squeeze(axis=-1).mean()
                q_info_kwargs = { f"qf_{i}": q_info[i - 1] for i in range(1, self.reward_dim + 1) }
                metric_actor.update(pi_loss=kl_loss, ent=-log_prob, corr=corr, **q_info_kwargs)
                _, s_actor = nnx.split((actor, opt_actor, metric_actor))
                _, s_sampler = nnx.split(preference_sampler)
                return s_actor, log_prob, s_sampler
        else:
            def actor_update_fn(g_critic, s_critic,
                                g_actor, s_actor,
                                g_ent_coef, s_ent_coef,
                                g_sampler, s_sampler,
                                batch: ReplayBufferSamplesWeight,
                                key: jax.Array):
                keys = jax.random.split(key, 3)

                critic, _, _ = nnx.merge(g_critic, s_critic)
                actor, opt_actor, metric_actor = nnx.merge(g_actor, s_actor)
                __ent_coef, _, _ = nnx.merge(g_ent_coef, s_ent_coef)
                ent_coef = __ent_coef()
                preference_sampler = nnx.merge(g_sampler, s_sampler)

                w = self.random_weight(keys[0], batch.observations)

                taus = self.risk_measure(
                    jax.random.uniform(keys[0], shape=(batch.observations.shape[0], self.reward_dim, 8)))

                def loss_fn(model: Actor):
                    action, log_prob = model.sample_and_log_prob(batch.observations, w)
                    # (b, w, n, c)
                    raw_q = critic(batch.observations, action, taus, w)
                    q_info = raw_q.mean(axis=-2).min(axis=-1).mean(axis=0)

                    corr = jax.vmap(lambda x, y: jnp.corrcoef(x, y)[0, 1], in_axes=(1, 1), out_axes=0)(
                        jax.lax.stop_gradient(raw_q.mean(axis=-2).min(axis=-1)), w)
                    qf = raw_q * w[..., None, None]
                    qf = qf.sum(axis=1).mean(axis=-2).min(axis=-1, keepdims=True)

                    kl_loss = (ent_coef * log_prob - qf).mean(axis=-1).mean()
                    loss = kl_loss

                    return loss, (kl_loss, log_prob, corr.mean(), q_info)

                grads, (kl_loss, log_prob, corr, q_info) = nnx.grad(loss_fn, has_aux=True)(actor)

                opt_actor.update(grads)
                log_prob = log_prob.squeeze(axis=-1).mean()
                q_info_kwargs = { f"qf_{i}": q_info[i - 1] for i in range(1, self.reward_dim + 1) }
                metric_actor.update(pi_loss=kl_loss, ent=-log_prob, corr=corr, **q_info_kwargs)
                _, s_actor = nnx.split((actor, opt_actor, metric_actor))
                _, s_sampler = nnx.split(preference_sampler)
                return s_actor, log_prob, s_sampler

        return jax.jit(actor_update_fn)

    def train_step(self, batch: ReplayBufferSamplesWeight):
        '''
        self.s_critic = self.critic_update_fn(self.g_critic, self.s_critic, self.target_param,
                                              self.g_actor, self.s_actor,
                                              self.g_ent, self.s_ent, batch,
                                              self.rng()
                                              )

        self.target_param = polyak_update(
            self.g_critic, self.s_critic, self.target_param, self.soft_update_ratio,
        )

        self.s_actor, log_prob = self.actor_update_fn(
            self.g_critic, self.s_critic,
            self.g_actor, self.s_actor,
            self.g_ent, self.s_ent,
            batch,
            self.rng()
        )
        self.s_ent = self.ent_coef_loss(
            self.g_ent, self.s_ent,
            log_prob
        )
        '''

        self.s_critic, self.target_param, self.s_actor, self.s_ent, self.s_sampler = self.train_step_fn(
            batch, self.g_critic, self.s_critic, self.target_param,
            self.g_actor, self.s_actor,
            self.g_ent, self.s_ent,
            self.g_sampler, self.s_sampler,
            self.rng()
        )

    @staticmethod
    @jax.jit
    def _sample_preference(g_preference, s_preference, place_holder):
        sampler = nnx.merge(g_preference, s_preference)
        w = sampler.sample_weight(place_holder)
        # for rng management
        _, s_preference = nnx.split(sampler)
        return w, s_preference

    def sample_preference(self, nums: int = 1):
        w = self.random_weight(self.rng(), np.ones(nums, ))

        return w

    def build_train_step(self, ):
        critic_update_fn = self.critic_update_fn
        actor_update_fn = self.actor_update_fn
        ent_coef_update_fn = self.ent_coef_update_fn
        polyak_update_fn = jax.jit(partial(polyak_update, soft_update_ratio=self.soft_update_ratio))

        def update_fn(batch: ReplayBufferSamplesWeight,
                      g_critic: nnx.GraphDef, s_critic: nnx.State, target_param: nnx.Param,
                      g_actor: nnx.GraphDef, s_actor: nnx.State,
                      g_ent: nnx.GraphDef, s_ent: nnx.State,
                      g_sampler: nnx.GraphDef, s_sampler: nnx.State,
                      key: jax.Array
                      ):
            keys = jax.random.split(key, 2)
            s_critic, s_sampler, = critic_update_fn(g_critic, s_critic, target_param,
                                                    g_actor, s_actor,
                                                    g_ent, s_ent,
                                                    g_sampler, s_sampler,
                                                    batch, keys[0]
                                                    )

            target_param = polyak_update_fn(
                g_critic, s_critic, target_param
            )

            s_actor, log_prob, s_sampler = actor_update_fn(
                g_critic, s_critic,
                g_actor, s_actor,
                g_ent, s_ent,
                g_sampler, s_sampler,
                batch,
                keys[1]

            )
            s_ent = ent_coef_update_fn(
                g_ent, s_ent,
                log_prob
            )
            return s_critic, target_param, s_actor, s_ent, s_sampler

        return jax.jit(update_fn)


class MarginalIQNPolicy(KRCriticPolicy):
    def build_critic(self):
        self.critic = MarginalIQN(self.obs_dim, self.n_action, self.reward_dim,
                                  n_critics=3,
                                  rngs=self.rng)
        if self.opt_class == 'adam':
            opt_ = optax.adam
        elif self.opt_class == 'adabelief':
            opt_ = optax.adabelief
        else:
            opt_ = optax.sgd
        self.opt_critic = nnx.Optimizer(self.critic,
                                        optax.chain(opt_(self.critic_lr, ), )
                                        )

        self.metric_critic = nnx.MultiMetric(
            q_loss=nnx.metrics.Average('q_loss'),
        )
        self.g_critic, self.s_critic = nnx.split(
            (self.critic, self.opt_critic, self.metric_critic)
        )
        self.target_param = copy_param(self.critic)
        self.critic_update_fn = self.build_critic_update_fn()
        self.preference_sampler = PreferenceSampler(self.reward_dim, rngs=self.rng)
        self.g_sampler, self.s_sampler = nnx.split(self.preference_sampler)

    def build_critic_update_fn(self,
                               ):
        upper = int(self.truncation_upper)
        lower = int(self.truncation_lower)
        proj_tqc = int(self.proj_tqc)

        def critic_update_fn(g_critic, s_critic, target_critic_params,
                             g_actor, s_actor,
                             g_ent_coef, s_ent_coef,
                             g_sampler, s_sampler,
                             batch: ReplayBufferSamplesWeight,
                             key):
            N_TAUS = 16
            keys = jax.random.split(key, 3)
            critic, opt_critic, metric_critic = nnx.merge(g_critic, s_critic)
            graph, param, *others = nnx.split(critic, nnx.Param, ...)
            target_q_network = nnx.merge(graph, target_critic_params, *others)
            preference_sampler = nnx.merge(g_sampler, s_sampler)

            actor, _, _ = nnx.merge(g_actor, s_actor)
            __ent_coef, _, _ = nnx.merge(g_ent_coef, s_ent_coef)
            ent_coef = __ent_coef()

            w = self.random_weight(keys[-1], batch.observations)  # preference_sampler.sample_weight(batch.observations)

            B = batch.observations.shape[0]
            next_taus = jax.random.uniform(keys[0], shape=(B, self.reward_dim, N_TAUS))
            next_actions, next_log_prob = actor.sample_and_log_prob(batch.next_observations, w)
            # feature, actions, taus, weight
            # (b, w, n, c) -> (B, W, N* C)

            next_quantile = target_q_network(batch.next_observations, next_actions, next_taus, w)
            next_quantile = next_quantile.reshape(w.shape[0], self.reward_dim, -1)

            # remove largest values for each axis
            for d in range(self.reward_dim):
                if upper > 0:
                    _index = jnp.argsort(next_quantile[:, d, :], axis=-1)[..., lower:-upper]
                else:
                    _index = jnp.argsort(next_quantile[:, d, :], axis=-1)[..., lower:]
                _index = _index[..., None, :]
                next_quantile = jnp.take_along_axis(next_quantile, _index, axis=-1)

            # (b, w, 1) * (b, w, n * c) -> (b, n * c)
            scalarized = ((w[..., None] * next_quantile).sum(axis=1))
            # (b, 1)
            if proj_tqc > 0:
                scalarized = scalarized.argsort(axis=-1)[..., :-proj_tqc * 3]
            else:
                scalarized = scalarized.argsort(axis=-1)
            # (b, 1, 1)
            scalarized = scalarized[:, None, :]
            # (b, w, n *c) -> (b, w, n* c - 2)
            next_quantile = jnp.take_along_axis(next_quantile, indices=scalarized, axis=-1)
            # (b, w, n *c) -> (b, w, (n - 2)* c)
            next_quantile = next_quantile - ent_coef * next_log_prob.reshape(B, 1, 1)
            # (B, n_reward, n_quantiles)
            reward = batch.rewards.reshape(w.shape[0], self.reward_dim, 1)
            non_terminal = 1 - batch.dones.reshape(B, 1, 1)
            td_target = reward + self.gamma * non_terminal * next_quantile

            def loss_fn(model):
                loss = model.loss_fn(batch.observations, batch.actions, td_target, w, keys[-1])
                loss = loss.sum(axis=(-1,)).mean()
                return loss

            loss, grads = nnx.value_and_grad(loss_fn)(critic)
            metric_critic.update(q_loss=loss)
            opt_critic.update(grads)
            _, critic_state = nnx.split((critic, opt_critic, metric_critic))
            _, s_sampler = nnx.split(preference_sampler)
            return critic_state, s_sampler

        return jax.jit(critic_update_fn)


class AblationKRCriticPolicy(KRCriticPolicy):
    critic: KRIQNAblation

    def build_critic(self):
        self.critic = KRIQNAblation(self.obs_dim, self.n_action, self.reward_dim,
                                    n_critics=3,
                                    rngs=self.rng)
        self.opt_critic = nnx.Optimizer(self.critic,
                                        optax.chain(optax.adam(self.critic_lr, ), )
                                        )

        self.metric_critic = nnx.MultiMetric(
            q_loss=nnx.metrics.Average('q_loss'),
        )
        self.g_critic, self.s_critic = nnx.split(
            (self.critic, self.opt_critic, self.metric_critic)
        )
        self.target_param = copy_param(self.critic)
        self.critic_update_fn = self.build_critic_update_fn()
        self.preference_sampler = PreferenceSampler(self.reward_dim, rngs=self.rng)
        self.g_sampler, self.s_sampler = nnx.split(self.preference_sampler)


class AblationKRCriticPositionalEncodingPolicy(KRCriticPolicy):
    critic: KRIQNAblationPositionalEncoding

    def build_critic(self):
        self.critic = KRIQNAblationPositionalEncoding(self.obs_dim, self.n_action, self.reward_dim,
                                                      n_critics=3,
                                                      rngs=self.rng)
        self.opt_critic = nnx.Optimizer(self.critic,
                                        optax.chain(optax.adam(self.critic_lr, ), )
                                        )

        self.metric_critic = nnx.MultiMetric(
            q_loss=nnx.metrics.Average('q_loss'),
        )
        self.g_critic, self.s_critic = nnx.split(
            (self.critic, self.opt_critic, self.metric_critic)
        )
        self.target_param = copy_param(self.critic)
        self.critic_update_fn = self.build_critic_update_fn()
        self.preference_sampler = PreferenceSampler(self.reward_dim, rngs=self.rng)
        self.g_sampler, self.s_sampler = nnx.split(self.preference_sampler)
