from typing import List, Optional

import torch
from tensordict import TensorDictBase

from src.async_rl.module.noise_control.sampler import sample_signals
from src.async_rl.module.trainer import RLTrainer, RLTrainerHook
from src.async_rl.module.utils import Key


class DummyNoiseControlHook(RLTrainerHook):
	"""Dummy hook that does nothing.

	This is used when the user does not want to use noise control
	during training. It exists for the purpose of easier configuration
	with Hydra.
	"""

	def register(
		self, trainer: RLTrainer, name: str = "DummyNoiseControlHook"
	) -> None:
		"""Registers the dummy noise control hook to the trainer.

		Args:
			trainer (RLTrainer): Trainer to register the hook to.
			name (str, optional): Name of the hook. Defaults to
				`"DummyNoiseControlHook"`.
		"""
		super().register(trainer, name)


class NoiseControlHook(RLTrainerHook):
	"""Hook implementing noise control during training.

	Noise control is a technique used whenn training agents in
	reinforcement learning on partially observable environments. It
	consists of adding noise to the observations of the agent,
	simulating regimes with different levels of noise.

	This class also allows to implement noise control with curriculum
	learning, where the noise level is increased or decreased gradually
	over time.

	Currently, the class only implements noise control for environments
	with asynchronous observations, rather than general partially
	observable environments.

	Args:
		min_noise_ratio (float): Minimum ratio of subsampled signals.
			Each time a batch is subsampled to simulate greater latency,
			the ratio of subsampled signals is drawn between this value
			and `1.0`.
		signal_keys (List[Key]): The keys used to store signals in the
			input tensor dictionary that should be noised.
		curriculum_factor (float, optional): Factor by which to multiply
			the `min_noise_ratio` to increase the noise level after a
			certain number of steps. Defaults to `None`, meaning that
			curriculum learning is not used.
		curriculum_period (int, optional): Number of steps after which
			the noise level is increased. Defaults to `None`, meaning
			that curriculum learning is not used.
		per_batch (bool, optional): Whether to sample a different
			ratio of signals for each batch. Defaults to `True`.
		aggregate (bool, optional): Whether to aggregate the sampled
			signals, together with the action history, into a single
			`"observation"` key. Defaults to `False`.
	"""

	def __init__(
		self,
		min_noise_ratio: float,
		signal_keys: List[Key],
		curriculum_factor: Optional[float] = None,
		curriculum_period: Optional[int] = None,
		per_batch: bool = True,
		aggregate: bool = False,
	) -> None:
		self.min_noise_ratio = min_noise_ratio
		self.signal_keys = sorted(signal_keys)
		self.per_batch = per_batch
		self.aggregate = aggregate

		self.curriculum_factor = 1.0
		self.curriculum_period = -1
		self.last_updated = 0

		if curriculum_factor is not None:
			assert (
				curriculum_period is not None
			), "curriculum_factor is set, but curriculum_period is not."

			self.curriculum_factor = curriculum_factor
			self.curriculum_period = curriculum_period

	def register(
		self, trainer: RLTrainer, name: str = "NoiseControlHook"
	) -> None:
		"""Registers the noise control hook to the trainer.

		Args:
			trainer (RLTrainer): Trainer to register the hook to.
			name (str, optional): Name of the hook. Defaults to
				`"NoiseControlHook"`.
		"""
		super().register(trainer, name)
		trainer.register_hook("sample", self._add_noise)
		if self.curriculum_period > 0:
			trainer.register_hook("end_optim_step", self._update_noise_level)

	def _add_noise(self, batch: TensorDictBase) -> TensorDictBase:
		"""Adds noise to the observation signals in the batch.

		Args:
			batch (TensorDictBase): The batch to add noise to.

		Returns:
			TensorDictBase: The batch with added noise, i.e. with
				asynchronous signals with greater latency.
		"""
		return sample_signals(
			batch,
			self.signal_keys,
			min_ratio=self.min_noise_ratio,
			per_batch=self.per_batch,
			aggregate=self.aggregate,
		)

	def _update_noise_level(self) -> None:
		"""Updates the noise level according to curriculum learning."""
		if self.curriculum_period <= 0 or self.trainer is None:
			return

		trainer = self.trainer()
		if trainer is None:
			return

		steps = trainer.metrics["optim_steps"]
		periods_passed = (steps - self.last_updated) // self.curriculum_period
		if periods_passed <= 0:
			return
		self.last_updated = steps - (steps % self.curriculum_period)

		self.min_noise_ratio *= self.curriculum_factor**periods_passed

		if trainer.logger is not None:
			trainer.logger.log_scalar(
				"noise_level/min_ratio", self.min_noise_ratio, steps
			)


