from networks import DeterministicPolicy, SquashedGaussianPolicy
from networks.actor import ORAACActor, VAEActor

from networks import VectorQuantileCritic, VectorCritic, Lagrangian, Constant
from rl.policies.rl_train_state import (RLTrainState, TrainState, VAETrainState, VAERLTrainState)
from misc.rng_modules import PRNGSequence
from functools import partial
from abc import ABCMeta, abstractmethod
import jax
import numpy as np
import flax
from flax import struct
import flax.training.orbax_utils
import optax
import orbax.checkpoint
from typing import Type, Callable
import orbax


class BaseActorCriticPolicy(object, metaclass=ABCMeta):
    actor: DeterministicPolicy
    actor_train_state: RLTrainState

    critic: VectorCritic
    critic_train_state: RLTrainState

    def __init__(self,
                 observation_ph, action_ph,
                 actor_update_fn, critic_update_fn,
                 opt_class: Type[optax.adam] = optax.adam,
                 learning_rate: float = 3e-4,
                 n_critics: int = 2,
                 seed: int = 42,
                 ):
        self.opt_class = opt_class
        self.learning_rate = learning_rate
        self.n_critics = n_critics
        self.actor_update_fn = actor_update_fn
        self.critic_update_fn = critic_update_fn
        self.rng = PRNGSequence(seed)
        self.build(observation_ph, action_ph)

    @abstractmethod
    def build(self, observation_ph, action_ph):
        pass

    @abstractmethod
    def predict(self, observation):
        pass

    @property
    def checkpoint(self):
        return {"actor": self.actor_train_state, "critic": self.critic_train_state}

    def save_checkpoint(self, path, overwrite: bool = True):
        ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler())
        CKPT_PYTREE = self.checkpoint
        ckptr.save(path, CKPT_PYTREE,
                   save_args=flax.training.orbax_utils.save_args_from_target(CKPT_PYTREE), force=overwrite)

    def save_actor(self, path, overwrite: bool = True):
        ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler())
        CKPT_PYTREE = {"actor": self.actor_train_state}
        ckptr.save(path, CKPT_PYTREE,
                   save_args=flax.training.orbax_utils.save_args_from_target(CKPT_PYTREE), force=overwrite)

    def _load_checkpoint(self, path: str):
        target = self.checkpoint

        ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler())
        restored = ckptr.restore(path, item=target,
                                 restore_args=flax.training.orbax_utils.restore_args_from_target(target, mesh=None))

        return restored

    def load_checkpoint(self, path):
        restored = self._load_checkpoint(path)
        self.actor_train_state = restored['actor']
        self.critic_train_state = restored['critic']


class TD3Policy(BaseActorCriticPolicy):
    bc_weight_coef: Lagrangian
    bc_weight_state: TrainState
    bc_weight_update_fn: Callable

    bc_tolerance: float = 0.25

    def build_bc_update_loss_fn(self):
        @jax.jit
        def update_fn(lag_state: TrainState, errors):
            def loss_fn(params):
                log_coef = lag_state.apply_fn(params)
                loss = (log_coef * jax.lax.stop_gradient(self.bc_tolerance - errors)).mean()
                return loss, {"bc_coef": jax.numpy.exp(log_coef)}

            grad, infos = jax.grad(loss_fn, has_aux=True)(lag_state.params)
            lag_state = lag_state.apply_gradients(grads=grad)
            return lag_state, infos

        return update_fn

    def build(self, observation_ph, action_ph):
        self.actor = DeterministicPolicy(action_ph.shape[-1])
        actor_param = self.actor.init(next(self.rng), observation_ph)
        actor_target_params = actor_param.copy()['params']
        self.actor_train_state = RLTrainState.create(apply_fn=self.actor.apply,
                                                     params=actor_param['params'],
                                                     tx=self.opt_class(self.learning_rate),
                                                     target_params=actor_target_params)

        self.critic = VectorCritic(n_critics=self.n_critics)
        critic_param = self.critic.init(next(self.rng), observation_ph, action_ph)
        critic_target_params = critic_param.copy()['params']
        self.critic_train_state = RLTrainState.create(apply_fn=self.critic.apply,
                                                      params=critic_param['params'],
                                                      tx=self.opt_class(self.learning_rate),
                                                      target_params=critic_target_params
                                                      )

        self.bc_weight_coef = Lagrangian()
        lag_param = self.bc_weight_coef.init(next(self.rng))
        self.bc_weight_state = TrainState.create(apply_fn=self.bc_weight_coef.apply,
                                                 params=lag_param,
                                                 tx=self.opt_class(self.learning_rate),
                                                 )
        self.bc_weight_update_fn = self.build_bc_update_loss_fn()

    @partial(jax.jit, static_argnums=(0,))
    def _predict(self, params, observation):
        action = jax.lax.stop_gradient(self.actor.apply({'params': params},
                                                        observation))
        return action

    def predict(self, observation):
        return np.asarray(self._predict(self.actor_train_state.params, observation))

    @partial(jax.jit, static_argnums=(0,))
    def _explore(self, params, observation, key):
        action = self._predict(params, observation)
        noise = (0.3 * jax.random.normal(key, shape=action.shape)).clip(-0.5, 0.5)
        return (action + noise).clip(-1., 1.)

    def explore(self, observation, key):
        return self._explore(self.actor_train_state.params, observation, key)

    def update(self, batch, key, cnt, delay):
        key1, key2 = jax.random.split(key, 2)
        self.critic_train_state, loss_info = self.critic_update_fn(self.critic_train_state,
                                                                   self.actor_train_state,
                                                                   batch, key1)
        if cnt % delay == 0:
            self.actor_train_state, actor_loss_info = self.actor_update_fn(self.critic_train_state,
                                                                           self.actor_train_state,
                                                                           # self.bc_weight_state,
                                                                           batch, key2
                                                                           )
            # err = actor_loss_info['bc_loss']
            # self.bc_weight_state, coef_loss_info = self.bc_weight_update_fn(self.bc_weight_state, err)
            # loss_info.update(coef_loss_info)
            loss_info.update(actor_loss_info)
        return loss_info


