from rl.ddpg import DDPG
from rl.policies import TD3Policy, RLTrainState
from rl.utils.replay_buffer import ReplayBuffer, QLearningBatch

import jax
import jax.numpy as jnp
import gymnasium as gym
from typing import Callable, Optional
from optax import adam
from rl.base_offline import BaseOffline
import optax


class TD3PlusBC(DDPG, BaseOffline):
    policy: TD3Policy

    def __init__(self,
                 env: gym.Env,
                 buffer: ReplayBuffer,
                 normalizer: Optional[Callable] = False,
                 gamma: float = 0.99,
                 batch_size: int = 256,
                 opt_class: Callable = adam,
                 learning_rate: float = 3e-4,
                 q_learning_scale: float = 2.5,
                 n_critics: int = 2,
                 policy_delay: int = 2,
                 seed: int = 42,
                 ):
        self.q_learning_scale = q_learning_scale
        super().__init__(env,
                         gamma=gamma,
                         batch_size=batch_size,
                         opt_class=opt_class,
                         learning_rate=learning_rate,
                         n_critics=n_critics,
                         policy_delay=policy_delay,
                         seed=seed,
                         )
        BaseOffline.__init__(self)
        self.buffer = buffer
        self.normalizer = normalizer

    def build_critic_update_fn(self, ):
        gamma = self.gamma
        target_update_rate = self.target_update_rate

        @jax.jit
        def update_fn(critic_train_state: RLTrainState, actor_train_state: RLTrainState, batch: QLearningBatch, key):
            def loss_fn(param_critic):

                def fn(s, a):
                    q_values = critic_train_state.apply_fn({'params': param_critic},
                                                           s, a)
                    return q_values.mean(), q_values

                (g_s, g_a), q_value = jax.grad(fn, has_aux=True, argnums=(0, 1))(batch.observations, batch.actions)
                ds = (g_s ** 2).mean()
                da = (g_a ** 2)

                next_action = jax.lax.stop_gradient(
                    actor_train_state.apply_fn({'params': actor_train_state.target_params}, batch.next_observations
                                               ))

                next_action = (next_action +
                               0.2 * jax.random.normal(key, shape=next_action.shape).clip(-0.5, 0.5)).clip(-1., 1.)

                next_q_value = critic_train_state.apply_fn({'params': critic_train_state.target_params},
                                                              batch.next_observations, next_action)
                next_q_value = jax.lax.stop_gradient(next_q_value.min(axis=-1))
                td_target = jax.lax.stop_gradient(batch.rewards + gamma * (1 - batch.dones) * next_q_value)

                q_value = q_value.squeeze(axis=1)

                mse = ((td_target - q_value) ** 2)
                mse = mse.sum(axis=-1).mean()

                loss = mse
                return loss, {"q_loss": mse, "da": da, "ds": ds}

            grads, loss_info = jax.grad(loss_fn, has_aux=True)(critic_train_state.params)
            state = critic_train_state.apply_gradients(grads=grads)
            critic_target_params = jax.jit(optax.incremental_update, static_argnums=(2,))(
                state.params, critic_train_state.target_params, target_update_rate)
            state = state.replace(target_params=critic_target_params)
            return state, loss_info
        return update_fn

    def build_actor_update_fn(self):
        target_update_rate = self.target_update_rate
        q_learning_scale = self.q_learning_scale

        @jax.jit
        def update_fn(critic_train_state: RLTrainState,
                      actor_train_state: RLTrainState,
                      batch: QLearningBatch,
                      key: jax.Array
                      ):
            def loss_fn(params):
                actions = actor_train_state.apply_fn({'params': params},
                                                     batch.observations)

                q_values = critic_train_state.apply_fn({'params': critic_train_state.params},
                                                       batch.observations, actions)

                policy_kl = -q_values[..., 0]
                scale = jax.lax.stop_gradient(jnp.abs(policy_kl).mean())
                bc_loss = ((actions - batch.actions) ** 2).mean()
                loss = q_learning_scale * (policy_kl / scale).mean() + bc_loss

                return loss, {"policy_kl": policy_kl.mean(), "bc_loss": bc_loss}

            grads, loss_info = jax.grad(loss_fn, has_aux=True)(actor_train_state.params)
            state = actor_train_state.apply_gradients(grads=grads)
            target_param = jax.jit(optax.incremental_update, static_argnums=(2,))(state.params,
                                                                                  actor_train_state.target_params,
                                                                                  target_update_rate)

            state = state.replace(target_params=target_param)
            return state, loss_info

        return update_fn

    def __str__(self):
        # name
        return "TD3PlusBC"