class RandomMaskHook(RLTrainerHook):
	"""Hook implementing random masking of observations' values.

	Opposed to the `NoiseControlHook`, this hook is signal-agnostic and
	simply masks a certain ratio of the values in the `"observation"` of
	the input tensor dictionary.

	Moreover, it does not draw a random new ratio for each sequence, but
	rather uses a fixed ratio for the entire training process.

	Args:
		mask_ratio (float): Ratio of values to mask.
	"""

	def __init__(self, mask_ratio: float) -> None:
		self.mask_ratio = mask_ratio

	def register(
		self, trainer: RLTrainer, name: str = "RandomMaskHook"
	) -> None:
		"""Registers the random mask hook to the trainer.

		Args:
			trainer (RLTrainer): Trainer to register the hook to.
			name (str, optional): Name of the hook. Defaults to
				`"RandomMaskHook"`.
		"""
		super().register(trainer, name)
		trainer.register_hook("sample", self._add_mask)

	def _add_mask(self, batch: TensorDictBase) -> TensorDictBase:
		"""Adds a random mask to the observation signals in the batch.

		Args:
			batch (TensorDictBase): The batch to add noise to.

		Returns:
			TensorDictBase: The batch with added noise, i.e. with
				asynchronous signals with greater latency.
		"""
		obs = batch.get("observation")
		if obs is None:
			return batch

		mask = obs.new_ones(obs.shape)
		mask = mask.bernoulli_(self.mask_ratio)
		mask = mask.bool()
		obs[mask] = 0.0
		batch.set("observation", obs)

		return batch


class GaussianNoiseHook(RLTrainerHook):
	"""Hook implementing addition of Gaussian noise to observations.

	This hook adds random Gaussian noise to the observation values in
	the input tensor dictionary, simulating noisy sensors or measurement
	error in the environment.

	Args:
		mean (float): Mean of the Gaussian noise distribution. Defaults
			to `0.0` for zero-centered noise.
		std (float): Standard deviation of the Gaussian noise
			distribution. Controls the magnitude of the noise.
		observation_key (str, optional): The key in the tensor
			dictionary that contains the observations to add noise to.
			Defaults to `"observation"`.
	"""

	def __init__(
		self,
		mean: float = 0.0,
		std: float = 0.1,
		observation_key: str = "observation",
	) -> None:
		self.mean = mean
		self.std = std
		self.observation_key = observation_key

	def register(
		self, trainer: RLTrainer, name: str = "GaussianNoiseHook"
	) -> None:
		"""Registers the Gaussian noise hook to the trainer.

		Args:
			trainer (RLTrainer): Trainer to register the hook to.
			name (str, optional): Name of the hook. Defaults to
				`"GaussianNoiseHook"`.
		"""
		super().register(trainer, name)
		trainer.register_hook("sample", self._add_gaussian_noise)

	def _add_gaussian_noise(self, batch: TensorDictBase) -> TensorDictBase:
		"""Adds Gaussian noise to the observations in the batch.

		Args:
			batch (TensorDictBase): The batch to add noise to.

		Returns:
			TensorDictBase: The batch with added Gaussian noise.
		"""
		obs = batch.get(self.observation_key)
		if obs is None:
			return batch

		noise = torch.randn_like(obs) * self.std + self.mean

		noisy_obs = obs + noise
		batch.set(self.observation_key, noisy_obs)

		return batch
