import warnings
from typing import Any, Dict, Optional, Type, Union, Tuple

import gym
import numpy as np
import torch as th
from stable_baselines3 import HerReplayBuffer
from stable_baselines3.common.buffers import ReplayBuffer, DictReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from torch.nn import functional as F

from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn, polyak_update

from rl.ooqn.policies import OOQNPolicy


class OOQN(OffPolicyAlgorithm):
    def __init__(
            self,
            policy: Union[str, Type[OOQNPolicy]],
            env: Union[GymEnv, str],
            learning_rate: Union[float, Schedule] = 1e-4,
            n_steps: int = 5,
            batch_size: int = 32,
            gradient_steps: int = 1,
            tau: float = 1.0,
            gamma: float = 0.99,
            replay_buffer_class: Optional[ReplayBuffer] = None,
            replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
            optimize_memory_usage: bool = False,
            target_update_interval: int = 10000,
            exploration_fraction: float = 0.1,
            exploration_initial_eps: float = 1.0,
            exploration_final_eps: float = 0.05,
            max_grad_norm: float = 10,
            tensorboard_log: Optional[str] = None,
            create_eval_env: bool = False,
            policy_kwargs: Optional[Dict[str, Any]] = None,
            verbose: int = 0,
            seed: Optional[int] = None,
            device: Union[th.device, str] = "auto",
            _init_setup_model: bool = True,
            transition_loss_coef: float = 1.0,
            reward_loss_coef: float = 1.0,
    ):

        super(OOQN, self).__init__(
            policy,
            env,
            OOQNPolicy,
            learning_rate,
            env.num_envs * n_steps,
            0,
            batch_size,
            tau,
            gamma,
            n_steps,
            gradient_steps,
            action_noise=None,  # No action noise
            replay_buffer_class=replay_buffer_class,
            replay_buffer_kwargs=replay_buffer_kwargs,
            policy_kwargs=policy_kwargs,
            tensorboard_log=tensorboard_log,
            verbose=verbose,
            device=device,
            create_eval_env=create_eval_env,
            seed=seed,
            sde_support=False,
            optimize_memory_usage=optimize_memory_usage,
            supported_action_spaces=(gym.spaces.Discrete,),
            support_multi_env=True,
        )

        self.n_steps = n_steps
        self.max_grad_norm = max_grad_norm
        self.tau = tau
        self.target_update_interval = target_update_interval
        self.exploration_fraction = exploration_fraction
        self.exploration_initial_eps = exploration_initial_eps
        self.exploration_final_eps = exploration_final_eps
        self.transition_loss_coef = transition_loss_coef
        self.reward_loss_coef = reward_loss_coef
        self._n_calls = 0
        self.exploration_rate = 0.0
        self.exploration_schedule = None

        if self.env is not None:
            # Check that `n_steps * n_envs > 1` to avoid NaN
            # when doing advantage normalization
            buffer_size = self.env.num_envs * self.n_steps
            assert (
                    buffer_size > 1
            ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}"
            # Check that the rollout buffer size is a multiple of the mini-batch size
            untruncated_batches = buffer_size // batch_size
            if buffer_size % batch_size > 0:
                warnings.warn(
                    f"You have specified a mini-batch size of {batch_size},"
                    f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`,"
                    f" after every {untruncated_batches} untruncated mini-batches,"
                    f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n"
                    f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n"
                    f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})"
                )
        self.policy = policy.to(self.device)

        if _init_setup_model:
            self._setup_model()

    def _setup_model(self) -> None:
        self._setup_lr_schedule()
        self.set_random_seed(self.seed)

        # Use DictReplayBuffer if needed
        if self.replay_buffer_class is None:
            if isinstance(self.observation_space, gym.spaces.Dict):
                self.replay_buffer_class = DictReplayBuffer
            else:
                self.replay_buffer_class = ReplayBuffer

        elif self.replay_buffer_class == HerReplayBuffer:
            assert self.env is not None, "You must pass an environment when using `HerReplayBuffer`"

            # If using offline sampling, we need a classic replay buffer too
            if self.replay_buffer_kwargs.get("online_sampling", True):
                replay_buffer = None
            else:
                replay_buffer = DictReplayBuffer(
                    self.buffer_size,
                    self.observation_space,
                    self.action_space,
                    device=self.device,
                    optimize_memory_usage=self.optimize_memory_usage,
                )

            self.replay_buffer = HerReplayBuffer(
                self.env,
                self.buffer_size,
                device=self.device,
                replay_buffer=replay_buffer,
                **self.replay_buffer_kwargs,
            )

        if self.replay_buffer is None:
            self.replay_buffer = self.replay_buffer_class(
                self.buffer_size,
                self.observation_space,
                self.action_space,
                device=self.device,
                n_envs=self.n_envs,
                optimize_memory_usage=self.optimize_memory_usage,
                **self.replay_buffer_kwargs,
            )

        # Convert train freq parameter to TrainFreq object
        self._convert_train_freq()
        self._create_aliases()
        self.exploration_schedule = get_linear_fn(
            self.exploration_initial_eps,
            self.exploration_final_eps,
            self.exploration_fraction,
        )
        # Account for multiple environments
        # each call to step() corresponds to n_envs transitions
        if self.n_envs > 1:
            if self.n_envs > self.target_update_interval:
                warnings.warn(
                    "The number of environments used is greater than the target network "
                    f"update interval ({self.n_envs} > {self.target_update_interval}), "
                    "therefore the target network will be updated after each call to env.step() "
                    f"which corresponds to {self.n_envs} steps."
                )

            self.target_update_interval = max(self.target_update_interval // self.n_envs, 1)

    def _create_aliases(self) -> None:
        self.q_net = self.policy.q_net
        self.q_net_target = self.policy.q_net_target

    def _on_step(self) -> None:
        self._n_calls += 1
        if self._n_calls % self.target_update_interval == 0:
            polyak_update(self.q_net.parameters(), self.q_net_target.parameters(), self.tau)

        self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
        self.logger.record("rollout/exploration_rate", self.exploration_rate)

    def train(self, gradient_steps: int, batch_size: int = 100) -> None:
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update learning rate according to schedule
        if self.policy.use_wm_optimizer:
            optimizers = [self.q_net.optimizer_wm, self.q_net.optimizer_value]
        else:
            optimizers = [self.q_net.optimizer]
        self._update_learning_rate(optimizers)

        q_losses, wm_losses, transition_losses, reward_losses, losses = [], [], [], [], []
        wm_grad_norms, value_grad_norms, q_net_grad_norms = [], [], []
        for _ in range(gradient_steps):
            # Sample replay buffer
            replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)

            with th.no_grad():
                features_detach = self.policy.extract_features(replay_data.observations)
                next_features_detach = self.policy.extract_features(replay_data.next_observations,
                                                                    prev_slots=features_detach)
                # Compute the next Q-values using the target network
                next_q_values = self.q_net_target(next_features_detach)
                # Follow greedy policy: use the one with the highest value
                next_q_values, _ = next_q_values.max(dim=1)
                # Avoid potential broadcast issue
                next_q_values = next_q_values.reshape(-1, 1)
                # 1-step TD target
                target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values

            # Get current Q-values estimates
            current_q_values = self.q_net(features_detach)

            # Retrieve the q-values for the actions from the replay buffer
            current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long())

            q_loss = F.mse_loss(current_q_values, target_q_values)
            q_losses.append(q_loss.item())

            if self.policy.use_wm_optimizer:
                self.q_net.optimizer_value.zero_grad()
                q_loss.backward()
                if self.max_grad_norm is not None:
                    value_grad_norm = th.nn.utils.clip_grad_norm_(self.q_net.value_model.parameters(),
                                                                  self.max_grad_norm)
                    value_grad_norms.append(value_grad_norm.item())
                self.q_net.optimizer_value.step()

            next_features_prediction = self.q_net.transition_model(features_detach, replay_data.actions.squeeze(dim=1))
            transition_loss = F.mse_loss(next_features_prediction, next_features_detach)
            transition_losses.append(transition_loss.item())

            rewards_prediction = self.q_net.reward_model(features_detach, replay_data.actions.squeeze(dim=1))
            reward_loss = F.mse_loss(rewards_prediction, replay_data.rewards.squeeze(dim=1))
            reward_losses.append(reward_loss.item())

            wm_loss = self.transition_loss_coef * transition_loss + self.reward_loss_coef * reward_loss
            wm_losses.append(wm_loss.item())

            if self.policy.use_wm_optimizer:
                self.q_net.optimizer_wm.zero_grad()
                wm_loss.backward()
                if self.max_grad_norm is not None:
                    wm_parameters = list(self.q_net.transition_model.parameters()) + list(
                        self.q_net.reward_model.parameters())
                    wm_grad_norm = th.nn.utils.clip_grad_norm_(wm_parameters, self.max_grad_norm)
                    wm_grad_norms.append(wm_grad_norm.item())
                self.q_net.optimizer_wm.step()

            loss = q_loss + wm_loss
            losses.append(loss.item())

            if not self.policy.use_wm_optimizer:
                self.q_net.optimizer.zero_grad()
                loss.backward()
                if self.max_grad_norm is not None:
                    q_net_grad_norm = th.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.max_grad_norm)
                    q_net_grad_norms.append(q_net_grad_norm.item())
                self.q_net.optimizer.step()

        # Increase update counter
        self._n_updates += gradient_steps

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/loss", np.mean(losses))
        self.logger.record("train/q_loss", np.mean(q_losses))
        self.logger.record("train/wm_loss", np.mean(wm_losses))
        self.logger.record("train/transition_loss", np.mean(transition_losses))
        self.logger.record("train/reward_loss", np.mean(reward_losses))
        if len(wm_grad_norms) > 0:
            self.logger.record("train/wm_grad_norm", np.mean(wm_grad_norms))
        if len(value_grad_norms) > 0:
            self.logger.record("train/value_grad_norm", np.mean(value_grad_norms))
        if len(q_net_grad_norms) > 0:
            self.logger.record("train/q_net_grad_norm", np.mean(q_net_grad_norms))

    def learn(
            self,
            total_timesteps: int,
            callback: MaybeCallback = None,
            log_interval: int = 4,
            eval_env: Optional[GymEnv] = None,
            eval_freq: int = -1,
            n_eval_episodes: int = 5,
            tb_log_name: str = "DQN",
            eval_log_path: Optional[str] = None,
            reset_num_timesteps: bool = True,
    ) -> OffPolicyAlgorithm:

        return super(OOQN, self).learn(
            total_timesteps=total_timesteps,
            callback=callback,
            log_interval=log_interval,
            eval_env=eval_env,
            eval_freq=eval_freq,
            n_eval_episodes=n_eval_episodes,
            tb_log_name=tb_log_name,
            eval_log_path=eval_log_path,
            reset_num_timesteps=reset_num_timesteps,
        )

    def _sample_action(
            self,
            learning_starts: int,
            action_noise: Optional[ActionNoise] = None,
            n_envs: int = 1,
    ) -> Tuple[np.ndarray, np.ndarray]:
        action = self.predict(self._last_obs, deterministic=False)[0]
        return action, action

    def predict(
            self,
            observation: np.ndarray,
            state: Optional[Tuple[np.ndarray, ...]] = None,
            episode_start: Optional[np.ndarray] = None,
            deterministic: bool = False,
    ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
        if not deterministic and np.random.rand() < self.exploration_rate:
            action = np.array([self.action_space.sample() for _ in range(observation.shape[0])])
        else:
            action, state = self.policy.predict(observation, state, episode_start, deterministic)
        return action, state
