from rl.iqn_td3_plusbc import QuantileTD3PlusBC, QLearningBatch
from rl.policies.policy import MRDeterministicTQCPolicy, RLTrainState
from rl.utils.model_risk import calculate_model_risk
import jax
import jax.numpy as jnp
import numpy as np
import optax


class ModelRiskTD3PlusBC(QuantileTD3PlusBC):
    policy: MRDeterministicTQCPolicy
    debug: np.ndarray

    def build(self, ):
        observation_ph, action_ph = self.make_placeholder()
        self.policy = MRDeterministicTQCPolicy(observation_ph, action_ph,
                                               opt_class=self.opt_class, learning_rate=self.learning_rate,
                                               seed=next(self.hk_rng), n_critics=self.n_critics,
                                               ff_feature=self.fourier_feature,
                                               actor_update_fn=self.build_actor_update_fn(),
                                               critic_update_fn=self.build_critic_update_fn(),
                                               smooth=self.smooth,
                                               )

    def build_actor_update_fn(self):
        target_update_rate = self.target_update_rate
        risk_measure = self.risk_measure
        risk_eta = self.risk_eta
        density = self.densities[self.risk_type]
        actor_sampling_quantiles = self.actor_sampling_quantiles

        @jax.jit
        def update_fn(actor_train_state: RLTrainState,
                      critic_train_state: RLTrainState,
                      batch: QLearningBatch,
                      key
                      ):
            def loss_fn(params):
                taus = jax.random.uniform(key, shape=(batch.observations.shape[0],
                                                      actor_sampling_quantiles)).sort(axis=-1)
                predicted_actions = actor_train_state.apply_fn({'params': params}, batch.observations)
                risk_taus = risk_measure(taus, risk_eta)

                risk_qf = critic_train_state.apply_fn({'params': critic_train_state.params},
                                                      batch.observations, predicted_actions,
                                                      risk_taus
                                                      )
                critic_estim = critic_train_state.apply_fn({'params': critic_train_state.params},
                                                           batch.observations, predicted_actions, taus
                                                           )
                std = critic_estim.std(axis=-2).max(axis=-1)

                neg_risk, interpolation = calculate_model_risk(density, risk_eta, std, critic_estim,
                                                               risk_qf, taus, risk_taus)
                risk = -neg_risk
                scale = jax.lax.stop_gradient(jnp.abs(risk)).mean()
                policy_loss = self.q_learning_scale * (risk / scale).mean()
                bc_loss = ((predicted_actions - batch.actions) ** 2).mean()
                loss = self.q_learning_scale * policy_loss + bc_loss

                return loss, {"risk": risk.mean(), "bc_loss": bc_loss}

            grads, loss_info = jax.grad(loss_fn, has_aux=True)(actor_train_state.params)
            state = actor_train_state.apply_gradients(grads=grads)
            target_param = jax.jit(optax.incremental_update, static_argnums=(2,))(state.params,
                                                                                  actor_train_state.target_params,
                                                                                  target_update_rate)
            state = state.replace(target_params=target_param)
            return state, loss_info

        return update_fn

    def config(self):
        return super().config()

    def __str__(self):
        return "MR_IQN_TD3PlusBC"

    def q_value(self, observation):
        return self.policy.q_values(observation[None])
