from rl.iqn_td3_plusbc import QuantileTD3PlusBC
from rl.policies.policy import ORAACPolicy, RLTrainState, VAETrainState
from rl.utils.risk_utils import quantile_huber_loss, get_tau
from rl.utils.replay_buffer import QLearningBatch, ReplayBuffer
import gymnasium as gym
from functools import partial
import optax
import jax.numpy as jnp
import jax
from typing import Callable, Optional


class ORAAC(QuantileTD3PlusBC):
    policy: ORAACPolicy
    bc_vae_beta: float = 0.5

    def __init__(self,
                 env: gym.Env,
                 buffer: ReplayBuffer,
                 normalizer: Optional[Callable] = False,
                 gamma: float = 0.99,
                 batch_size: int = 256,
                 opt_class: Callable = optax.adam,
                 learning_rate: float = 3e-4,
                 risk_type: str = 'cvar',
                 risk_eta: float = 0.5,
                 n_quantiles: int = 16,
                 drop_per_net: int = 3,
                 phi: float = 0.25,
                 n_critics: int = 2,
                 smooth: bool = False,
                 seed: int = 42,
                 ):
        self.oraac_phi = phi

        super().__init__(env=env,
                         buffer=buffer,
                         normalizer=normalizer,
                         gamma=gamma,
                         batch_size=batch_size,
                         opt_class=opt_class,
                         learning_rate=learning_rate,
                         q_learning_scale=1,  # ignored
                         risk_type=risk_type,
                         risk_eta=risk_eta,
                         n_quantiles=n_quantiles,
                         drop_per_net=drop_per_net,
                         policy_delay=1,
                         n_critics=n_critics,
                         smooth=smooth,
                         seed=seed)

    def build(self, ):
        observation_ph, action_ph = self.make_placeholder()
        self.policy = ORAACPolicy(observation_ph, action_ph,
                                  critic_update_fn=self.build_critic_update_fn(),
                                  actor_update_fn=self.build_actor_update_fn(),
                                  bc_actor_update_fn=self.build_bc_actor_update_fn(),
                                  risk_measure_fn=partial(self.risk_measure, eta=self.risk_eta),
                                  opt_class=self.opt_class, learning_rate=self.learning_rate,
                                  n_critics=self.n_critics,
                                  phi=self.oraac_phi,
                                  seed=next(self.hk_rng))

    def build_critic_update_fn(self, ):
        gamma = self.gamma
        target_update_rate = self.target_update_rate

        @jax.jit
        def update_fn(critic_train_state: RLTrainState,
                      bc_actor_train_state: VAETrainState,
                      batch: QLearningBatch, key):
            def loss_fn(param_critic):
                key1, key2 = jax.random.split(key, 2)
                batch_size = batch.observations.shape[0]
                _, taus, presum_tau = get_tau(key1, shape=(batch_size, self.n_quantiles))
                _, next_taus, next_presum_tau = get_tau(key1, shape=(batch_size, self.n_quantiles))

                q_value = critic_train_state.apply_fn({'params': param_critic}, batch.observations,
                                                      batch.actions, taus)
                diff = jnp.diff(q_value, axis=-2)
                order = (jnp.where(diff < 0, 1, 0).mean() / diff.size) * 100

                repeated_next_observation = jnp.repeat(batch.next_observations[:, None], repeats=10, axis=1)

                def sample_next(next_obs, key):
                    next_bc = bc_actor_train_state.apply_fn({"params": bc_actor_train_state.params},
                                                            next_obs, rngs={"rng_stream": key})

                    q_values = critic_train_state.apply_fn({"params": critic_train_state.target_params},
                                                           next_obs, next_bc, next_taus
                                                           )

                    risk_q_values = critic_train_state.apply_fn({"params": critic_train_state.target_params},
                                                                next_obs, next_bc,
                                                                self.risk_measure(next_taus, self.risk_eta)
                                                                )
                    risk = jnp.mean(risk_q_values, axis=-2).min(axis=-1)
                    return q_values, risk

                q, risk = jax.vmap(sample_next, in_axes=(1, 0), out_axes=(1, 1))(repeated_next_observation,
                                                                                 jax.random.split(key1, 10))

                index = jnp.argmax(risk, axis=1, keepdims=True)
                next_q_values = jnp.take_along_axis(q, index[..., None, None], axis=1)
                next_q_values = next_q_values.min(axis=-1)
                next_q_values = next_q_values.reshape(next_q_values.shape[0], -1)
                td_target = jax.lax.stop_gradient(batch.rewards + gamma * (1 - batch.dones) * next_q_values)

                loss = jax.vmap(quantile_huber_loss, in_axes=(None, 2, None, None), out_axes=-1)(td_target, q_value,
                                                                                                 taus, next_presum_tau)
                qr_loss = loss.mean()

                return qr_loss, {"q_loss": qr_loss, "ord": order}

            grads, loss_info = jax.grad(loss_fn, has_aux=True)(critic_train_state.params)
            state = critic_train_state.apply_gradients(grads=grads)
            critic_target_params = jax.jit(optax.incremental_update, static_argnums=(2,))(
                state.params, critic_train_state.target_params, target_update_rate)
            state = state.replace(target_params=critic_target_params)

            return state, loss_info

        return update_fn

    def build_actor_update_fn(self):
        target_update_rate = self.target_update_rate
        risk_measure = self.risk_measure
        risk_eta = self.risk_eta

        @jax.jit
        def update_fn(critic_train_state: RLTrainState,
                      actor_train_state: RLTrainState,
                      bc_actor_train_state: VAETrainState,
                      batch: QLearningBatch,
                      key
                      ):
            key = jax.random.split(key, 32)[-1]

            def loss_fn(params):
                taus = jax.random.uniform(key, shape=(batch.actions.shape[0], self.actor_sampling_quantiles))
                taus = risk_measure(taus, risk_eta)
                bc_actions = jax.lax.stop_gradient(
                    bc_actor_train_state.apply_fn({"params": bc_actor_train_state.params},
                                                  batch.observations,
                                                  rngs={"rng_stream": key}
                                                  ))

                actions = actor_train_state.apply_fn({'params': params}, batch.observations, bc_actions)
                q_values = critic_train_state.apply_fn({'params': critic_train_state.params, },
                                                       batch.observations, actions, taus)

                policy_kl = -(q_values.mean(axis=-2).min(axis=-1)).mean()

                return policy_kl, {"policy_kl": policy_kl.mean()}

            grads, loss_info = jax.grad(loss_fn, has_aux=True)(actor_train_state.params)
            state = actor_train_state.apply_gradients(grads=grads)
            target_param = optax.incremental_update(state.params,
                                                    actor_train_state.target_params, target_update_rate)
            state = state.replace(target_params=target_param)
            return state, loss_info

        return update_fn

    def build_bc_actor_update_fn(self) -> Callable:
        @jax.jit
        def update_fn(bc_actor_train_state: VAETrainState,
                      batch: QLearningBatch, key
                      ):
            def loss_fn(param_bc_actor):
                loss, loss_info = bc_actor_train_state.loss_fn({"params": param_bc_actor},
                                                               batch.observations, batch.actions,
                                                               self.bc_vae_beta,
                                                               rngs={"rng_stream": key})

                return loss, loss_info

            grads, loss_info = jax.grad(loss_fn, has_aux=True)(bc_actor_train_state.params)
            state = bc_actor_train_state.apply_gradients(grads=grads)
            return state, loss_info

        return update_fn

    def train_step(self, ) -> dict:
        batch = self.buffer.sample(self.batch_size)
        self.train_cnt += 1
        loss_dict = self.policy.update(batch, next(self.hk_rng), self.train_cnt, 1, True)

        return loss_dict
