import os
import warnings
from typing import Any, Dict, Optional, Union

import gymnasium as gym
import jax
import jax.numpy as jnp
import numpy as np
import rlax
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback, EventCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization


def get_mc_return_with_entropy_on_reset(bias_eval_env, model, max_ep_len, gamma, n_mc_eval, n_mc_cutoff):
    # Based off of original REDQ analysis function and adapted for JAX.
    # https://github.com/watchernyu/REDQ/blob/45a28d3e0960e222d1864c2cc8dabc1af287839f/redq/utils/bias_utils.py#L6

    # since we want to also compute bias, so we need to
    final_mc_list = np.zeros(0)
    final_mc_entropy_list = np.zeros(0)
    final_obs_list = []
    final_act_list = []
    while final_mc_list.shape[0] < n_mc_eval:
        # we continue if haven't collected enough data
        o = bias_eval_env.reset()
        # temporary lists
        reward_list, log_prob_a_tilda_list, obs_list, act_list = [], [], [], []
        r, d, ep_ret, ep_len = 0, False, 0, 0
        for _ in range(max_ep_len):  # run an episode
            a, log_prob_a_tilda = model.policy.predict_action_with_logprobs(
                o,  # type: ignore[arg-type]
                deterministic=False,
            )

            obs_list.append(o)
            act_list.append(a)
            o, r, d, _ = bias_eval_env.step(a)
            ep_len += 1
            reward_list.append(r)
            log_prob_a_tilda_list.append(log_prob_a_tilda)
            if d or (ep_len == max_ep_len):
                break

        rewards = jnp.concatenate(reward_list)
        log_probs = jnp.concatenate(log_prob_a_tilda_list)
        alpha = model.current_entropy_coeff()

        @jax.jit
        def calc_disc_returns(rewards, alpha, log_probs):
            soft_rewards = rewards - alpha * jnp.concatenate([log_probs[:-1], jnp.zeros(1)])
            gammas = jnp.ones_like(rewards) * gamma
            discounted_return_list = rlax.discounted_returns(rewards, gammas, 0.0)
            discounted_return_with_entropy_list = rlax.discounted_returns(soft_rewards, gammas, 0.0)
            return discounted_return_list, discounted_return_with_entropy_list

        discounted_return_list, discounted_return_with_entropy_list = calc_disc_returns(rewards, alpha, log_probs)

        # now we take the first few of these.
        final_mc_list = jnp.concatenate((final_mc_list, discounted_return_list[:n_mc_cutoff]))
        final_mc_entropy_list = jnp.concatenate(
            (final_mc_entropy_list, discounted_return_with_entropy_list[:n_mc_cutoff])
        )
        final_obs_list += obs_list[:n_mc_cutoff]
        final_act_list += act_list[:n_mc_cutoff]

    final_obs_list = jnp.concatenate(final_obs_list)
    final_act_list = jnp.concatenate(final_act_list)

    q_prediction = jnp.mean(model.predict_critic(final_obs_list, final_act_list), axis=0)[..., 0]

    @jax.jit
    def calc_metrics(q_prediction, final_mc_entropy_list):
        bias = q_prediction - final_mc_entropy_list
        bias_abs = jnp.abs(bias)
        bias_squared = bias**2
        final_mc_entropy_list_normalize_base = jnp.abs(final_mc_entropy_list)
        final_mc_entropy_list_normalize_base = jnp.clip(final_mc_entropy_list_normalize_base, a_min=10.0)
        normalized_bias_per_state = bias / final_mc_entropy_list_normalize_base
        normalized_bias_sqr_per_state = bias_squared / final_mc_entropy_list_normalize_base
        return dict(
            MCDisRet=final_mc_list,
            MCDisRetEnt=final_mc_entropy_list,
            QPred=q_prediction,
            QBias=bias,
            QBiasAbs=bias_abs,
            QBiasSqr=bias_squared,
            NormQBias=normalized_bias_per_state,
            NormQBiasSqr=normalized_bias_sqr_per_state,
        )

    return calc_metrics(q_prediction, final_mc_entropy_list)