class SACPolicy(TD3Policy):
    entropy_coefficient: Lagrangian | Constant
    ent_coef_train_state: TrainState
    deterministic_pred: Callable

    def __init__(self,
                 observation_ph, action_ph,
                 actor_update_fn, critic_update_fn, ent_coef_update_fn,
                 learn_ent_coef: bool = True,
                 ent_coef: float | None = 0.01,
                 opt_class: Type[optax.adam] = optax.adam,
                 learning_rate: float = 3e-4,
                 n_critics: int = 2,
                 seed: int = 42,
                 ):
        self.learn_ent_coef = learn_ent_coef
        self.ent_coef_update_fn = ent_coef_update_fn
        self.ent_coef = ent_coef

        super().__init__(observation_ph=observation_ph,
                         action_ph=action_ph,
                         actor_update_fn=actor_update_fn,
                         critic_update_fn=critic_update_fn,
                         opt_class=opt_class, learning_rate=learning_rate,
                         n_critics=n_critics, seed=seed
                         )

    def build(self, observation_ph, action_ph):
        self.actor = SquashedGaussianPolicy(action_ph.shape[-1])
        actor_param = self.actor.init(next(self.rng), observation_ph)
        actor_target_params = actor_param.copy()['params']

        self.actor_train_state = RLTrainState.create(apply_fn=self.actor.apply,
                                                     params=actor_param['params'],
                                                     tx=self.opt_class(self.learning_rate),
                                                     target_params=actor_target_params)
        self.deterministic_pred = jax.jit(self.actor.deterministic)
        self.critic = VectorCritic(n_critics=self.n_critics)
        critic_param = self.critic.init(next(self.rng), observation_ph, action_ph)
        critic_target_params = critic_param.copy()['params']

        self.critic_train_state = RLTrainState.create(apply_fn=self.critic.apply,
                                                      params=critic_param['params'],
                                                      tx=self.opt_class(self.learning_rate),
                                                      target_params=critic_target_params
                                                      )

        if self.learn_ent_coef:
            self.entropy_coefficient = Lagrangian()
        else:
            self.entropy_coefficient = Constant(self.ent_coef)
        lag_param = self.entropy_coefficient.init(next(self.rng))
        self.ent_coef_train_state = TrainState.create(apply_fn=self.entropy_coefficient.apply,
                                                      params=lag_param,
                                                      tx=self.opt_class(self.learning_rate),
                                                      )

    @partial(jax.jit, static_argnums=(0,))
    def _deterministic_predict(self, params, observation):
        action, _ = jax.lax.stop_gradient(self.deterministic_pred({'params': params},
                                                                  observation))

        return action

    @partial(jax.jit, static_argnums=(0,))
    def _predict(self, params, observation, key):
        action, _ = jax.lax.stop_gradient(self.actor.apply({'params': params},
                                                           observation,
                                                           rngs={"rng_stream": key}))

        return action

    def predict(self, observation, deterministic: bool = False):
        if deterministic:
            return np.asarray(
                self.deterministic_pred(self.actor_train_state.params, observation))
        else:
            return np.asarray(
                self._explore(self.actor_train_state.params, observation,
                              key=next(self.rng)))

    @partial(jax.jit, static_argnums=(0,))
    def _explore(self, params, observation, key):
        action = self._predict(params, observation, key)
        return action

    def explore(self, observation, key):
        return self._explore(self.actor_train_state.params, observation, key)

    def update(self, batch, key, cnt, delay):
        key1, key2 = jax.random.split(key, 2)
        self.critic_train_state, loss_info = self.critic_update_fn(self.critic_train_state,
                                                                   self.actor_train_state,
                                                                   self.ent_coef_train_state,
                                                                   batch, key1)

        if cnt % delay == 0:
            self.actor_train_state, actor_loss_info = self.actor_update_fn(self.critic_train_state,
                                                                           self.actor_train_state,
                                                                           self.ent_coef_train_state,
                                                                           batch, key2)
            log_probs = actor_loss_info.pop('log_probs')
            self.ent_coef_train_state, ent_coef_info = self.ent_coef_update_fn(self.ent_coef_train_state, log_probs)
            loss_info.update(actor_loss_info)
            loss_info.update(ent_coef_info)
        return loss_info


