import warnings
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union

import numpy as np
import torch as th
from gymnasium import spaces
from torch.nn import functional as F

from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import (
    get_linear_fn,
    get_parameters_by_name,
    polyak_update,
    constant_fn,
    update_learning_rate,
)
from rl_zoo3.custom_algos.ic_dqn.policies import (
    CnnPolicy,
    DQNPolicy,
    MlpPolicy,
    MultiInputPolicy,
    QNetwork,
)

SelfICDQN = TypeVar("SelfICDQN", bound="ICDQN")


class ICDQN(OffPolicyAlgorithm):
    """
    Deep Q-Network (DQN) with variable target update intervals

    Original Paper for DQN: https://arxiv.org/abs/1312.5602, https://www.nature.com/articles/nature14236
    Default hyperparameters are taken from the Nature paper,
    except for the optimizer and learning rate that were taken from Stable Baselines defaults.

    :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
    :param env: The environment to learn from (if registered in Gym, can be str)
    :param learning_rate: The learning rate, it can be a function
        of the current progress remaining (from 1 to 0)
    :param learning_reset: If True, learning rate will be reset at each target network update
    :param learning_rate_determinant: Learning rate can be measured with respect to progress or n_calls
    :param learning_rate_max: The maximum learning rate value, it can be a function
        of the current progress remaining (from 1 to 0) or the target iteration number
    :param learning_rate_max_determinant: The determinant for the maximum learning rate, either
        progress or target_iteration
    :param buffer_size: size of the replay buffer
    :param learning_starts: how many steps of the model to collect transitions for before learning starts
    :param batch_size: Minibatch size for each gradient update
    :param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update
    :param gamma: the discount factor
    :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
        like ``(5, "step")`` or ``(2, "episode")``.
    :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
        Set to ``-1`` means to do as many gradient steps as steps done in the environment
        during the rollout.
    :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
        If ``None``, it will be automatically selected.
    :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
    :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
        at a cost of more complexity.
        See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
    :param initial_target_update_interval: start by updating the target network every ``target_update_interval``
        environment steps.
    :param geometric_target_update_intervals: If True, use adaptive target update interval sizes that grow geometrically
    :param constant_target_update_interval_fraction: Fraction of total timesteps after which to start increasing
    :param exploration_fraction: fraction of entire training period over which the exploration rate is reduced
    :param exploration_initial_eps: initial value of random action probability
    :param exploration_final_eps: final value of random action probability
    :param max_grad_norm: The maximum value for the gradient clipping
    :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
        the reported success rate, mean episode length, and mean reward over
    :param tensorboard_log: the log location for tensorboard (if None, no logging)
    :param policy_kwargs: additional arguments to be passed to the policy on creation
    :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
        debug messages
    :param seed: Seed for the pseudo random generators
    :param device: Device (cpu, cuda, ...) on which the code should be run.
        Setting it to auto, the code will be run on the GPU if possible.
    :param _init_setup_model: Whether or not to build the network at the creation of the instance
    """

    policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
        "MlpPolicy": MlpPolicy,
        "CnnPolicy": CnnPolicy,
        "MultiInputPolicy": MultiInputPolicy,
    }
    # Linear schedule will be defined in `_setup_model()`
    exploration_schedule: Schedule
    q_net: QNetwork
    q_net_target: QNetwork
    policy: DQNPolicy

    def __init__(
        self,
        policy: Union[str, Type[DQNPolicy]],
        env: Union[GymEnv, str],
        learning_rate: Union[float, Schedule] = 1e-4,
        learning_reset: bool = True,
        learning_rate_determinant: str = "progress",  # or n_calls
        learning_rate_max: Union[float, Schedule] = 1.0,
        learning_rate_max_determinant: str = "progress",  # or target_iteration
        buffer_size: int = 1_000_000,  # 1e6
        learning_starts: int = 50000,
        batch_size: int = 32,
        tau: float = 1.0,
        gamma: float = 0.99,
        train_freq: Union[int, Tuple[int, str]] = 4,
        gradient_steps: int = 1,
        replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
        replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
        optimize_memory_usage: bool = False,
        initial_target_update_interval: int = 10000,
        geometric_target_update_intervals: bool = True,
        constant_target_update_interval_fraction: float = 0.0,
        exploration_fraction: float = 0.1,
        exploration_initial_eps: float = 1.0,
        exploration_final_eps: float = 0.05,
        max_grad_norm: float = 10,
        stats_window_size: int = 100,
        tensorboard_log: Optional[str] = None,
        policy_kwargs: Optional[Dict[str, Any]] = None,
        verbose: int = 0,
        seed: Optional[int] = None,
        device: Union[th.device, str] = "auto",
        _init_setup_model: bool = True,
    ) -> None:
        super().__init__(
            policy,
            env,
            learning_rate,
            buffer_size,
            learning_starts,
            batch_size,
            tau,
            gamma,
            train_freq,
            gradient_steps,
            action_noise=None,  # No action noise
            replay_buffer_class=replay_buffer_class,
            replay_buffer_kwargs=replay_buffer_kwargs,
            policy_kwargs=policy_kwargs,
            stats_window_size=stats_window_size,
            tensorboard_log=tensorboard_log,
            verbose=verbose,
            device=device,
            seed=seed,
            sde_support=False,
            optimize_memory_usage=optimize_memory_usage,
            supported_action_spaces=(spaces.Discrete,),
            support_multi_env=True,
        )

        self.learning_reset = learning_reset
        self.learning_rate_determinant = learning_rate_determinant
        if isinstance(learning_rate_max, (float, int)):
            self.learning_rate_max_schedule = constant_fn(float(learning_rate_max))
        else:
            self.learning_rate_max_schedule = learning_rate_max
            assert callable(learning_rate_max)
        self.learning_rate_max_determinant = learning_rate_max_determinant

        self.exploration_initial_eps = exploration_initial_eps
        self.exploration_final_eps = exploration_final_eps
        self.exploration_fraction = exploration_fraction

        self._current_progress_remaining_for_update_interval = 1.0
        self.progress_remaining = 1.0

        self.current_target_update_interval = initial_target_update_interval
        self.current_target_update_interval_exact = initial_target_update_interval
        self.constant_target_update_interval_fraction = (
            constant_target_update_interval_fraction
        )
        # For updating the target network with multiple envs:
        self._n_calls = 0
        self._n_calls_since_last_target_update = 0
        self._n_target_iteration = 1
        self.max_grad_norm = max_grad_norm
        self.geometric_target_update_intervals = geometric_target_update_intervals
        # "epsilon" for the epsilon-greedy exploration
        self.exploration_rate = 0.0

        if _init_setup_model:
            self._setup_model()

    def _setup_model(self) -> None:
        super()._setup_model()
        self._create_aliases()
        # Copy running stats, see GH issue #996
        self.batch_norm_stats = get_parameters_by_name(self.q_net, ["running_"])
        self.batch_norm_stats_target = get_parameters_by_name(
            self.q_net_target, ["running_"]
        )
        self.exploration_schedule = get_linear_fn(
            self.exploration_initial_eps,
            self.exploration_final_eps,
            self.exploration_fraction,
        )

        if self.n_envs > 1:
            if self.n_envs > self.current_target_update_interval:
                warnings.warn(
                    "The number of environments used is greater than the initial target network "
                    f"update interval ({self.n_envs} > {self.current_target_update_interval}), "
                    "therefore the target network will be updated after each call to env.step() "
                    f"which corresponds to {self.n_envs} steps."
                )

    def _create_aliases(self) -> None:
        self.q_net = self.policy.q_net
        self.q_net_target = self.policy.q_net_target

    def _update_learning_rate(
        self, optimizers: Union[List[th.optim.Optimizer], th.optim.Optimizer]
    ) -> None:
        """
        Update the optimizers learning rate using the current learning rate schedule
        and the current progress remaining (from 1 to 0).

        :param optimizers:
            An optimizer or a list of optimizers.
        """

        if not isinstance(optimizers, list):
            optimizers = [optimizers]
        for optimizer in optimizers:
            if self.learning_rate_determinant == "progress":
                if self.learning_reset:
                    if self.learning_rate_max_determinant == "progress":
                        update_learning_rate(
                            optimizer,
                            min(
                                self.lr_schedule(
                                    self._current_progress_remaining_for_update_interval
                                ),
                                self.learning_rate_max_schedule(
                                    self._current_progress_remaining
                                ),
                            ),
                        )
                        self.logger.record(
                            "train/learning_rate",
                            min(
                                self.lr_schedule(
                                    self._current_progress_remaining_for_update_interval
                                ),
                                self.learning_rate_max_schedule(
                                    self._current_progress_remaining
                                ),
                            ),
                        )
                    else:
                        update_learning_rate(
                            optimizer,
                            min(
                                self.lr_schedule(
                                    self._current_progress_remaining_for_update_interval
                                ),
                                self.learning_rate_max_schedule(
                                    self._n_target_iteration
                                ),
                            ),
                        )
                        self.logger.record(
                            "train/learning_rate",
                            min(
                                self.lr_schedule(
                                    self._current_progress_remaining_for_update_interval
                                ),
                                self.learning_rate_max_schedule(
                                    self._n_target_iteration
                                ),
                            ),
                        )
                else:
                    if self.learning_rate_max_determinant == "progress":
                        update_learning_rate(
                            optimizer,
                            min(
                                self.lr_schedule(self._current_progress_remaining),
                                self.learning_rate_max_schedule(
                                    self._current_progress_remaining
                                ),
                            ),
                        )
                        self.logger.record(
                            "train/learning_rate",
                            min(
                                self.lr_schedule(self._current_progress_remaining),
                                self.learning_rate_max_schedule(
                                    self._current_progress_remaining
                                ),
                            ),
                        )
                    else:
                        update_learning_rate(
                            optimizer,
                            min(
                                self.lr_schedule(self._current_progress_remaining),
                                self.learning_rate_max_schedule(
                                    self._n_target_iteration
                                ),
                            ),
                        )
                        self.logger.record(
                            "train/learning_rate",
                            min(
                                self.lr_schedule(self._current_progress_remaining),
                                self.learning_rate_max_schedule(
                                    self._n_target_iteration
                                ),
                            ),
                        )
            else:
                if self.learning_reset:
                    if self.learning_rate_max_determinant == "progress":
                        update_learning_rate(
                            optimizer,
                            min(
                                self.lr_schedule(
                                    self._n_calls_since_last_target_update
                                ),
                                self.learning_rate_max_schedule(
                                    self._current_progress_remaining
                                ),
                            ),
                        )
                        self.logger.record(
                            "train/learning_rate",
                            min(
                                self.lr_schedule(
                                    self._n_calls_since_last_target_update
                                ),
                                self.learning_rate_max_schedule(
                                    self._current_progress_remaining
                                ),
                            ),
                        )
                    else:
                        update_learning_rate(
                            optimizer,
                            min(
                                self.lr_schedule(
                                    self._n_calls_since_last_target_update
                                ),
                                self.learning_rate_max_schedule(
                                    self._n_target_iteration
                                ),
                            ),
                        )
                        self.logger.record(
                            "train/learning_rate",
                            min(
                                self.lr_schedule(
                                    self._n_calls_since_last_target_update
                                ),
                                self.learning_rate_max_schedule(
                                    self._n_target_iteration
                                ),
                            ),
                        )
                else:
                    if self.learning_rate_max_determinant == "progress":
                        update_learning_rate(
                            optimizer,
                            min(
                                self.lr_schedule(self._n_calls),
                                self.learning_rate_max_schedule(
                                    self._current_progress_remaining
                                ),
                            ),
                        )
                        self.logger.record(
                            "train/learning_rate",
                            min(
                                self.lr_schedule(self._n_calls),
                                self.learning_rate_max_schedule(
                                    self._current_progress_remaining
                                ),
                            ),
                        )
                    else:
                        update_learning_rate(
                            optimizer,
                            min(
                                self.lr_schedule(self._n_calls),
                                self.learning_rate_max_schedule(
                                    self._n_target_iteration
                                ),
                            ),
                        )
                        self.logger.record(
                            "train/learning_rate",
                            min(
                                self.lr_schedule(self._n_calls),
                                self.learning_rate_max_schedule(
                                    self._n_target_iteration
                                ),
                            ),
                        )

    def _on_step(self) -> None:
        """
        Update the exploration rate and target network if needed.
        This method is called in ``collect_rollouts()`` after each step in the environment.
        """
        self._n_calls += 1
        self._n_calls_since_last_target_update += 1
        self.progress_remaining -= 1 / self._total_timesteps
        self._current_progress_remaining_for_update_interval -= (
            1 / self.current_target_update_interval
        )
        # Account for multiple environments
        # each call to step() corresponds to n_envs transitions
        if (
            self._n_calls_since_last_target_update
            % max(self.current_target_update_interval // self.n_envs, 1)
            == 0
        ):
            self._n_target_iteration += 1
            self._n_calls_since_last_target_update = 0
            self._current_progress_remaining_for_update_interval = 1.0
            if (
                self.geometric_target_update_intervals
                and (
                    1 - self.progress_remaining
                    >= self.constant_target_update_interval_fraction
                )
                and self._n_calls >= self.learning_starts
            ):
                self.current_target_update_interval_exact = (
                    self.current_target_update_interval_exact * self.gamma ** (-2 / 3)
                )
                self.current_target_update_interval = int(
                    self.current_target_update_interval_exact
                )
            polyak_update(
                self.q_net.parameters(), self.q_net_target.parameters(), self.tau
            )
            # Copy running stats, see GH issue #996
            polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)

        self.exploration_rate = self.exploration_schedule(
            self._current_progress_remaining
        )
        self.logger.record("rollout/exploration_rate", self.exploration_rate)

    def train(self, gradient_steps: int, batch_size: int = 100) -> None:
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update learning rate according to schedule
        self._update_learning_rate(self.policy.optimizer)

        losses = []
        for _ in range(gradient_steps):
            # Sample replay buffer
            replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)  # type: ignore[union-attr]

            with th.no_grad():
                # Compute the next Q-values using the target network
                next_q_values = self.q_net_target(replay_data.next_observations)
                # Follow greedy policy: use the one with the highest value
                next_q_values, _ = next_q_values.max(dim=1)
                # Avoid potential broadcast issue
                next_q_values = next_q_values.reshape(-1, 1)
                # 1-step TD target
                target_q_values = (
                    replay_data.rewards
                    + (1 - replay_data.dones) * self.gamma * next_q_values
                )

            # Get current Q-values estimates
            current_q_values = self.q_net(replay_data.observations)

            # Retrieve the q-values for the actions from the replay buffer
            current_q_values = th.gather(
                current_q_values, dim=1, index=replay_data.actions.long()
            )

            # Compute Huber loss (less sensitive to outliers)
            loss = F.smooth_l1_loss(current_q_values, target_q_values)
            losses.append(loss.item())

            # Optimize the policy
            self.policy.optimizer.zero_grad()
            loss.backward()
            # Clip gradient norm
            th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
            self.policy.optimizer.step()

        # Increase update counter
        self._n_updates += gradient_steps

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/loss", np.mean(losses))

    def predict(
        self,
        observation: Union[np.ndarray, Dict[str, np.ndarray]],
        state: Optional[Tuple[np.ndarray, ...]] = None,
        episode_start: Optional[np.ndarray] = None,
        deterministic: bool = False,
    ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
        """
        Overrides the base_class predict function to include epsilon-greedy exploration.

        :param observation: the input observation
        :param state: The last states (can be None, used in recurrent policies)
        :param episode_start: The last masks (can be None, used in recurrent policies)
        :param deterministic: Whether or not to return deterministic actions.
        :return: the model's action and the next state
            (used in recurrent policies)
        """
        if not deterministic and np.random.rand() < self.exploration_rate:
            if self.policy.is_vectorized_observation(observation):
                if isinstance(observation, dict):
                    n_batch = observation[next(iter(observation.keys()))].shape[0]
                else:
                    n_batch = observation.shape[0]
                action = np.array([self.action_space.sample() for _ in range(n_batch)])
            else:
                action = np.array(self.action_space.sample())
        else:
            action, state = self.policy.predict(
                observation, state, episode_start, deterministic
            )
        return action, state

    def learn(
        self: SelfICDQN,
        total_timesteps: int,
        callback: MaybeCallback = None,
        log_interval: int = 4,
        tb_log_name: str = "DQN",
        reset_num_timesteps: bool = True,
        progress_bar: bool = False,
    ) -> SelfICDQN:
        return super().learn(
            total_timesteps=total_timesteps,
            callback=callback,
            log_interval=log_interval,
            tb_log_name=tb_log_name,
            reset_num_timesteps=reset_num_timesteps,
            progress_bar=progress_bar,
        )

    def _excluded_save_params(self) -> List[str]:
        return [*super()._excluded_save_params(), "q_net", "q_net_target"]

    def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
        state_dicts = ["policy", "policy.optimizer"]

        return state_dicts, []
