import numpy as np
import jax
from rl.base import OffPolicy
from rl.policies import TD3Policy, RLTrainState
import gymnasium as gym
from optax import adam

from typing import Callable, Optional
from rl.utils.replay_buffer import QLearningBatch
from collections import deque
import optax


class DDPG(OffPolicy):
    policy: TD3Policy

    def __init__(self,
                 env: gym.Env,
                 gamma: float = 0.99,
                 buffer_capacity: int = 1_000_000,
                 batch_size: int = 256,
                 opt_class: Callable = adam,
                 learning_rate: float = 3e-4,
                 n_critics: int = 2,
                 policy_delay: int = 2,
                 target_update_rate: float = 0.005,
                 seed: int = 42,
                 ):
        super().__init__(env, gamma, learning_rate,
                         buffer_capacity, batch_size, seed=seed)
        self.learning_rate = learning_rate
        self.opt_class = opt_class
        self.n_critics = n_critics
        self.target_update_rate = target_update_rate
        self.mean_score = deque(maxlen=100)
        self.train_cnt = 0
        self.policy_delay = policy_delay
        self.build()

    def build(self, ):
        observation_ph, action_ph = self.make_placeholder()
        self.policy = TD3Policy(observation_ph, action_ph,
                                opt_class=self.opt_class, learning_rate=self.learning_rate,
                                seed=next(self.hk_rng),
                                n_critics=self.n_critics,
                                critic_update_fn=self.build_critic_update_fn(),
                                actor_update_fn=self.build_actor_update_fn())

    @property
    def params_actor(self):
        return self.policy.actor_train_state.params

    @property
    def params_critic(self):
        return self.policy.critic_train_state.params

    def predict(self, observation, explore: bool = False,
                *args, **kwargs):
        squeeze_action = False
        if len(observation.shape) == 1:
            observation = observation[None]
            squeeze_action = True

        if explore:
            action = self.policy.explore(observation, key=next(self.hk_rng))
        else:
            action = self.policy.predict(observation)
        if squeeze_action:
            action = action.squeeze()
        return np.asarray(action)

    def learn(self, steps: int = 1000_000,
              learning_start: int = 100,
              eval_env: Optional[gym.Env] = None,
              eval_interval: int = 100,
              log_interval: int = 4,
              ):
        score = 0
        start_frame = 0
        from time import time
        start_time = time()
        epi_cnt = 0
        eval_flag = False
        for e in range(steps):

            action = self.predict(self.last_observation, explore=True)
            next_observation, reward, done, timeout, info = self.env.step(action)
            done = done or timeout
            next_observation = next_observation.copy()
            self.buffer.add(self.last_observation,
                            action, reward, next_observation, done)
            self.last_observation = next_observation
            score += reward
            if e > learning_start:
                losses = self.train_step()
                for k, v in losses.items():
                    self.logger.record_mean("loss/" + k, v.mean())

            if done or timeout:
                end_frame = e
                end_time = time()
                fps = (end_frame - start_frame) / (end_time - start_time)
                start_frame = end_frame
                start_time = end_time
                self.last_observation, _ = self.env.reset()
                self.mean_score.append(score)
                self.logger.record("episode/score", score)
                self.logger.record("episode/mean_100score", np.mean(self.mean_score))
                self.logger.record("time/fps", fps)
                self.logger.record("time/step", e)

                # print("score", score)
                score = 0
                epi_cnt += 1
                eval_flag = False
                if epi_cnt % log_interval == 0:
                    self.logger.dump()

            if epi_cnt % eval_interval == 0 and not eval_flag:
                eval_flag = True
                if eval_env is not None:
                    _eval_scores = []
                    for _ in range(10):
                        _done = False
                        _obs, _ = eval_env.reset()
                        _eval_score = 0
                        while not _done:
                            action = self.predict(_obs)
                            _obs, _reward, _done, _timeout, _info = eval_env.step(action)
                            _eval_score += _reward
                            _done = _done or _timeout
                        _eval_scores.append(_eval_score)
                    self.logger.record("eval/eval_score", np.mean(_eval_scores))

    def build_actor_update_fn(self):
        target_update_rate = self.target_update_rate

        @jax.jit
        def update_fn(critic_train_state: RLTrainState,
                      actor_train_state: RLTrainState,
                      batch: QLearningBatch):
            def loss_fn(params):
                actions = actor_train_state.apply_fn({'params': params}, batch.observations)
                q_values = critic_train_state.apply_fn({'params': critic_train_state.params},
                                                          batch.observations, actions)

                loss = -(q_values.min(axis=-1)).mean()
                return loss, {"pi_loss": loss}
            grads, loss_info = jax.grad(loss_fn, has_aux=True)(actor_train_state.params)
            state = actor_train_state.apply_gradients(grads=grads)
            target_param = jax.jit(optax.incremental_update, static_argnums=(2,))(state.params,
                                                                                  actor_train_state.target_params,
                                                                                  target_update_rate)
            state = state.replace(target_params=target_param)
            return state, loss_info

        return update_fn

    def build_critic_update_fn(self, ):
        gamma = self.gamma
        target_update_rate = self.target_update_rate

        @jax.jit
        def update_fn(critic_train_state: RLTrainState, actor_train_state: RLTrainState, batch: QLearningBatch, key):
            def loss_fn(param_critic):

                q_value = critic_train_state.apply_fn({'params': param_critic}, batch.observations, batch.actions)
                next_action = jax.lax.stop_gradient(
                    actor_train_state.apply_fn({'params': actor_train_state.target_params}, batch.next_observations
                                               ))

                next_action = (next_action +
                               0.3 * jax.random.normal(key, shape=next_action.shape).clip(-0.5, 0.5)).clip(-1., 1.)

                next_q_value = critic_train_state.apply_fn({'params': critic_train_state.target_params},
                                                              batch.next_observations, next_action)
                next_q_value = jax.lax.stop_gradient(next_q_value.min(axis=-1))
                td_target = jax.lax.stop_gradient(batch.rewards + gamma * (1 - batch.dones) * next_q_value)

                q_value = q_value.squeeze(axis=1)

                mse = ((td_target - q_value) ** 2)
                mse = mse.sum(axis=-1).mean()
                return mse, {"q_loss": mse}

            grads, loss_info = jax.grad(loss_fn, has_aux=True)(critic_train_state.params)
            state = critic_train_state.apply_gradients(grads=grads)
            critic_target_params = jax.jit(optax.incremental_update, static_argnums=(2,))(
                state.params, critic_train_state.target_params, target_update_rate)
            state = state.replace(target_params=critic_target_params)
            return state, loss_info

        return update_fn

    def train_step(self, ) -> dict:
        batch = self.buffer.sample(self.batch_size)
        self.train_cnt += 1
        loss_dict = self.policy.update(batch, next(self.hk_rng), self.train_cnt, self.policy_delay)
        return loss_dict