class CQLPolicy(SACPolicy):
    lagrangian: Lagrangian | Constant
    lag_train_state: TrainState

    def __init__(self,
                 observation_ph, action_ph,
                 actor_update_fn, critic_update_fn, ent_coef_update_fn,
                 lag_update_fn,
                 learn_ent_coef: bool = True,
                 ent_coef: float | None = 0.01,
                 auto_adjust_kl: bool = True,
                 kl_difference: float | None = 10,
                 opt_class: Type[optax.adam] = optax.adam,
                 learning_rate: float = 3e-4,
                 n_critics: int = 2,
                 seed: int = 42,
                 ):

        self.lag_update_fn = lag_update_fn
        self.auto_adjust_kl = auto_adjust_kl
        self.kl_difference = kl_difference

        super().__init__(
            observation_ph, action_ph,
            actor_update_fn, critic_update_fn, ent_coef_update_fn,
            learn_ent_coef, ent_coef,
            opt_class=opt_class, learning_rate=learning_rate,
            n_critics=n_critics, seed=seed
        )

    def build(self, observation_ph, action_ph):
        super().build(observation_ph, action_ph)
        if self.auto_adjust_kl:
            self.lagrangian = Lagrangian()
        else:
            self.lagrangian = Constant(self.kl_difference)
        lag_param = self.lagrangian.init(next(self.rng))
        self.lag_train_state = TrainState.create(apply_fn=self.lagrangian.apply,
                                                 params=lag_param,
                                                 tx=self.opt_class(self.learning_rate),
                                                 )

    def update(self, batch, key, cnt, delay):
        key1, key2 = jax.random.split(key, 2)
        self.critic_train_state, loss_info = self.critic_update_fn(self.critic_train_state,
                                                                   self.actor_train_state,
                                                                   self.ent_coef_train_state,
                                                                   self.lag_train_state,
                                                                   batch, key1)
        kl_potential = loss_info['kl_potential']
        self.lag_train_state, lag_loss_info = self.lag_update_fn(self.lag_train_state, kl_potential)
        self.actor_train_state, actor_loss_info = self.actor_update_fn(self.critic_train_state,
                                                                       self.actor_train_state,
                                                                       self.ent_coef_train_state,
                                                                       batch, key2)
        log_probs = actor_loss_info.pop('log_probs')
        self.ent_coef_train_state, ent_coef_info = self.ent_coef_update_fn(self.ent_coef_train_state, log_probs)
        loss_info.update(actor_loss_info)
        loss_info.update(ent_coef_info)
        loss_info.update(lag_loss_info)
        return loss_info


class CODACPolicy(CQLPolicy):
    lagrangian: Lagrangian | Constant
    lag_train_state: TrainState

    def __init__(self,
                 observation_ph, action_ph,
                 actor_update_fn, critic_update_fn, ent_coef_update_fn,
                 lag_update_fn,
                 learn_ent_coef: bool = True,
                 ent_coef: float | None = 0.01,
                 auto_adjust_kl: bool = True,
                 kl_difference: float | None = 10,
                 opt_class: Type[optax.adam] = optax.adam,
                 learning_rate: float = 3e-4,
                 n_critics: int = 2,
                 smooth_critic: bool = False,
                 seed: int = 42,
                 ):
        self.smooth_critic = smooth_critic

        super().__init__(observation_ph, action_ph,
                         actor_update_fn=actor_update_fn, critic_update_fn=critic_update_fn,
                         ent_coef_update_fn=ent_coef_update_fn, lag_update_fn=lag_update_fn,
                         learn_ent_coef=learn_ent_coef, ent_coef=ent_coef,
                         auto_adjust_kl=auto_adjust_kl, kl_difference=kl_difference,
                         opt_class=opt_class,
                         learning_rate=learning_rate,
                         n_critics=n_critics,
                         seed=seed
                         )

    def build(self, observation_ph, action_ph):
        taus_ph = jax.numpy.zeros(shape=(observation_ph.shape[0], 32))
        self.actor = SquashedGaussianPolicy(action_ph.shape[-1])
        actor_param = self.actor.init(next(self.rng), observation_ph)
        actor_target_params = actor_param.copy()['params']
        self.actor_train_state = RLTrainState.create(apply_fn=self.actor.apply,
                                                     params=actor_param['params'],
                                                     tx=self.opt_class(self.learning_rate),
                                                     target_params=actor_target_params)
        self.deterministic_pred = jax.jit(self.actor.deterministic)

        self.critic = VectorQuantileCritic(n_critics=self.n_critics, smooth=self.smooth_critic, )
        critic_param = self.critic.init(next(self.rng), observation_ph, action_ph, taus_ph)
        critic_target_params = critic_param.copy()['params']
        self.critic_train_state = RLTrainState.create(apply_fn=self.critic.apply,
                                                      params=critic_param['params'],
                                                      tx=self.opt_class(self.learning_rate),
                                                      target_params=critic_target_params,

                                                      )

        if self.learn_ent_coef:
            self.entropy_coefficient = Lagrangian()
        else:
            self.entropy_coefficient = Constant(self.ent_coef)
        ent_param = self.entropy_coefficient.init(next(self.rng))
        self.ent_coef_train_state = TrainState.create(apply_fn=self.entropy_coefficient.apply,
                                                      params=ent_param,
                                                      tx=self.opt_class(self.learning_rate),
                                                      )

        if self.auto_adjust_kl:
            self.lagrangian = Lagrangian()
        else:
            self.lagrangian = Constant(self.kl_difference)
        lag_param = self.lagrangian.init(next(self.rng))
        self.lag_train_state = TrainState.create(apply_fn=self.lagrangian.apply,
                                                 params=lag_param,
                                                 tx=self.opt_class(self.learning_rate),
                                                 )


class BCActorTrainState(TrainState):
    log_prob_of: Callable = struct.field(pytree_node=False)