class CriticBiasCallback(EvalCallback):
    def __init__(
        self,
        eval_env: Union[gym.Env, VecEnv],
        jax_random_key_for_seeds: int,
        callback_on_new_best: Optional[BaseCallback] = None,
        callback_after_eval: Optional[BaseCallback] = None,
        n_eval_episodes: int = 5,
        eval_freq: int = 10000,
        log_path: Optional[str] = None,
        best_model_save_path: Optional[str] = None,
        deterministic: bool = True,
        render: bool = False,
        verbose: int = 1,
        warn: bool = True,
    ):
        super().__init__(
            eval_env,
            callback_on_new_best=callback_on_new_best,
            callback_after_eval=callback_after_eval,
            n_eval_episodes=n_eval_episodes,
            eval_freq=eval_freq,
            log_path=log_path,
            best_model_save_path=best_model_save_path,
            deterministic=deterministic,
            render=render,
            verbose=verbose,
            warn=warn,
        )

        self.final_mc_list = []
        self.final_mc_entropy_list = []
        self.q_prediction = []
        self.bias = []
        self.bias_abs = []
        self.bias_squared = []
        self.normalized_bias_per_state = []
        self.normalized_bias_sqr_per_stat = []

        # generate a list of 1M random integers, using a jax random key supplied in the args
        seed_list = jax.random.randint(jax.random.PRNGKey(jax_random_key_for_seeds), (10000000,), 0, 2**30 - 1)
        # cast to numpy
        self.seed_list = np.array(seed_list)

    def _on_step(self) -> bool:
        continue_training = True

        if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
            # reset the env with a new seed at the current timestep (reproducibilty)
            self.eval_env.seed(int(self.seed_list[self.n_calls]))
            # Sync training and eval env if there is VecNormalize
            if self.model.get_vec_normalize_env() is not None:
                try:
                    sync_envs_normalization(self.training_env, self.eval_env)
                except AttributeError as e:
                    raise AssertionError(
                        "Training and eval env are not wrapped the same way, "
                        "see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback "
                        "and warning above."
                    ) from e

            # Reset success rate buffer
            self._is_success_buffer = []

            key, noise_key = self.model.policy.key, self.model.policy.noise_key
            metrics = get_mc_return_with_entropy_on_reset(
                self.eval_env, self.model, 1000, 0.99, n_mc_eval=1000, n_mc_cutoff=350
            )
            # Reset random key, to before evaluation.
            # This way we make sure that the randomness in the main training is not
            # influenced by the evaluation, how it should be.
            self.model.policy.key, self.model.policy.noise_key = key, noise_key

            if self.log_path is not None:
                kwargs = {}
                # Save success log if present
                if len(self._is_success_buffer) > 0:
                    self.evaluations_successes.append(self._is_success_buffer)
                    kwargs = dict(successes=self.evaluations_successes)

                np.savez(
                    self.log_path,
                    timesteps=self.evaluations_timesteps,
                    **metrics,
                    **kwargs,
                )

            for k, v in metrics.items():
                self.logger.record(f"Q/{k}_mean", jnp.mean(v).item())
                self.logger.record(f"Q/{k}_std", jnp.std(v).item())

            if len(self._is_success_buffer) > 0:
                success_rate = np.mean(self._is_success_buffer)
                if self.verbose >= 1:
                    print(f"Success rate: {100 * success_rate:.2f}%")
                self.logger.record("eval/success_rate", success_rate)

            # Dump log so the evaluation results are printed with the correct timestep
            self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
            self.logger.dump(self.num_timesteps)

            # Trigger callback after every evaluation, if needed
            if self.callback is not None:
                continue_training = continue_training and self._on_event()

        return continue_training


