from gym.spaces import Box, Discrete
import numpy as np
from typing import Optional, TYPE_CHECKING, Union

from env.base_env import BaseEnv
from models.action_dist import ActionDistribution
from models.modelv2 import ModelV2
from models.torch.torch_action_dist import (
    TorchCategorical,
    TorchDeterministic,
)
from policy.sample_batch import SampleBatch
from utils.annotations import override
from utils.exploration.exploration import Exploration
from utils.framework import get_variable, try_import_torch
from utils.from_config import from_config
from utils.numpy import softmax, SMALL_NUMBER
from utils.typing import TensorType

if TYPE_CHECKING:
    from policy.policy import Policy

torch, _ = try_import_torch()


class ParameterNoise(Exploration):
    """An exploration that changes a Model's parameters.

    Implemented based on:
    [1] https://blog.openai.com/better-exploration-with-parameter-noise/
    [2] https://arxiv.org/pdf/1706.01905.pdf

    At the beginning of an episode, Gaussian noise is added to all weights
    of the model. At the end of the episode, the noise is undone and an action
    diff (pi-delta) is calculated, from which we determine the changes in the
    noise's stddev for the next episode.
    """

    def __init__(
        self,
        action_space,
        *,
        framework: str,
        policy_config: dict,
        model: ModelV2,
        initial_stddev: float = 1.0,
        random_timesteps: int = 10000,
        sub_exploration: Optional[dict] = None,
        **kwargs,
    ):
        """Initializes a ParameterNoise Exploration object.

        Args:
            initial_stddev: The initial stddev to use for the noise.
            random_timesteps: The number of timesteps to act completely
                randomly (see [1]).
            sub_exploration: Optional sub-exploration config.
                None for auto-detection/setup.
        """
        assert framework is not None
        super().__init__(
            action_space,
            policy_config=policy_config,
            model=model,
            framework=framework,
            **kwargs,
        )

        self.stddev = get_variable(initial_stddev, framework=self.framework)
        self.stddev_val = initial_stddev  # Out-of-graph tf value holder.

        # The weight variables of the Model where noise should be applied to.
        # This excludes any variable, whose name contains "LayerNorm" (those
        # are BatchNormalization layers, which should not be perturbed).
        self.model_variables = [
            v
            for k, v in self.model.trainable_variables(as_dict=True).items()
            if "LayerNorm" not in k
        ]
        # Our noise to be added to the weights. Each item in `self.noise`
        # corresponds to one Model variable and holding the Gaussian noise to
        # be added to that variable (weight).
        self.noise = []
        for var in self.model_variables:
            name_ = var.name.split(":")[0] + "_noisy" if var.name else ""
            self.noise.append(
                get_variable(
                    np.zeros(var.shape, dtype=np.float32),
                    framework=self.framework,
                    torch_tensor=True,
                    device=self.device,
                )
            )

        # Whether the Model's weights currently have noise added or not.
        self.weights_are_currently_noisy = False

        # Auto-detection of underlying exploration functionality.
        if sub_exploration is None:
            # For discrete action spaces, use an underlying EpsilonGreedy with
            # a special schedule.
            if isinstance(self.action_space, Discrete):
                sub_exploration = {
                    "type": "EpsilonGreedy",
                    "epsilon_schedule": {
                        "type": "PiecewiseSchedule",
                        # Step function (see [2]).
                        "endpoints": [
                            (0, 1.0),
                            (random_timesteps + 1, 1.0),
                            (random_timesteps + 2, 0.01),
                        ],
                        "outside_value": 0.01,
                    },
                }
            elif isinstance(self.action_space, Box):
                sub_exploration = {
                    "type": "OrnsteinUhlenbeckNoise",
                    "random_timesteps": random_timesteps,
                }
            # TODO(sven): Implement for any action space.
            else:
                raise NotImplementedError

        self.sub_exploration = from_config(
            Exploration,
            sub_exploration,
            framework=self.framework,
            action_space=self.action_space,
            policy_config=self.policy_config,
            model=self.model,
            **kwargs,
        )

        # Whether we need to call `self._delayed_on_episode_start` before
        # the forward pass.
        self.episode_started = False

    @override(Exploration)
    def before_compute_actions(
        self,
        *,
        timestep: Optional[int] = None,
        explore: Optional[bool] = None,
    ):
        explore = explore if explore is not None else self.policy_config["explore"]

        # Is this the first forward pass in the new episode? If yes, do the
        # noise re-sampling and add to weights.
        if self.episode_started:
            self._delayed_on_episode_start(explore)

        # Add noise if necessary.
        if explore and not self.weights_are_currently_noisy:
            self._add_stored_noise()
        # Remove noise if necessary.
        elif not explore and self.weights_are_currently_noisy:
            self._remove_noise()

    @override(Exploration)
    def get_exploration_action(
        self,
        *,
        action_distribution: ActionDistribution,
        timestep: Union[TensorType, int],
        explore: Union[TensorType, bool],
    ):
        # Use our sub-exploration object to handle the final exploration
        # action (depends on the algo-type/action-space/etc..).
        return self.sub_exploration.get_exploration_action(
            action_distribution=action_distribution, timestep=timestep, explore=explore
        )

    @override(Exploration)
    def on_episode_start(
        self,
        policy: "Policy",
        *,
        environment: BaseEnv = None,
        episode: int = None,
    ):
        # We have to delay the noise-adding step by one forward call.
        # This is due to the fact that the optimizer does it's step right
        # after the episode was reset (and hence the noise was already added!).
        # We don't want to update into a noisy net.
        self.episode_started = True

    def _delayed_on_episode_start(self, explore):
        # Sample fresh noise and add to weights.
        if explore:
            self._sample_new_noise_and_add(override=True)
        # Only sample, don't apply anything to the weights.
        else:
            self._sample_new_noise()
        self.episode_started = False

    @override(Exploration)
    def on_episode_end(self, policy, *, environment=None, episode=None):
        # Remove stored noise from weights (only if currently noisy).
        if self.weights_are_currently_noisy:
            self._remove_noise()

    @override(Exploration)
    def postprocess_trajectory(
        self,
        policy: "Policy",
        sample_batch: SampleBatch,
    ):
        noisy_action_dist = noise_free_action_dist = None
        # Adjust the stddev depending on the action (pi)-distance.
        # Also see [1] for details.
        # TODO(sven): Find out whether this can be scrapped by simply using
        #  the `sample_batch` to get the noisy/noise-free action dist.
        _, _, fetches = policy.compute_actions_from_input_dict(
            input_dict=sample_batch, explore=self.weights_are_currently_noisy
        )

        # Categorical case (e.g. DQN).
        if policy.dist_class in [TorchCategorical]:
            action_dist = softmax(fetches[SampleBatch.ACTION_DIST_INPUTS])
        # Deterministic (Gaussian actions, e.g. DDPG).
        elif policy.dist_class in [TorchDeterministic]:
            action_dist = fetches[SampleBatch.ACTION_DIST_INPUTS]
        else:
            raise NotImplementedError  # TODO(sven): Other action-dist cases.

        if self.weights_are_currently_noisy:
            noisy_action_dist = action_dist
        else:
            noise_free_action_dist = action_dist

        _, _, fetches = policy.compute_actions_from_input_dict(
            input_dict=sample_batch, explore=not self.weights_are_currently_noisy
        )

        # Categorical case (e.g. DQN).
        if policy.dist_class in [TorchCategorical]:
            action_dist = softmax(fetches[SampleBatch.ACTION_DIST_INPUTS])
            # Deterministic (Gaussian actions, e.g. DDPG).
        elif policy.dist_class in [TorchDeterministic]:
            action_dist = fetches[SampleBatch.ACTION_DIST_INPUTS]

        if noisy_action_dist is None:
            noisy_action_dist = action_dist
        else:
            noise_free_action_dist = action_dist

        delta = distance = None
        # Categorical case (e.g. DQN).
        if policy.dist_class in [TorchCategorical]:
            # Calculate KL-divergence (DKL(clean||noisy)) according to [2].
            # TODO(sven): Allow KL-divergence to be calculated by our
            #  Distribution classes (don't support off-graph/numpy yet).
            distance = np.nanmean(
                np.sum(
                    noise_free_action_dist
                    * np.log(
                        noise_free_action_dist / (noisy_action_dist + SMALL_NUMBER)
                    ),
                    1,
                )
            )
            current_epsilon = self.sub_exploration.get_state()["cur_epsilon"]
            delta = -np.log(1 - current_epsilon + current_epsilon / self.action_space.n)
        elif policy.dist_class in [TorchDeterministic]:
            # Calculate MSE between noisy and non-noisy output (see [2]).
            distance = np.sqrt(
                np.mean(np.square(noise_free_action_dist - noisy_action_dist))
            )
            current_scale = self.sub_exploration.get_state()["cur_scale"]
            delta = getattr(self.sub_exploration, "ou_sigma", 0.2) * current_scale

        # Adjust stddev according to the calculated action-distance.
        if distance <= delta:
            self.stddev_val *= 1.01
        else:
            self.stddev_val /= 1.01

        # Update our state (self.stddev and self.stddev_val).
        self.set_state(self.get_state())

        return sample_batch

    def _sample_new_noise(self):
        """Samples new noise and stores it in `self.noise`."""
        for i in range(len(self.noise)):
            self.noise[i] = torch.normal(
                mean=torch.zeros(self.noise[i].size()), std=self.stddev
            ).to(self.device)

    def _sample_new_noise_and_add(self, *, override=False):
        if override and self.weights_are_currently_noisy:
            self._remove_noise()
        self._sample_new_noise()
        self._add_stored_noise()

        self.weights_are_currently_noisy = True

    def _add_stored_noise(self):
        """Adds the stored `self.noise` to the model's parameters.

        Note: No new sampling of noise here.

        Args:
            override (bool): If True, undo any currently applied noise first,
                then add the currently stored noise.
        """
        # Make sure we only add noise to currently noise-free weights.
        assert self.weights_are_currently_noisy is False

        # Add stored noise to the model's parameters.
        for var, noise in zip(self.model_variables, self.noise):
            # Add noise to weights in-place.
            var.requires_grad = False
            var.add_(noise)
            var.requires_grad = True

        self.weights_are_currently_noisy = True

    def _remove_noise(self):
        """
        Removes the current action noise from the model parameters.
        """
        # Make sure we only remove noise iff currently noisy.
        assert self.weights_are_currently_noisy is True

        # Removes the stored noise from the model's parameters.
        for var, noise in zip(self.model_variables, self.noise):
            # Remove noise from weights in-place.
            var.requires_grad = False
            var.add_(-noise)
            var.requires_grad = True

        self.weights_are_currently_noisy = False

    @override(Exploration)
    def get_state(self):
        return {"cur_stddev": self.stddev_val}

    @override(Exploration)
    def set_state(self, state: dict) -> None:
        self.stddev_val = state["cur_stddev"]
        # Set self.stddev to calculated value.
        if isinstance(self.stddev, float):
            self.stddev = self.stddev_val
        else:
            self.stddev.assign(self.stddev_val)