class OnlineStochasticTQCPolicy(SACPolicy):
    critic: VectorQuantileCritic
    actor: SquashedGaussianPolicy

    def build(self, observation_ph, action_ph):
        if self.learn_ent_coef:
            self.entropy_coefficient = Lagrangian()
        else:
            self.entropy_coefficient = Constant(self.ent_coef)
        ent_param = self.entropy_coefficient.init(next(self.rng))
        self.ent_coef_train_state = TrainState.create(apply_fn=self.entropy_coefficient.apply,
                                                      params=ent_param,
                                                      tx=self.opt_class(self.learning_rate),
                                                      )

        taus_ph = jax.numpy.zeros(shape=(observation_ph.shape[0], 32))
        self.actor = SquashedGaussianPolicy(action_ph.shape[-1])
        actor_param = self.actor.init(next(self.rng), observation_ph)
        actor_target_params = actor_param.copy()['params']
        self.actor_train_state = RLTrainState.create(apply_fn=self.actor.apply,
                                                     params=actor_param['params'],
                                                     tx=self.opt_class(self.learning_rate),
                                                     target_params=actor_target_params)

        self.deterministic_pred = jax.jit(self.actor.deterministic)
        self.critic = VectorQuantileCritic(n_critics=self.n_critics, smooth=True,
                                           monotone=False)
        critic_param = self.critic.init(next(self.rng), observation_ph, action_ph, taus_ph)
        critic_target_params = critic_param.copy()['params']
        self.critic_train_state = RLTrainState.create(apply_fn=self.critic.apply,
                                                      params=critic_param['params'],
                                                      tx=self.opt_class(self.learning_rate),
                                                      target_params=critic_target_params,

                                                      )

        if self.learn_ent_coef:
            self.lagrangian = Lagrangian()
        else:
            self.lagrangian = Constant(np.log(self.ent_coef))
        lag_param = self.lagrangian.init(next(self.rng))
        self.lag_train_state = TrainState.create(apply_fn=self.lagrangian.apply,
                                                 params=lag_param,
                                                 tx=self.opt_class(self.learning_rate),
                                                 )


class StochasticTQCPolicy(SACPolicy):
    critic: VectorQuantileCritic
    bc_actor: SquashedGaussianPolicy
    bc_actor_train_state: BCActorTrainState
    actor: SquashedGaussianPolicy

    def __init__(self,
                 observation_ph, action_ph,
                 actor_update_fn, critic_update_fn, ent_coef_update_fn, bc_update_fn,
                 learn_ent_coef: bool = True,
                 ent_coef: float | None = 0.01,
                 smooth=True,
                 opt_class: Type[optax.adam] = optax.adam,
                 learning_rate: float = 3e-4,
                 n_critics: int = 2,
                 seed: int = 42,
                 ):

        self.smooth = smooth
        self.bc_update_fn = bc_update_fn

        super().__init__(
            observation_ph=observation_ph,
            action_ph=action_ph,
            actor_update_fn=actor_update_fn,
            critic_update_fn=critic_update_fn,
            ent_coef_update_fn=ent_coef_update_fn,
            learn_ent_coef=learn_ent_coef,
            ent_coef=ent_coef,
            opt_class=opt_class,
            learning_rate=learning_rate,
            n_critics=n_critics,
            seed=seed
        )

    def build(self, observation_ph, action_ph):
        if self.learn_ent_coef:
            self.entropy_coefficient = Lagrangian()
        else:
            self.entropy_coefficient = Constant(self.ent_coef)
        ent_param = self.entropy_coefficient.init(next(self.rng))
        self.ent_coef_train_state = TrainState.create(apply_fn=self.entropy_coefficient.apply,
                                                      params=ent_param,
                                                      tx=self.opt_class(self.learning_rate),
                                                      )

        taus_ph = jax.numpy.zeros(shape=(observation_ph.shape[0], 32))
        self.actor = SquashedGaussianPolicy(action_ph.shape[-1])
        self.bc_actor = SquashedGaussianPolicy(action_ph.shape[-1])

        actor_param = self.actor.init(next(self.rng), observation_ph)
        actor_target_params = actor_param.copy()['params']
        self.actor_train_state = RLTrainState.create(apply_fn=self.actor.apply,
                                                     params=actor_param['params'],
                                                     tx=self.opt_class(self.learning_rate),
                                                     target_params=actor_target_params)
        self.bc_actor_train_state = BCActorTrainState.create(
            apply_fn=flax.linen.apply(
                SquashedGaussianPolicy.log_prob_of,
                self.bc_actor,
            ),
            params=actor_param.copy()['params'],
            tx=self.opt_class(self.learning_rate),
            log_prob_of=flax.linen.apply(
                SquashedGaussianPolicy.log_prob_of,
                self.bc_actor,
            )
        )

        self.deterministic_pred = jax.jit(self.actor.deterministic)
        self.critic = VectorQuantileCritic(n_critics=self.n_critics, smooth=self.smooth,
                                           monotone=False)
        critic_param = self.critic.init(next(self.rng), observation_ph, action_ph, taus_ph)
        critic_target_params = critic_param.copy()['params']
        self.critic_train_state = RLTrainState.create(apply_fn=self.critic.apply,
                                                      params=critic_param['params'],
                                                      tx=self.opt_class(self.learning_rate),
                                                      target_params=critic_target_params,

                                                      )

        if self.learn_ent_coef:
            self.lagrangian = Lagrangian()
        else:
            self.lagrangian = Constant(np.log(self.ent_coef))
        lag_param = self.lagrangian.init(next(self.rng))
        self.lag_train_state = TrainState.create(apply_fn=self.lagrangian.apply,
                                                 params=lag_param,
                                                 tx=self.opt_class(self.learning_rate),
                                                 )

    def update(self, batch, key, cnt, delay):
        key1, key2 = jax.random.split(key, 2)
        self.critic_train_state, loss_info = self.critic_update_fn(self.critic_train_state,
                                                                   self.actor_train_state,
                                                                   self.bc_actor_train_state,
                                                                   self.ent_coef_train_state,
                                                                   batch, key1)
        self.bc_actor_train_state, bc_loss_info = self.bc_update_fn(self.bc_actor_train_state,
                                                                    batch
                                                                    )

        self.actor_train_state, actor_loss_info = self.actor_update_fn(self.critic_train_state,
                                                                       self.actor_train_state,
                                                                       self.bc_actor_train_state,
                                                                       self.ent_coef_train_state,
                                                                       batch, key2)
        log_probs = actor_loss_info['log_probs']
        self.ent_coef_train_state, ent_coef_info = self.ent_coef_update_fn(self.ent_coef_train_state, log_probs)
        loss_info.update(bc_loss_info)
        loss_info.update(actor_loss_info)
        loss_info.update(ent_coef_info)
        return loss_info


