from typing import Callable, Generic, Optional, TypeVar

import numpy as np
from gymnasium.spaces import Space

from async_rl.utils import gamma_renewal_prob

SInType = TypeVar("SInType")
SOutType = TypeVar("SOutType", covariant=True)


class Signal(Generic[SInType, SOutType]):
	"""A signal is a specific view of the environment state.

	In the context of Asynchronous MDPs, a signal represents a view, or
	projection, of the environment state from the perspective of a
	specific observer, like a sensor.

	Args:
		space_out (Space): The space of possible states outputted by the
			signal.
		f (Callable[[SInType], SOutType]): The function to apply to the
			input state to obtain the output signal.
		period_beta (tuple[float, float]): A tuple containing the
			default period and shape parameter of the signal renewal
			time. By default, `(1, 1)`. Notice that these can be
			overridden by the `renew` method.
		max_interval (Optional[float]): The maximum interval between
			signal renewals. By default, `None`, which means the
			interval is unbounded.
		name (Optional[str]): The name of the signal. By default,
			`None`.
	"""

	def __init__(
		self,
		space_out: Space[SOutType],
		f: Callable[[SInType], SOutType],
		*,
		period_beta: tuple[float, float] = (1, 1),
		max_interval: Optional[float] = None,
		name: Optional[str] = None,
	) -> None:
		self.space_out = space_out
		self._f = f

		self._period = period_beta[0]
		self._gamma_beta = period_beta[1]
		self._max_interval = (
			max_interval if max_interval is not None else float("inf")
		)

		self._name = name

	def __call__(self, state: SInType) -> SOutType:
		"""Returns the output signal for the given state.

		Args:
			state (SInType): The state to observe.

		Returns:
			SOutType: The output signal.
		"""
		return self._f(state)

	def renew(
		self,
		last_seen: int,
		period: Optional[float] = None,
		gamma_beta: Optional[float] = None,
	) -> bool:
		"""Computes whether the signal has been renewed.

		Args:
			last_seen (int): The time since the last observation.
			period (Optional[float]): The period of the renewal
				distribution. By default, `None`, which means the
				default period is used.
			gamma_beta (Optional[float]): The shape parameter of the
				renewal distribution. By default, `None`, which means
				the default shape parameter is used.

		Returns:
			bool: Whether the signal has been renewed.
		"""
		if last_seen >= self._max_interval:
			return True
		rng = np.random.default_rng()
		outcome = rng.random()
		prob = gamma_renewal_prob(
			period if period else self._period,
			gamma_beta if gamma_beta else self._gamma_beta,
			last_seen,
		)
		return outcome < prob

	@property
	def max_interval(self) -> Optional[float]:
		"""Optional[float]: The maximum interval between renewals."""
		return self._max_interval

	@property
	def name(self) -> Optional[str]:
		"""Optional[str]: The name of the signal."""
		return self._name
