from rl.policies.policy import SACPolicy
from rl.ddpg import DDPG
import numpy as np
import jax
import jax.numpy as jnp
from rl.policies import RLTrainState, TrainState
import gymnasium as gym
from optax import adam

from typing import Callable
from rl.utils.replay_buffer import QLearningBatch
import optax


class SAC(DDPG):
    policy: SACPolicy

    def __init__(self,
                 env: gym.Env,
                 gamma: float = 0.99,
                 buffer_capacity: int = 1_000_000,
                 batch_size: int = 256,
                 opt_class: Callable = adam,
                 learning_rate: float = 3e-4,
                 n_critics: int = 2,
                 target_update_rate: float = 0.005,
                 target_entropy: float | str = 'auto',
                 learn_ent_coef: bool = True,
                 ent_coef: float = 1e-2,
                 seed: int = 42,
                 ):
        if target_entropy == 'auto':
            target_entropy = -(np.prod(env.action_space.shape, dtype=np.float32))
        else:
            target_entropy = target_entropy
        self.target_entropy = target_entropy
        self.learn_ent_coef = learn_ent_coef
        self.ent_coef = ent_coef
        super().__init__(env=env,
                         gamma=gamma,
                         buffer_capacity=buffer_capacity,
                         batch_size=batch_size,
                         opt_class=opt_class,
                         learning_rate=learning_rate,
                         n_critics=n_critics,
                         target_update_rate=target_update_rate,
                         seed=seed
                         )

    def build(self, ):
        observation_ph, action_ph = self.make_placeholder()
        self.policy = SACPolicy(observation_ph, action_ph,
                                opt_class=self.opt_class, learning_rate=self.learning_rate,
                                seed=next(self.hk_rng),
                                n_critics=self.n_critics,
                                ent_coef=self.ent_coef,
                                learn_ent_coef=self.learn_ent_coef,
                                ent_coef_update_fn=self.build_ent_coef_update_fn(self.target_entropy),
                                critic_update_fn=self.build_critic_update_fn(),
                                actor_update_fn=self.build_actor_update_fn())

    def build_critic_update_fn(self, ):
        gamma = self.gamma
        target_update_rate = self.target_update_rate

        @jax.jit
        def update_fn(critic_train_state: RLTrainState, actor_train_state: RLTrainState,
                      ent_coef_state: TrainState,
                      batch: QLearningBatch, key):
            def loss_fn(param_critic):
                q_value = critic_train_state.apply_fn({'params': param_critic}, batch.observations, batch.actions)

                next_action, next_log_prob = jax.lax.stop_gradient(
                    actor_train_state.apply_fn({'params': actor_train_state.params},
                                               batch.next_observations,
                                               rngs={"rng_stream": key}
                                               ))
                next_q_value = critic_train_state.apply_fn({'params': critic_train_state.target_params},
                                                           batch.next_observations, next_action)
                next_q_value = jax.lax.stop_gradient(next_q_value.min(axis=-1))
                ent_coef = jax.lax.stop_gradient(jnp.exp(ent_coef_state.apply_fn(ent_coef_state.params)))

                next_q_value = next_q_value - ent_coef * next_log_prob

                td_target = jax.lax.stop_gradient(batch.rewards + gamma * (1 - batch.dones) * next_q_value)
                q_value = q_value.squeeze(axis=1)

                mse = ((td_target - q_value) ** 2)
                mse = mse.sum(axis=-1).mean()
                return mse, {"q_loss": mse}

            grads, loss_info = jax.grad(loss_fn, has_aux=True)(critic_train_state.params)
            state = critic_train_state.apply_gradients(grads=grads)
            critic_target_params = jax.jit(optax.incremental_update, static_argnums=(2,))(
                state.params, critic_train_state.target_params, target_update_rate)
            state = state.replace(target_params=critic_target_params)
            return state, loss_info

        return update_fn

    def build_actor_update_fn(self):
        target_update_rate = self.target_update_rate

        @jax.jit
        def update_fn(critic_train_state: RLTrainState,
                      actor_train_state: RLTrainState,
                      ent_coef_state: TrainState,
                      batch: QLearningBatch,
                      key: jax.Array,
                      ):
            def loss_fn(params):
                actions, log_probs = actor_train_state.apply_fn({'params': params}, batch.observations,
                                                                rngs={"rng_stream": key})

                q_values = critic_train_state.apply_fn({'params': critic_train_state.params},
                                                       batch.observations, actions)
                ent_coef = jax.lax.stop_gradient(jnp.exp(ent_coef_state.apply_fn(ent_coef_state.params)))
                loss = ent_coef * log_probs - q_values.min(axis=-1)
                loss = loss.mean()
                return loss, {"pi_loss": loss, "log_probs": log_probs}

            grads, loss_info = jax.grad(loss_fn, has_aux=True)(actor_train_state.params)
            state = actor_train_state.apply_gradients(grads=grads)
            return state, loss_info

        return update_fn

    def build_ent_coef_update_fn(self, target_entropy):
        if self.learn_ent_coef:
            @jax.jit
            def update_fn(ent_coef_state: TrainState, log_probs):
                def loss_fn(params):
                    log_entropy = ent_coef_state.apply_fn(params)
                    loss = (-log_entropy * jax.lax.stop_gradient(target_entropy + log_probs)).mean()
                    return loss, {"ent_loss": loss, "ent_coef": jnp.exp(log_entropy)}

                grads, items = jax.grad(loss_fn, has_aux=True)(ent_coef_state.params)
                ent_coef_state = ent_coef_state.apply_gradients(grads=grads)
                return ent_coef_state, items

            return update_fn
        else:
            @jax.jit
            def dummy_fn(ent_coef_state: TrainState, log_pobs):
                ent = jnp.exp(ent_coef_state.apply_fn(ent_coef_state.params))
                return ent_coef_state, {"ent_coef": ent}

            return dummy_fn

    def train_step(self, ) -> dict:
        batch = self.buffer.sample(self.batch_size)
        self.train_cnt += 1
        loss_dict = self.policy.update(batch, next(self.hk_rng), self.train_cnt, 1)
        return loss_dict

