# Adapted from https://github.com/navneet-nmk/pytorch-rl
# according to original implementation from https://github.com/openai/baselines
# and https://github.com/hill-a/stable-baselines
from copy import deepcopy
from typing import List

import torch as th
import torch.nn as nn
from stable_baselines3.common.utils import zip_strict


class ParameterNoise(object):
    """
    Implementation of Parameter Space Noise for Exploration, Plappert et al., 2018

    :param policy: The policy which actor parameters will be perturbed.
        Note: it should include layer normalization.
    :param initial_stddev: the initial value for the standard deviation of the noise
    :param desired_action_stddev: the desired value for the standard deviation of the noise
    :param adoption_coefficient: the update coefficient for the standard deviation of the noise
    """

    def __init__(
        self,
        policy: nn.Module,
        initial_stddev: float = 0.2,
        desired_action_stddev: float = 0.2,
        adoption_coefficient: float = 1.01,
    ):
        self.actor = policy.actor
        self.perturbed_actor = policy.make_actor().to(self.actor.device)
        self.adaptive_perturbed_actor = policy.make_actor().to(self.actor.device)

        self.initial_stddev = initial_stddev
        self.desired_action_stddev = desired_action_stddev
        self.adoption_coefficient = adoption_coefficient
        self.current_stddev = initial_stddev

    def get_perturbable_parameters(self, actor: nn.Module) -> List[nn.Parameter]:
        """
        :param actor: The actor object.
        :return: List of parameters to perturb
        """
        # Removing parameters that don't require parameter noise
        parameters = []
        params_dict = {key: param for key, param in actor.named_parameters()}
        for name, params in actor.named_parameters():
            check_param = None
            # Hack to find if a parameter belong to layer norm
            if "weight" in name:
                check_param = name
            elif "bias" in name:
                check_param = ".".join(name.split(".")[:-1] + ["weight"])
            # Do not perturb layer norm layers
            if check_param and len(params_dict[check_param].shape) == 1:
                continue

            parameters.append(params)

        return parameters

    def set_perturbed_actor_params(self, perturbed_actor: nn.Module):
        """
        Update the perturbed actor parameters.

        :param perturbed_actor:
        """
        actor_perturbable_parameters = self.get_perturbable_parameters(self.actor)
        perturbed_actor_perturbable_parameters = self.get_perturbable_parameters(perturbed_actor)

        for params, perturbed_params in zip_strict(actor_perturbable_parameters, perturbed_actor_perturbable_parameters):
            # Update the parameters
            perturbed_params.data.copy_(params + th.normal(mean=th.zeros_like(params), std=self.current_stddev))

    def distance_to_actor(self, observations: th.Tensor) -> float:
        # Configure separate copy for stddev adaptation
        self.adaptive_perturbed_actor.load_state_dict(deepcopy(self.actor.state_dict()))
        # Perturb the adaptive actor weights
        self.set_perturbed_actor_params(self.adaptive_perturbed_actor)
        # Refer to https://arxiv.org/pdf/1706.01905.pdf for details on the distance used specifically for DDPG
        with th.no_grad():
            # tf.sqrt(tf.reduce_mean(tf.square(self.actor_tf - adaptive_actor_tf)))
            action_diff = self.actor(observations, deterministic=True) - self.adaptive_perturbed_actor(
                observations, deterministic=True
            )
            self.adaptive_policy_distance = th.sqrt(th.mean(action_diff ** 2))
        return self.adaptive_policy_distance.item()

    def adapt_param_noise(self, observations: th.Tensor) -> None:
        """
        update the standard deviation for the parameter noise

        :param observations:
        """
        # Perturb a separate copy of the policy to adjust the scale for the next "real" perturbation.
        adaptive_noise_distance = self.distance_to_actor(observations)

        if adaptive_noise_distance > self.desired_action_stddev:
            # Decrease stddev.
            self.current_stddev /= self.adoption_coefficient
        else:
            # Increase stddev.
            self.current_stddev *= self.adoption_coefficient

    def reset(self) -> None:
        self.perturbed_actor.load_state_dict(deepcopy(self.actor.state_dict()))
        # Sample new noise after an episode is complete
        self.set_perturbed_actor_params(self.perturbed_actor)