class DeterministicTQCPolicy(TD3Policy):
    critic: VectorQuantileCritic

    def __init__(self,
                 observation_ph, action_ph,
                 actor_update_fn, critic_update_fn,
                 opt_class: Type[optax.adam] = optax.adam,
                 learning_rate: float = 3e-4,
                 n_critics: int = 2,
                 smooth: bool = True,
                 ff_feature: bool = False,
                 seed: int = 42,
                 ):
        self.smooth = smooth
        self.ff_feature = ff_feature
        super().__init__(
            observation_ph=observation_ph,
            action_ph=action_ph,
            actor_update_fn=actor_update_fn,
            critic_update_fn=critic_update_fn,
            opt_class=opt_class,
            learning_rate=learning_rate,
            n_critics=n_critics,
            seed=seed,
        )

    def build(self, observation_ph, action_ph):
        taus_ph = jax.numpy.zeros(shape=(observation_ph.shape[0], 32))
        self.actor = DeterministicPolicy(action_ph.shape[-1], ff_feature=self.ff_feature)
        actor_param = self.actor.init(next(self.rng), observation_ph)
        actor_target_params = actor_param.copy()['params']

        self.actor_train_state = RLTrainState.create(apply_fn=self.actor.apply,
                                                     params=actor_param['params'],
                                                     tx=self.opt_class(self.learning_rate),
                                                     target_params=actor_target_params)

        self.critic = VectorQuantileCritic(n_critics=self.n_critics, smooth=self.smooth,
                                           monotone=False, ff_feature=self.ff_feature)

        critic_param = self.critic.init(next(self.rng), observation_ph, action_ph, taus_ph, )
        critic_target_params = critic_param.copy()['params']

        self.critic_train_state = RLTrainState.create(apply_fn=self.critic.apply,
                                                      params=critic_param['params'],
                                                      tx=self.opt_class(self.learning_rate),
                                                      target_params=critic_target_params
                                                      )

    def update(self, batch, key, cnt, delay):
        self.critic_train_state, loss_info = self.critic_update_fn(self.critic_train_state,
                                                                   self.actor_train_state,
                                                                   batch, key)

        if cnt % delay == 0:
            key = jax.random.split(key, 2)[0]
            self.actor_train_state, actor_loss_info = self.actor_update_fn(self.critic_train_state,
                                                                           self.actor_train_state,
                                                                           batch, key)
            loss_info.update(actor_loss_info)
        return loss_info


class VAETD3Policy(TD3Policy):
    actor: VAEActor
    actor_train_state: VAERLTrainState

    def build(self, observation_ph, action_ph):
        self.actor = VAEActor(action_ph.shape[-1])
        actor_param = self.actor.init(next(self.rng), observation_ph, action_ph)
        actor_target_params = actor_param.copy()['params']
        loss_fn = flax.linen.apply(VAEActor.loss_fn, self.actor)
        bulk_sample = flax.linen.apply(VAEActor.non_clipping_bulk_sample, self.actor)
        sample = flax.linen.apply(VAEActor.non_clipping_noise_sample, self.actor)
        self.actor_train_state = VAERLTrainState.create(apply_fn=sample,
                                                        params=actor_param['params'],
                                                        tx=self.opt_class(self.learning_rate),
                                                        loss_fn=loss_fn,
                                                        bulk_sample=bulk_sample,
                                                        target_params=actor_target_params)

        self.critic = VectorCritic(n_critics=self.n_critics)
        critic_param = self.critic.init(next(self.rng), observation_ph, action_ph)
        critic_target_params = critic_param.copy()['params']
        self.critic_train_state = RLTrainState.create(apply_fn=self.critic.apply,
                                                      params=critic_param['params'],
                                                      tx=self.opt_class(self.learning_rate),
                                                      target_params=critic_target_params
                                                      )

    def predict(self, observation):
        return np.asarray(self._predict(self.actor_train_state.params,
                                        self.critic_train_state.params,
                                        observation, next(self.rng)))

    @partial(jax.jit, static_argnums=(0,))
    def _predict(self,
                 params_actor,
                 params_critic,
                 observation,
                 key
                 ):
        actions = self.actor_train_state.bulk_sample({"params": params_actor}, observation, rngs={"rng_stream": key})

        def evaluate_actions(obs, a):
            return self.critic.apply({'params': params_critic}, obs, a).mean()

        q_a = jax.vmap(evaluate_actions, in_axes=(None, 1), out_axes=0)(observation, actions)
        q_a = q_a.reshape(-1, )
        actions = actions.reshape(-1, actions.shape[-1])
        index = jax.numpy.argmax(q_a, )

        return actions[index]


