from rl.base_offline import BaseOffline
from rl.sac import SAC
from rl.policies.policy import CQLPolicy, RLTrainState, TrainState
from rl.utils.replay_buffer import ReplayBuffer, QLearningBatch

import jax
import jax.numpy as jnp
import gymnasium as gym
from typing import Callable
from optax import adam
import optax


class CQL(SAC, BaseOffline):
    policy: CQLPolicy

    def __init__(self,
                 env: gym.Env,
                 buffer: ReplayBuffer,
                 gamma: float = 0.99,
                 batch_size: int = 256,
                 opt_class: Callable = adam,
                 learning_rate: float = 3e-4,
                 n_critics: int = 2,
                 target_update_rate: float = 0.005,
                 target_entropy: float | str = 'auto',
                 learn_ent_coef: bool = False,
                 ent_coef: float = 0.2,
                 target_diff: float = 10,
                 auto_adjust_kl: bool = False,
                 seed: int = 42,
                 ):
        self.auto_adjust_kl: bool = auto_adjust_kl
        self.target_diff = target_diff
        super().__init__(
            env,
            gamma,
            1,
            batch_size,
            opt_class,
            learning_rate,
            n_critics,
            target_update_rate,
            target_entropy,
            learn_ent_coef,
            ent_coef,
            seed
        )
        self.buffer = buffer

    def build(self, ):
        observation_ph, action_ph = self.make_placeholder()
        self.policy = CQLPolicy(observation_ph, action_ph,
                                opt_class=self.opt_class, learning_rate=self.learning_rate,
                                seed=next(self.hk_rng),
                                n_critics=self.n_critics,
                                auto_adjust_kl=self.auto_adjust_kl,

                                ent_coef=self.ent_coef,
                                learn_ent_coef=self.learn_ent_coef,

                                critic_update_fn=self.build_critic_update_fn(),
                                actor_update_fn=self.build_actor_update_fn(),
                                ent_coef_update_fn=self.build_ent_coef_update_fn(self.target_diff),
                                lag_update_fn=self.build_lag_coef_update_fn(self.target_diff),
                                )

    def build_lag_coef_update_fn(self, target_diff):
        if self.auto_adjust_kl:
            @jax.jit
            def update_fn(lag_train_state: TrainState, diff):
                def loss_fn(params):
                    log_lag_coef = lag_train_state.apply_fn(params)
                    lag_coef = jnp.exp(log_lag_coef).clip(0, 1e+6)
                    loss = -(lag_coef * jax.lax.stop_gradient(diff + target_diff)).mean()
                    return loss, {"lag_loss": loss, "lag_coef": jnp.exp(log_lag_coef)}

                grads, items = jax.grad(loss_fn, has_aux=True)(lag_train_state.params)
                ent_coef_state = lag_train_state.apply_gradients(grads=grads)
                return ent_coef_state, items

            return update_fn
        else:
            @jax.jit
            def dummy_fn(lag_train_state: TrainState, diff):
                lag = jnp.exp(lag_train_state.apply_fn(lag_train_state.params))
                return lag_train_state, {"lag_coef": lag}

            return dummy_fn

    def build_critic_update_fn(self, ):
        gamma = self.gamma
        target_update_rate = self.target_update_rate

        @jax.jit
        def update_fn(critic_train_state: RLTrainState,
                      actor_train_state: RLTrainState,
                      ent_coef_state: TrainState,
                      lag_coef_state: TrainState,
                      batch: QLearningBatch, key):
            def loss_fn(param_critic):
                batch_size = batch.observations.shape[0]
                current_q = critic_train_state.apply_fn({"params":param_critic},
                                                        batch.observations, batch.actions).reshape(batch_size, -1)
                ent_coef = jax.lax.stop_gradient(jnp.exp(ent_coef_state.apply_fn(ent_coef_state.params)))
                cql_coef = jax.lax.stop_gradient(jnp.exp(lag_coef_state.apply_fn(lag_coef_state.params)).clip(0, 1e+6))

                ################ compute td_target ##################
                key1, key2, key3, key4 = jax.random.split(key, 4)
                next_actions, next_log_probs = actor_train_state.apply_fn({"params": actor_train_state.target_params},
                                                                          batch.next_observations,
                                                                          rngs={"rng_stream": key}
                                                                          )
                next_q = critic_train_state.apply_fn({"params": critic_train_state.target_params},
                                                     batch.next_observations, next_actions,
                                                     ).reshape(batch_size, -1)
                next_q = next_q.min(axis=-1, keepdims=True)
                next_q = next_q - ent_coef * next_log_probs
                td_target = jax.lax.stop_gradient(batch.rewards + gamma * (1 - batch.dones) * next_q)

                ##########################################################################
                td_error = (td_target - current_q) ** 2
                td_error = td_error.sum(axis=-1).mean()

                ############ compute KL-divergence via Donsker-Varadhan Representation ###
                ############# First We Compute log E_(X~Q)[exp(g(X)] #####################
                repeated_observation = jnp.repeat(batch.observations[:, None], axis=1, repeats=10)
                repeated_next_observation = jnp.repeat(batch.next_observations[:, None], axis=1, repeats=10)

                def act_fn(obs, k):
                    return jax.lax.stop_gradient(actor_train_state.apply_fn({"params": actor_train_state.params},
                                                      obs, rngs={"rng_stream": k}))

                act_fn = jax.vmap(act_fn, in_axes=(1, 0), out_axes=(1, 1))
                current_action, current_log_pi = act_fn(repeated_observation, jax.random.split(key2, 10))
                next_action, next_log_pi = act_fn(repeated_next_observation, jax.random.split(key3, 10))
                rand_action = 2 * jax.random.uniform(key4, shape=current_action.shape) - 1

                def _critic_fn(obs, action):
                    return critic_train_state.apply_fn({"params": param_critic},
                                                       obs, action).reshape(batch_size, -1)

                qf_fn = jax.vmap(_critic_fn, in_axes=(1, 1), out_axes=1)
                q_cur = qf_fn(repeated_observation, current_action) - current_log_pi
                q_next = qf_fn(repeated_observation, next_action) - next_log_pi
                q_rand = qf_fn(repeated_observation, rand_action) - jnp.log(0.5) * rand_action.shape[-1]
                cat_q = jnp.concatenate([q_cur, q_next, q_rand], axis=1)  # (Batch, N_repeat, N_critics)
                logsumexpQ = jax.nn.logsumexp(cat_q, axis=1)
                ############## D_{KL}[P, Q] = sup_g E_P[g(X)] - log E_(X~Q)[exp(g(X)] #######
                kl_potential = (logsumexpQ.mean(axis=0) - current_q.mean(axis=0))
                ##################################################
                cql_loss = cql_coef * kl_potential.sum(axis=-1)
                loss = td_error + cql_loss
                return loss, {"td": td_error, "cql_loss": cql_loss, "kl_potential": kl_potential.mean()}

            grads, loss_info = jax.grad(loss_fn, has_aux=True)(critic_train_state.params)
            state = critic_train_state.apply_gradients(grads=grads)
            critic_target_params = jax.jit(optax.incremental_update, static_argnums=(2,))(
                state.params, critic_train_state.target_params, target_update_rate)
            state = state.replace(target_params=critic_target_params)
            return state, loss_info

        return update_fn

    def build_actor_update_fn(self):
        target_update_rate = self.target_update_rate

        @jax.jit
        def update_fn(critic_train_state: RLTrainState,
                      actor_train_state: RLTrainState,
                      ent_coef_state: TrainState,
                      batch: QLearningBatch,
                      key: jax.Array,
                      ):
            def loss_fn(params):
                actions, log_probs = actor_train_state.apply_fn({'params': params}, batch.observations,
                                                                rngs={"rng_stream": key})

                q_values = critic_train_state.apply_fn({'params': critic_train_state.params},
                                                       batch.observations, actions)
                ent_coef = jax.lax.stop_gradient(jnp.exp(ent_coef_state.apply_fn(ent_coef_state.params)))
                loss = ent_coef * log_probs - q_values.min(axis=-1)
                loss = loss.mean()
                return loss, {"pi_loss": loss, "log_probs": log_probs}

            grads, loss_info = jax.grad(loss_fn, has_aux=True)(actor_train_state.params)
            state = actor_train_state.apply_gradients(grads=grads)
            actor_target_param = jax.jit(optax.incremental_update, static_argnums=(2,))(
                state.params, actor_train_state.target_params, target_update_rate)
            state = state.replace(target_params=actor_target_param)
            return state, loss_info
        return update_fn

    def __str__(self):
        # name
        return "CQL"

