import time
import warnings
from typing import Any, Dict, Optional, Union, Tuple, List

import gym
import numpy as np
import torch as th
from gym import spaces
from stable_baselines3.common.buffers import DictRolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.vec_env import VecEnv
from torch.nn import functional as F

from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean

from rl.ppo.buffers import RolloutBuffer
from rl.ppo.policies import PPOCustomPolicy


class PPO(OnPolicyAlgorithm):
    def __init__(
        self,
        policy: PPOCustomPolicy,
        env: Union[GymEnv, str],
        learning_rate: Union[float, Schedule] = 3e-4,
        n_steps: int = 2048,
        batch_size: int = 64,
        n_epochs: int = 10,
        gamma: float = 0.99,
        gae_lambda: float = 0.95,
        clip_range: Union[float, Schedule] = 0.2,
        clip_range_vf: Union[None, float, Schedule] = None,
        normalize_advantage: bool = True,
        ent_coef: float = 0.0,
        vf_coef: float = 0.5,
        max_grad_norm: float = 0.5,
        use_sde: bool = False,
        sde_sample_freq: int = -1,
        target_kl: Optional[float] = None,
        tensorboard_log: Optional[str] = None,
        create_eval_env: bool = False,
        policy_kwargs: Optional[Dict[str, Any]] = None,
        verbose: int = 0,
        seed: Optional[int] = None,
        device: Union[th.device, str] = "auto",
        rollout_buffer=None,
        transition_loss_coef: float = 1.0,
        reward_loss_coef: float = 1.0,
        _init_setup_model: bool = True,
    ):

        super(PPO, self).__init__(
            policy,
            env,
            learning_rate=learning_rate,
            n_steps=n_steps,
            gamma=gamma,
            gae_lambda=gae_lambda,
            ent_coef=ent_coef,
            vf_coef=vf_coef,
            max_grad_norm=max_grad_norm,
            use_sde=use_sde,
            sde_sample_freq=sde_sample_freq,
            tensorboard_log=tensorboard_log,
            policy_kwargs=policy_kwargs,
            verbose=verbose,
            device=device,
            create_eval_env=create_eval_env,
            seed=seed,
            _init_setup_model=False,
            supported_action_spaces=(
                spaces.Box,
                spaces.Discrete,
                spaces.MultiDiscrete,
                spaces.MultiBinary,
            ),
        )

        # Sanity check, otherwise it will lead to noisy gradient and NaN
        # because of the advantage normalization
        if normalize_advantage:
            assert (
                batch_size > 1
            ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440"

        if self.env is not None:
            # Check that `n_steps * n_envs > 1` to avoid NaN
            # when doing advantage normalization
            buffer_size = self.env.num_envs * self.n_steps
            assert (
                buffer_size > 1
            ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}"
            # Check that the rollout buffer size is a multiple of the mini-batch size
            untruncated_batches = buffer_size // batch_size
            if buffer_size % batch_size > 0:
                warnings.warn(
                    f"You have specified a mini-batch size of {batch_size},"
                    f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`,"
                    f" after every {untruncated_batches} untruncated mini-batches,"
                    f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n"
                    f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n"
                    f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})"
                )
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.clip_range = clip_range
        self.clip_range_vf = clip_range_vf
        self.normalize_advantage = normalize_advantage
        self.target_kl = target_kl
        self.policy = policy
        self.rollout_buffer = rollout_buffer
        self._last_embedding = None
        self.transition_loss_coef = transition_loss_coef
        self.reward_loss_coef = reward_loss_coef

        if _init_setup_model:
            self._setup_model()

    def _setup_model(self) -> None:
        self._setup_lr_schedule()
        self.set_random_seed(self.seed)

        if self.rollout_buffer is None:
            buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else RolloutBuffer

            self.rollout_buffer = buffer_cls(
                self.n_steps,
                self.observation_space,
                self.action_space,
                device=self.device,
                gamma=self.gamma,
                gae_lambda=self.gae_lambda,
                n_envs=self.n_envs,
            )
        self.policy = self.policy.to(self.device)

        # Initialize schedules for policy/value clipping
        self.clip_range = get_schedule_fn(self.clip_range)
        if self.clip_range_vf is not None:
            if isinstance(self.clip_range_vf, (float, int)):
                assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping"

            self.clip_range_vf = get_schedule_fn(self.clip_range_vf)

    def collect_rollouts(
        self,
        env: VecEnv,
        callback: BaseCallback,
        rollout_buffer: RolloutBuffer,
        n_rollout_steps: int,
    ) -> bool:
        assert self._last_obs is not None, "No previous observation was provided"
        # Switch to eval mode (this affects batch norm / dropout)
        self.policy.set_training_mode(False)

        n_steps = 0
        rollout_buffer.reset()
        # Sample new weights for the state dependent exploration
        if self.use_sde:
            self.policy.reset_noise(env.num_envs)

        callback.on_rollout_start()

        while n_steps < n_rollout_steps:
            if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
                # Sample a new noise matrix
                self.policy.reset_noise(env.num_envs)

            with th.no_grad():
                # Convert to pytorch tensor or to TensorDict
                last_embedding = obs_as_tensor(self._last_embedding, self.device)
                actions, values, log_probs = self.policy(last_embedding)
            actions = actions.cpu().numpy()

            # Rescale and perform action
            clipped_actions = actions
            # Clip the actions to avoid out of bound error
            if isinstance(self.action_space, gym.spaces.Box):
                clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)

            new_obs, rewards, dones, infos = env.step(clipped_actions)

            self.num_timesteps += env.num_envs

            # Give access to local variables
            callback.update_locals(locals())
            if callback.on_step() is False:
                return False

            self._update_info_buffer(infos, dones)
            n_steps += 1

            if isinstance(self.action_space, gym.spaces.Discrete):
                # Reshape in case of discrete action
                actions = actions.reshape(-1, 1)

            new_obs_tensor = th.as_tensor(new_obs, dtype=th.float32)
            new_embedding = self.policy.extract_features(new_obs_tensor, prev_slots=last_embedding)
            next_embedding = new_embedding.detach().cpu().numpy()
            # Handle timeout by bootstraping with value function
            # see GitHub issue #633
            for idx, done in enumerate(dones):
                if (
                    done
                    and infos[idx].get("terminal_observation") is not None
                    and infos[idx].get("TimeLimit.truncated", False)
                ):
                    terminal_obs = th.as_tensor(infos[idx]["terminal_observation"], dtype=th.float32)
                    terminal_embedding = self.policy.extract_features(terminal_obs.unsqueeze(0))
                    next_embedding[idx] = terminal_embedding.detach().cpu().numpy()[0]
                    with th.no_grad():
                        terminal_value = self.policy.predict_values(terminal_embedding)[0].cpu().numpy()[0]
                    rewards[idx] += self.gamma * terminal_value

            if np.sum(dones) > 0:
                new_embedding[dones] = self.policy.extract_features(new_obs_tensor[dones], prev_slots=None)
            rollout_buffer.add(self._last_embedding, next_embedding, actions, rewards, self._last_episode_starts, values, log_probs)
            self._last_obs = new_obs
            self._last_embedding = new_embedding.detach().cpu().numpy()
            self._last_episode_starts = dones

        with th.no_grad():
            # Compute value for the last timestep
            values = self.policy.predict_values(new_embedding)[0]

        rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)

        callback.on_rollout_end()

        return True

    def _setup_learn(
        self,
        total_timesteps: int,
        eval_env: Optional[GymEnv],
        callback: MaybeCallback = None,
        eval_freq: int = 10000,
        n_eval_episodes: int = 5,
        log_path: Optional[str] = None,
        reset_num_timesteps: bool = True,
        tb_log_name: str = "run",
    ) -> Tuple[int, BaseCallback]:
        total_timesteps, callback = super()._setup_learn(
            total_timesteps,
            eval_env,
            callback,
            eval_freq,
            n_eval_episodes,
            log_path,
            reset_num_timesteps,
            tb_log_name
        )

        if self.policy.features_extractor is not None:
            if self._last_embedding is None:
                self._last_embedding = self.policy.extract_features(
                    obs=th.as_tensor(self._last_obs, dtype=th.float32)
                ).cpu().numpy()

        return total_timesteps, callback

    def learn(
        self,
        total_timesteps: int,
        callback: MaybeCallback = None,
        log_interval: int = 1,
        eval_env: Optional[GymEnv] = None,
        eval_freq: int = -1,
        n_eval_episodes: int = 5,
        tb_log_name: str = "PPO",
        eval_log_path: Optional[str] = None,
        reset_num_timesteps: bool = True,
    ) -> "PPO":
        iteration = 0

        total_timesteps, callback = self._setup_learn(
            total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
        )

        callback.on_training_start(locals(), globals())

        while self.num_timesteps < total_timesteps:

            continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)

            if continue_training is False:
                break

            iteration += 1
            self._update_current_progress_remaining(self.num_timesteps, total_timesteps)

            # Display training infos
            if log_interval is not None and iteration % log_interval == 0:
                fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time))
                self.logger.record("time/iterations", iteration, exclude="tensorboard")
                if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
                    self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
                    self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
                if len(self.ep_success_buffer) > 0:
                    self.logger.record("rollout/success_rate", safe_mean(self.ep_success_buffer))
                self.logger.record("time/fps", fps)
                self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard")
                self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
                self.logger.dump(step=self.num_timesteps)

            self.train()

        callback.on_training_end()

        return self

    def train(self) -> None:
        """
        Update policy using the currently gathered rollout buffer.
        """
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update optimizer learning rate
        self._update_learning_rate(self.policy.optimizer)
        if self.policy.use_wm_optimizer:
            self._update_learning_rate(self.policy.optimizer_wm)
        # Compute current clip range
        clip_range = self.clip_range(self._current_progress_remaining)
        # Optional: clip range for the value function
        if self.clip_range_vf is not None:
            clip_range_vf = self.clip_range_vf(self._current_progress_remaining)

        entropy_losses = []
        pg_losses, value_losses = [], []
        clip_fractions = []
        grad_norms = []
        transition_losses = []
        reward_losses = []
        wm_losses = []
        grad_norms_wm = []
        first_epoch_transition_losses = []
        first_epoch_reward_losses = []
        first_epoch_wm_losses = []

        continue_training = True
        rollout_values = []
        rollout_averaged_values = []

        # train for n_epochs epochs
        for epoch in range(self.n_epochs):
            approx_kl_divs = []
            # Do a complete pass on the rollout buffer
            for rollout_data in self.rollout_buffer.get(self.batch_size):
                actions = rollout_data.actions
                if isinstance(self.action_space, spaces.Discrete):
                    # Convert discrete action from float to long
                    actions = rollout_data.actions.long().flatten()

                # Re-sample the noise matrix because the log_std has changed
                if self.use_sde:
                    self.policy.reset_noise(self.batch_size)

                values, log_prob, entropy, averaged_values = self.policy.evaluate_actions(rollout_data.observations, actions)
                values = values.flatten()
                averaged_values = averaged_values.flatten()
                if epoch == 0:
                    rollout_values.append(values.detach())
                    rollout_averaged_values.append(averaged_values.detach())
                # Normalize advantage
                advantages = rollout_data.advantages
                if self.normalize_advantage:
                    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

                # ratio between old and new policy, should be one at the first iteration
                ratio = th.exp(log_prob - rollout_data.old_log_prob)

                # clipped surrogate loss
                policy_loss_1 = advantages * ratio
                policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
                policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()

                # Logging
                pg_losses.append(policy_loss.item())
                clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
                clip_fractions.append(clip_fraction)

                if self.clip_range_vf is None:
                    # No clipping
                    values_pred = values
                else:
                    # Clip the different between old and new value
                    # NOTE: this depends on the reward scaling
                    values_pred = rollout_data.old_values + th.clamp(
                        values - rollout_data.old_values, -clip_range_vf, clip_range_vf
                    )
                # Value loss using the TD(gae_lambda) target
                value_loss = F.mse_loss(rollout_data.returns, values_pred)
                value_losses.append(value_loss.item())

                # Entropy loss favor exploration
                if entropy is None:
                    # Approximate entropy when no analytical form
                    entropy_loss = -th.mean(-log_prob)
                else:
                    entropy_loss = -th.mean(entropy)

                entropy_losses.append(entropy_loss.item())

                loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
                if not self.policy.v_critic and not self.policy.use_wm_optimizer:
                    next_observations_prediction = self.policy.critic.transition_models[0](
                        rollout_data.observations,
                        actions
                    )
                    transition_loss = F.mse_loss(next_observations_prediction, rollout_data.next_observations)
                    transition_losses.append(transition_loss.item())

                    rewards_prediction = self.policy.critic.reward_models[0](
                        rollout_data.observations,
                        actions
                    )
                    reward_loss = F.mse_loss(rewards_prediction, rollout_data.rewards)
                    reward_losses.append(reward_loss.item())

                    wm_loss = self.transition_loss_coef * transition_loss + self.reward_loss_coef * reward_loss
                    wm_losses.append(wm_loss.item())
                    loss += wm_loss

                    if epoch == 0:
                        first_epoch_transition_losses.append(transition_loss.item())
                        first_epoch_reward_losses.append(reward_loss.item())
                        first_epoch_wm_losses.append(wm_loss.item())

                # Calculate approximate form of reverse KL Divergence for early stopping
                # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
                # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
                # and Schulman blog: http://joschu.net/blog/kl-approx.html
                with th.no_grad():
                    log_ratio = log_prob - rollout_data.old_log_prob
                    approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
                    approx_kl_divs.append(approx_kl_div)

                if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
                    continue_training = False
                    if self.verbose >= 1:
                        print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
                    break

                # Optimization step
                self.policy.optimizer.zero_grad()
                loss.backward()
                # Clip grad norm
                grad_norm = th.nn.utils.clip_grad_norm_(self.policy._get_parameters(), self.max_grad_norm)
                grad_norms.append(grad_norm.item())
                self.policy.optimizer.step()

                if self.policy.use_wm_optimizer:
                    next_observations_prediction = self.policy.critic.transition_models[0](
                        rollout_data.observations,
                        actions
                    )
                    transition_loss = F.mse_loss(next_observations_prediction, rollout_data.next_observations)
                    transition_losses.append(transition_loss.item())

                    rewards_prediction = self.policy.critic.reward_models[0](
                        rollout_data.observations,
                        actions
                    )
                    reward_loss = F.mse_loss(rewards_prediction, rollout_data.rewards)
                    reward_losses.append(reward_loss.item())

                    wm_loss = self.transition_loss_coef * transition_loss + self.reward_loss_coef * reward_loss
                    wm_losses.append(wm_loss.item())

                    if epoch == 0:
                        first_epoch_transition_losses.append(transition_loss.item())
                        first_epoch_reward_losses.append(reward_loss.item())
                        first_epoch_wm_losses.append(wm_loss.item())

                    self.policy.optimizer_wm.zero_grad()
                    wm_loss.backward()
                    # Clip grad norm
                    grand_norm_wm = th.nn.utils.clip_grad_norm_(self.policy._get_parameters_wm(), self.max_grad_norm)
                    grad_norms_wm.append(grand_norm_wm.item())
                    self.policy.optimizer_wm.step()

            if not continue_training:
                break

        self._n_updates += self.n_epochs
        explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
        rollout_values = th.cat(rollout_values, dim=0)
        rollout_averaged_values = th.cat(rollout_averaged_values, dim=0)
        v_diff_mean = th.mean(rollout_averaged_values - rollout_values).item()
        v_diff_std = th.std(rollout_averaged_values - rollout_values).item()

        # Logs
        self.logger.record("train/entropy_loss", np.mean(entropy_losses))
        self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
        self.logger.record("train/value_loss", np.mean(value_losses))
        self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
        self.logger.record("train/clip_fraction", np.mean(clip_fractions))
        self.logger.record("train/loss", loss.item())
        self.logger.record("train/explained_variance", explained_var)
        self.logger.record("train/grad_norm", np.mean(grad_norms))
        self.logger.record("train/transition_loss", np.mean(transition_losses))
        self.logger.record("train/reward_loss", np.mean(reward_losses))
        self.logger.record("train/wm_loss", np.mean(wm_losses))
        self.logger.record("train/mean(averaged_q-v)", v_diff_mean)
        self.logger.record("train/std(averaged_q-v)", v_diff_std)
        self.logger.record("train/first_epoch_transition_loss", np.mean(first_epoch_transition_losses))
        self.logger.record("train/first_epoch_reward_loss", np.mean(first_epoch_reward_losses))
        self.logger.record("train/first_epoch_wm_loss", np.mean(first_epoch_wm_losses))
        if hasattr(self.policy, "log_std"):
            self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())

        if self.policy.use_wm_optimizer:
            self.logger.record("train/grad_norm_wm", np.mean(grad_norms_wm))

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/clip_range", clip_range)
        if self.clip_range_vf is not None:
            self.logger.record("train/clip_range_vf", clip_range_vf)

    def predict(
            self,
            observation: np.ndarray,
            state: Optional[Tuple[np.ndarray, ...]] = None,
            episode_start: Optional[np.ndarray] = None,
            deterministic: bool = False,
    ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
        embedding = th.as_tensor(observation, dtype=th.float32, device=self.device)
        if self.policy.features_extractor is not None:
            embedding = self.policy.extract_features(obs=embedding)

        return self.policy.actor.predict(embedding, state, episode_start, deterministic)

    def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
        state_dicts = ["policy", "policy.optimizer"]
        if self.policy.use_wm_optimizer:
            state_dicts.append("policy.optimizer_wm")

        return state_dicts, []