class QLDLikeTQCPolicy(DeterministicTQCPolicy):
    critic: VectorQuantileCritic

    def __init__(self,
                 observation_ph, action_ph,
                 actor_update_fn, critic_update_fn,
                 risk_measure_fn,
                 smooth: bool = True,
                 opt_class: Type[optax.adam] = optax.adam,
                 learning_rate: float = 3e-4,
                 n_critics: int = 2,
                 seed: int = 42,
                 ):
        super(QLDLikeTQCPolicy, self).__init__(
            observation_ph=observation_ph, action_ph=action_ph,
            actor_update_fn=actor_update_fn, critic_update_fn=critic_update_fn,
            opt_class=opt_class,
            learning_rate=learning_rate, smooth=smooth,
            n_critics=n_critics,
            seed=seed,
        )
        self.risk_measure_fn = risk_measure_fn
        self.taus = jax.numpy.linspace(0, 1, 32)

    @abstractmethod
    def build(self, observation_ph, action_ph):
        pass

    def predict(self, observation):
        return self._predict(self.actor_train_state.params, self.critic_train_state.params,
                             observation, key=next(self.rng))

    @partial(jax.jit, static_argnums=(0,))
    def _predict(self, params_actor, params_critic, observation, key):
        observation = jax.numpy.repeat(observation, axis=0, repeats=50)
        actions = self.actor.apply({"params": params_actor}, observation,
                                   rngs={"rng_streams": key})
        taus = jax.numpy.repeat(self.taus[None], axis=0, repeats=50)
        taus = self.risk_measure_fn(taus)
        actions_q = self.critic.apply({"params": params_critic}, observation, actions, taus)
        actions_q = actions_q.mean(axis=(-1, -2))
        key, _ = jax.random.split(key, 2)
        index = jax.numpy.argmax(actions_q)
        action = actions_q[index]
        return action


class ORAACPolicy(QLDLikeTQCPolicy):
    bc_actor: VAEActor
    bc_actor_train_state: VAETrainState
    actor: ORAACActor
    actor_train_state: RLTrainState
    critic: VectorQuantileCritic

    def __init__(self,
                 observation_ph, action_ph,
                 actor_update_fn, critic_update_fn, bc_actor_update_fn,
                 risk_measure_fn,
                 opt_class: Type[optax.adam] = optax.adam,
                 learning_rate: float = 3e-4,
                 n_critics: int = 2,
                 phi: float = 0.05,
                 seed: int = 42,
                 ):
        self.bc_actor_update_fn = bc_actor_update_fn
        self.phi = phi

        super().__init__(observation_ph, action_ph,
                         actor_update_fn=actor_update_fn,
                         critic_update_fn=critic_update_fn,
                         risk_measure_fn=risk_measure_fn,
                         opt_class=opt_class,
                         learning_rate=learning_rate,
                         n_critics=n_critics,
                         seed=seed
                         )

    @property
    def checkpoint(self):
        return {"actor": self.actor_train_state, "critic": self.critic_train_state,
                "vae_actor": self.bc_actor_train_state}

    @partial(jax.jit, static_argnums=(0,))
    def _predict(self,
                 params_actor,
                 params_bc_actor,
                 params_critic,
                 observation,
                 key
                 ):
        bc_actions = self.bc_actor_train_state.bulk_sample({"params": params_bc_actor},
                                                           observation,
                                                           rngs={"rng_stream": key}
                                                           )

        def sample_actions(obs, bc):
            return self.actor.apply({'params': params_actor}, obs, bc)

        actions = jax.vmap(sample_actions, in_axes=(None, 1), out_axes=1)(observation, bc_actions)

        taus = self.taus[None]
        taus = self.risk_measure_fn(taus)

        def evaluate_actions(obs, a):
            return self.critic.apply({'params': params_critic}, obs, a, taus).mean()

        q_a = jax.vmap(evaluate_actions, in_axes=(None, 1), out_axes=0)(observation, actions)
        q_a = q_a.reshape(-1, )
        actions = actions.reshape(-1, actions.shape[-1])
        index = jax.numpy.argmax(q_a, )

        return actions[index]

    def predict(self, observation):
        return np.asarray(self._predict(self.actor_train_state.params,
                                        self.bc_actor_train_state.params,
                                        self.critic_train_state.params,
                                        observation, next(self.rng)))

    def build(self, observation_ph, action_ph):
        self.actor = ORAACActor(action_ph.shape[-1], phi=self.phi)
        actor_param = self.actor.init(next(self.rng), observation_ph, action_ph)
        actor_target_params = actor_param.copy()['params']
        self.actor_train_state = RLTrainState.create(apply_fn=self.actor.apply,
                                                     params=actor_param['params'],
                                                     tx=self.opt_class(self.learning_rate),
                                                     target_params=actor_target_params)
        self.bc_actor = VAEActor(action_ph.shape[-1])
        bc_actor_param = self.bc_actor.init(next(self.rng), observation_ph, action_ph)
        self.bc_actor_train_state = VAETrainState.create(
            apply_fn=flax.linen.apply(VAEActor.sample, self.bc_actor),
            loss_fn=flax.linen.apply(VAEActor.loss_fn, self.bc_actor, ),
            bulk_sample=flax.linen.apply(VAEActor.bulk_sample, self.bc_actor),
            params=bc_actor_param['params'],
            tx=self.opt_class(self.learning_rate),
        )

        self.critic = VectorQuantileCritic(n_critics=self.n_critics, smooth=False)
        taus_ph = jax.numpy.zeros(shape=(action_ph.shape))
        critic_param = self.critic.init(next(self.rng), observation_ph, action_ph, taus_ph)
        critic_target_params = critic_param.copy()['params']

        self.critic_train_state = RLTrainState.create(apply_fn=self.critic.apply,
                                                      params=critic_param['params'],
                                                      tx=self.opt_class(self.learning_rate),
                                                      target_params=critic_target_params
                                                      )

    def update(self, batch, key, cnt, delay, learn_bc):
        """
        :param batch: batch
        :param key: key
        :param cnt: counter
        :param delay: for the policy delay, optional
        :param learn_bc: for the early stop callback
        :return:
        """
        loss_info = {}
        if learn_bc:
            self.bc_actor_train_state, _loss_info = self.bc_actor_update_fn(
                self.bc_actor_train_state, batch, key
            )
            loss_info.update(_loss_info)

        self.critic_train_state, _loss_info = self.critic_update_fn(self.critic_train_state,
                                                                    self.bc_actor_train_state,
                                                                    batch, key)
        loss_info.update(_loss_info)

        if cnt % delay == 0:
            key = jax.random.split(key, 2)[0]
            self.actor_train_state, actor_loss_info = self.actor_update_fn(self.critic_train_state,
                                                                           self.actor_train_state,
                                                                           self.bc_actor_train_state,
                                                                           batch, key)
            loss_info.update(actor_loss_info)
        return loss_info