class EvalCallback(EventCallback):
    """
    Callback for evaluating an agent.

    .. warning::

      When using multiple environments, each call to  ``env.step()``
      will effectively correspond to ``n_envs`` steps.
      To account for that, you can use ``eval_freq = max(eval_freq // n_envs, 1)``

    :param eval_env: The environment used for initialization
    :param callback_on_new_best: Callback to trigger
        when there is a new best model according to the ``mean_reward``
    :param callback_after_eval: Callback to trigger after every evaluation
    :param n_eval_episodes: The number of episodes to test the agent
    :param eval_freq: Evaluate the agent every ``eval_freq`` call of the callback.
    :param log_path: Path to a folder where the evaluations (``evaluations.npz``)
        will be saved. It will be updated at each evaluation.
    :param best_model_save_path: Path to a folder where the best model
        according to performance on the eval env will be saved.
    :param deterministic: Whether the evaluation should
        use a stochastic or deterministic actions.
    :param render: Whether to render or not the environment during evaluation
    :param verbose: Verbosity level: 0 for no output, 1 for indicating information about evaluation results
    :param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been
        wrapped with a Monitor wrapper)
    """

    def __init__(
        self,
        eval_env: Union[gym.Env, VecEnv],
        jax_random_key_for_seeds: int,
        callback_on_new_best: Optional[BaseCallback] = None,
        callback_after_eval: Optional[BaseCallback] = None,
        n_eval_episodes: int = 5,
        eval_freq: int = 10000,
        log_path: Optional[str] = None,
        best_model_save_path: Optional[str] = None,
        deterministic: bool = True,
        render: bool = False,
        verbose: int = 1,
        warn: bool = True,
    ):
        super().__init__(callback_after_eval, verbose=verbose)

        self.callback_on_new_best = callback_on_new_best
        if self.callback_on_new_best is not None:
            # Give access to the parent
            self.callback_on_new_best.parent = self

        self.n_eval_episodes = n_eval_episodes
        self.eval_freq = eval_freq
        self.best_mean_reward = -np.inf
        self.last_mean_reward = -np.inf
        self.deterministic = deterministic
        self.render = render
        self.warn = warn

        # Convert to VecEnv for consistency
        if not isinstance(eval_env, VecEnv):
            eval_env = DummyVecEnv([lambda: eval_env])

        self.eval_env = eval_env
        self.best_model_save_path = best_model_save_path
        # Logs will be written in ``evaluations.npz``
        if log_path is not None:
            log_path = os.path.join(log_path, "evaluations")
        self.log_path = log_path
        self.evaluations_results = []
        self.evaluations_timesteps = []
        self.evaluations_length = []
        # For computing success rate
        self._is_success_buffer = []
        self.evaluations_successes = []

        # generate a list of 1M random integers, using a jax random key supplied in the args
        seed_list = jax.random.randint(jax.random.PRNGKey(jax_random_key_for_seeds), (10000000,), 0, 2**30 - 1)
        # cast to numpy
        self.seed_list = np.array(seed_list)

    def _init_callback(self) -> None:
        # Does not work in some corner cases, where the wrapper is not the same
        if not isinstance(self.training_env, type(self.eval_env)):
            warnings.warn("Training and eval env are not of the same type" f"{self.training_env} != {self.eval_env}")

        # Create folders if needed
        if self.best_model_save_path is not None:
            os.makedirs(self.best_model_save_path, exist_ok=True)
        if self.log_path is not None:
            os.makedirs(os.path.dirname(self.log_path), exist_ok=True)

        # Init callback called on new best model
        if self.callback_on_new_best is not None:
            self.callback_on_new_best.init_callback(self.model)

    def _log_success_callback(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None:
        """
        Callback passed to the  ``evaluate_policy`` function
        in order to log the success rate (when applicable),
        for instance when using HER.

        :param locals_:
        :param globals_:
        """
        info = locals_["info"]

        if locals_["done"]:
            maybe_is_success = info.get("is_success")
            if maybe_is_success is not None:
                self._is_success_buffer.append(maybe_is_success)

    def _on_step(self) -> bool:
        continue_training = True

        if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
            # reset the env with a new seed at the current timestep (reproducibilty)
            self.eval_env.seed(int(self.seed_list[self.n_calls]))

            # Sync training and eval env if there is VecNormalize
            if self.model.get_vec_normalize_env() is not None:
                try:
                    sync_envs_normalization(self.training_env, self.eval_env)
                except AttributeError as e:
                    raise AssertionError(
                        "Training and eval env are not wrapped the same way, "
                        "see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback "
                        "and warning above."
                    ) from e

            # Reset success rate buffer
            self._is_success_buffer = []

            episode_rewards, episode_lengths = evaluate_policy(
                self.model,
                self.eval_env,
                n_eval_episodes=self.n_eval_episodes,
                render=self.render,
                deterministic=self.deterministic,
                return_episode_rewards=True,
                warn=self.warn,
                callback=self._log_success_callback,
            )

            if self.log_path is not None:
                self.evaluations_timesteps.append(self.num_timesteps)
                self.evaluations_results.append(episode_rewards)
                self.evaluations_length.append(episode_lengths)

                kwargs = {}
                # Save success log if present
                if len(self._is_success_buffer) > 0:
                    self.evaluations_successes.append(self._is_success_buffer)
                    kwargs = dict(successes=self.evaluations_successes)

                np.savez(
                    self.log_path,
                    timesteps=self.evaluations_timesteps,
                    results=self.evaluations_results,
                    ep_lengths=self.evaluations_length,
                    **kwargs,
                )

            mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
            mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
            self.last_mean_reward = mean_reward

            if self.verbose >= 1:
                print(
                    f"Eval num_timesteps={self.num_timesteps}, "
                    f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}"
                )
                print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}")
            # Add to current Logger
            self.logger.record("eval/mean_reward", float(mean_reward))
            self.logger.record("eval/mean_ep_length", mean_ep_length)

            if len(self._is_success_buffer) > 0:
                success_rate = np.mean(self._is_success_buffer)
                if self.verbose >= 1:
                    print(f"Success rate: {100 * success_rate:.2f}%")
                self.logger.record("eval/success_rate", success_rate)

            # Dump log so the evaluation results are printed with the correct timestep
            self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
            self.logger.dump(self.num_timesteps)

            if mean_reward > self.best_mean_reward:
                if self.verbose >= 1:
                    print("New best mean reward!")
                if self.best_model_save_path is not None:
                    self.model.save(os.path.join(self.best_model_save_path, "best_model"))
                self.best_mean_reward = mean_reward
                # Trigger callback on new best model, if needed
                if self.callback_on_new_best is not None:
                    continue_training = self.callback_on_new_best.on_step()

            # Trigger callback after every evaluation, if needed
            if self.callback is not None:
                continue_training = continue_training and self._on_event()

        return continue_training

    def update_child_locals(self, locals_: Dict[str, Any]) -> None:
        """
        Update the references to the local variables.

        :param locals_: the local variables during rollout collection
        """
        if self.callback:
            self.callback.update_locals(locals_)