class MRDeterministicTQCPolicy(DeterministicTQCPolicy):
    auxiliary_critic: VectorQuantileCritic
    auxiliary_critic_train_state: TrainState  # DiffusionTrainState

    def __init__(self,
                 observation_ph, action_ph,
                 actor_update_fn, critic_update_fn,
                 opt_class: Type[optax.adam] = optax.adam,
                 learning_rate: float = 3e-4,
                 n_critics: int = 2,
                 smooth: bool = True,
                 ff_feature: bool = False,
                 seed: int = 42,
                 ):
        super().__init__(observation_ph=observation_ph,
                         action_ph=action_ph,
                         actor_update_fn=actor_update_fn,
                         critic_update_fn=critic_update_fn,
                         opt_class=opt_class,
                         learning_rate=learning_rate,
                         n_critics=n_critics,
                         smooth=smooth,
                         ff_feature=ff_feature,
                         seed=seed
                         )

    def q_values(self, observation):
        action = self.predict(observation)
        action = action.reshape(observation.shape[0], -1)
        return np.asarray(self._q_value(self.critic_train_state.params, observation, action))

    def aux_q_values(self, observation):
        action = self.predict(observation)
        action = action.reshape(observation.shape[0], -1)
        return np.asarray(self._aux_value(self.auxiliary_critic_train_state.params, observation, action))

    @partial(jax.jit, static_argnums=(0,))
    def _aux_value(self, params_critic, observation, actions):
        taus = jax.numpy.linspace(0, 1, 100)[None]
        q_values = self.auxiliary_critic.apply({"params": params_critic}, observation, actions, taus)
        return q_values

    @partial(jax.jit, static_argnums=(0,))
    def _q_value(self, params_critic, observation, actions):
        taus = jax.numpy.linspace(0, 1, 100)[None]
        q_values = self.critic.apply({"params": params_critic}, observation, actions, taus)
        return q_values

    def update(self, batch, key, cnt, delay):
        key1, key2, key3 = jax.random.split(key, 3)

        loss_info = {}
        self.critic_train_state, critic_loss_info = self.critic_update_fn(
            self.critic_train_state,
            self.actor_train_state,
            batch, key2)

        loss_info.update(critic_loss_info)
        if cnt % delay == 0:
            self.actor_train_state, actor_loss_info = self.actor_update_fn(self.actor_train_state,
                                                                           self.critic_train_state,
                                                                           batch, key3)
            loss_info.update(actor_loss_info)
        return loss_info


class VAETQCPolicy(DeterministicTQCPolicy):
    critic: VectorQuantileCritic
    actor: VAEActor
    actor_train_state: VAERLTrainState

    def __init__(self,
                 observation_ph, action_ph,
                 actor_update_fn, critic_update_fn,
                 risk_measure_fn, risk_eta,
                 opt_class: Type[optax.adam] = optax.adam,
                 learning_rate: float = 3e-4,
                 n_critics: int = 2,
                 smooth: bool = True,
                 ff_feature: bool = False,
                 seed: int = 42,
                 ):
        self.risk_eta = risk_eta
        self.risk_measure_fn = risk_measure_fn
        self.taus = jax.numpy.linspace(0, 1, 32)

        super().__init__(observation_ph, action_ph,
                         actor_update_fn=actor_update_fn,
                         critic_update_fn=critic_update_fn,
                         opt_class=opt_class,
                         learning_rate=learning_rate,
                         n_critics=n_critics,
                         ff_feature=ff_feature,
                         smooth=smooth,
                         seed=seed
                         )

    def build(self, observation_ph, action_ph):
        self.actor = VAEActor(action_ph.shape[-1], ff_feature=self.ff_feature)
        actor_param = self.actor.init(next(self.rng), observation_ph, action_ph)
        actor_target_params = actor_param.copy()['params']
        loss_fn = flax.linen.apply(VAEActor.loss_fn, self.actor)
        bulk_sample = flax.linen.apply(VAEActor.non_clipping_bulk_sample, self.actor)
        sample = flax.linen.apply(VAEActor.non_clipping_noise_sample, self.actor)
        self.actor_train_state = VAERLTrainState.create(apply_fn=sample,
                                                        params=actor_param['params'],
                                                        tx=self.opt_class(self.learning_rate),
                                                        loss_fn=loss_fn,
                                                        bulk_sample=bulk_sample,
                                                        target_params=actor_target_params)

        taus_ph = jax.numpy.zeros(shape=(observation_ph.shape[0], 32))
        self.critic = VectorQuantileCritic(n_critics=self.n_critics, smooth=self.smooth,
                                           monotone=False, ff_feature=self.ff_feature)
        critic_param = self.critic.init(next(self.rng), observation_ph, action_ph, taus_ph, )
        critic_target_params = critic_param.copy()['params']

        self.critic_train_state = RLTrainState.create(apply_fn=self.critic.apply,
                                                      params=critic_param['params'],
                                                      tx=self.opt_class(self.learning_rate),
                                                      target_params=critic_target_params
                                                      )

    def predict(self, observation):
        return np.asarray(self._predict(self.actor_train_state.params,
                                        self.critic_train_state.params,
                                        observation, next(self.rng)))

    @partial(jax.jit, static_argnums=(0,))
    def _predict(self,
                 params_actor,
                 params_critic,
                 observation,
                 key
                 ):
        actions = self.actor_train_state.bulk_sample({"params": params_actor}, observation, rngs={"rng_stream": key})
        taus = self.risk_measure_fn(self.taus[None], self.risk_eta)

        def evaluate_actions(obs, a):
            return self.critic.apply({'params': params_critic}, obs, a, taus).mean()

        q_a = jax.vmap(evaluate_actions, in_axes=(None, 1), out_axes=0)(observation, actions)
        q_a = q_a.reshape(-1, )
        actions = actions.reshape(-1, actions.shape[-1])
        index = jax.numpy.argmax(q_a, )

        return actions[index]


class VAEMRTQCPolicy(VAETQCPolicy):
    aux_critic: VectorQuantileCritic
    aux_critic_train_state: TrainState

    def __init__(self,
                 observation_ph, action_ph,
                 actor_update_fn, critic_update_fn, aux_critic_update_fn,
                 risk_measure_fn, risk_eta,
                 opt_class: Type[optax.adam] = optax.adam,
                 learning_rate: float = 3e-4,
                 n_critics: int = 2,
                 smooth: bool = True,
                 ff_feature: bool = False,
                 seed: int = 42,
                 ):
        self.auxiliary_update_fn = aux_critic_update_fn

        super().__init__(observation_ph, action_ph,
                         actor_update_fn=actor_update_fn,
                         critic_update_fn=critic_update_fn,
                         risk_measure_fn=risk_measure_fn,
                         risk_eta=risk_eta,
                         opt_class=opt_class,
                         learning_rate=learning_rate,
                         ff_feature=ff_feature,
                         n_critics=n_critics,
                         smooth=smooth,
                         seed=seed
                         )

    def build(self, observation_ph, action_ph):
        super().build(observation_ph, action_ph)
        self.aux_critic = VectorQuantileCritic(n_critics=self.n_critics, smooth=self.smooth,
                                               monotone=False, ff_feature=self.ff_feature)

        taus_ph = jax.numpy.zeros(shape=(observation_ph.shape[0], 32))
        critic_param = self.aux_critic.init(next(self.rng), observation_ph, action_ph, taus_ph, )
        self.aux_critic_train_state = TrainState.create(apply_fn=self.critic.apply,
                                                        params=critic_param['params'],
                                                        tx=self.opt_class(self.learning_rate),
                                                        )

    def predict(self, observation):
        if len(observation.shape) == 1:
            observation = observation[None]
        return np.asarray(self._predict(self.actor_train_state.params,
                                        self.aux_critic_train_state.params,
                                        observation, next(self.rng)))

    @partial(jax.jit, static_argnums=(0,))
    def _predict(self,
                 params_actor,
                 param_aux_critic,
                 observation,
                 key
                 ):
        actions = self.actor_train_state.apply_fn({"params": params_actor}, observation, rngs={"rng_stream": key})
        '''
        taus = self.risk_measure_fn(self.taus[None], self.risk_eta)
        def evaluate_actions(obs, a):
            return self.aux_critic.apply({'params': param_aux_critic}, obs, a, taus).mean()

        q_a = jax.vmap(evaluate_actions, in_axes=(None, 1), out_axes=0)(observation, actions)
        q_a = q_a.reshape(-1, )
        actions = actions.reshape(-1, actions.shape[-1])
        index = jax.numpy.argmax(q_a, )
        '''
        return actions

    def update(self, batch, key, cnt, delay):
        key1, key2, key3 = jax.random.split(key, 3)
        '''
        self.env_model_train_state, loss_info = self.model_update_fn(
            self.env_model_train_state,
            batch, key1
        )
        '''
        loss_info = {}
        self.critic_train_state, critic_loss_info = self.critic_update_fn(
            self.critic_train_state,
            self.actor_train_state,
            batch, key2)

        '''
        self.aux_critic_train_state, aux_loss_info = self.auxiliary_update_fn(
            self.aux_critic_train_state,
            self.critic_train_state,
            self.actor_train_state,
            batch,
            key1
        )
        '''
        loss_info.update(critic_loss_info)
        # loss_info.update(aux_loss_info)

        if cnt % delay == 0:
            self.actor_train_state, actor_loss_info = self.actor_update_fn(self.critic_train_state,
                                                                           self.actor_train_state,
                                                                           batch, key3)
            loss_info.update(actor_loss_info)
        return loss_info

