import math
from collections import OrderedDict
from typing import Generic, List, Optional, Tuple, TypeVar

import gymnasium as gym
import numpy as np
from gymnasium import ObservationWrapper, Wrapper

from .signal import Signal

SInType = TypeVar("SInType")
SOutType = TypeVar("SOutType", covariant=True)
ActType = TypeVar("ActType")


class SignaledObservation(
	ObservationWrapper[dict, ActType, SInType], Generic[SInType, ActType]
):
	"""Wrapper to transform observations into a list of signals.

	Args:
		env (gym.Env): The environment to wrap.
		signals (List[Signal[SInType, SOutType]]): The list of signals
			to apply to the observations.
	"""

	def __init__(
		self,
		env: gym.Env[SInType, ActType],
		signals: List[Signal[SInType, SOutType]],
	) -> None:
		super().__init__(env)

		self._signals = {}
		unnamed_signal_idx = 0
		space_dict = {}
		for signal in signals:
			name = signal.name
			if name is None:
				name = f"signal_{unnamed_signal_idx}"
				unnamed_signal_idx += 1
			space_dict[name] = signal.space_out
			self._signals[name] = signal

		self.observation_space = gym.spaces.Dict(space_dict)

	def observation(self, observation: SInType) -> dict:
		"""Returns the observation viewed by the signals.

		Args:
			observation (SInType): The observation to view by the
				signals.

		Returns:
			dict: The observation viewed by the signals as an ordered
				dictionary.
		"""
		return OrderedDict(
			[
				(name, signal(observation))
				for name, signal in self._signals.items()
			]
		)


class AsynchronousSignals(
	Wrapper[dict | gym.spaces.Box, ActType, dict, ActType], Generic[ActType]
):
	"""Wrapper to turn signals into asynchronous signals.

	This wrapper takes an environment that is already wrapped with
	`SignaledObservation` and turns the signals into asynchronous,
	following the Asynchronous MDPs framework.

	Args:
		env (SignaledObservation): The environment alrewady wrapped with
			`SignaledObservation`.
		action_history (bool, optional): Whether to include the action
			history in the observation. Defaults to `False`. If `True`,
			the last `n` actions are included in the observation, where
			`n` is the maximum renewal period of the signals.
		period_range (Optional[Tuple[float, float]], optional): The
			range from which to sample signals' renewal periods at each
			rollout. Defaults to `None`, which means the default period
			of each signal is used.
		gamma_beta_range (Optional[Tuple[float, float]], optional): The
			range from which to sample signals' shape parameter at each
			rollout. Defaults to `None`, which means the default shape
			parameter of each signal is used.
		remember (bool, optional): Whether to remember the most recent
			observations of unseen signals, together with the time since
			they were last seen. Defaults to `True`. If `False`, the
			observations of unseen signals are set to `0` and the last
			coordinate in the observation is used as a flag to indicate
			whether the signal was seen in the last time step.
		include_flattened_obs (bool, optional): Whether to include
			flattened observations in the observation. Defaults to
			`False`. If `True`, the observations are flattened,
			concatenated into a single array, and included in the
			dictionary space under the key `"observation"`.

	Raises:
		ValueError: If `action_history` is `True` and any signal has an
			infinite maximum renewal interval.
	"""

	def __init__(
		self,
		env: SignaledObservation,
		action_history: bool = False,
		period_range: Optional[Tuple[float, float]] = None,
		gamma_beta_range: Optional[Tuple[float, float]] = None,
		remember: bool = True,
		include_flattened_obs: bool = False,
	) -> None:
		super().__init__(env)
		self._last_observation: Optional[dict] = None

		self._action_history = action_history
		self._n = math.ceil(
			max(signal.max_interval for signal in env._signals.values())
		)
		self._action_dim = env.action_space.shape[0]  # type: ignore
		if action_history and self._n == float("inf"):
			raise ValueError(
				"Cannot include action history when signals have infinite "
				"max intervals."
			)

		self.keys: List[str] = list(env.observation_space.keys())  # type: ignore
		self.keys.sort()
		space_dict = {}
		for key, space in env.observation_space.items():  # type: ignore
			space_dict[key] = gym.spaces.Box(
				low=np.concatenate([space.low, [0.0]]),
				high=np.concatenate([space.high, [float("inf")]]),
				shape=(space.shape[0] + 1,),
				dtype=space.dtype,
			)
		if action_history:
			space_dict["action_history"] = gym.spaces.Box(
				low=np.repeat(env.action_space.low, self._n),  # type: ignore
				high=np.repeat(env.action_space.high, self._n),  # type: ignore
				shape=(self._action_dim * self._n,),
			)
		if include_flattened_obs:
			space_dict["observation"] = gym.spaces.Box(
				low=np.concatenate(
					[space_dict[key].low for key in self.keys]
					+ [space_dict["action_history"].low]
				),
				high=np.concatenate(
					[space_dict[key].high for key in self.keys]
					+ [space_dict["action_history"].high]
				),
				shape=(sum(v.shape[0] for v in space_dict.values()),),
			)
		self.observation_space = gym.spaces.Dict(space_dict)

		self.period_range = period_range
		self.gamma_beta_range = gamma_beta_range
		self.curr_period: Optional[float] = None
		self.curr_gamma_beta: Optional[float] = None
		self.remember = remember
		self.include_flattened_obs = include_flattened_obs

	def step(self, action: ActType) -> tuple:
		"""Step the environment.

		Args:
			action (ActType): The action to take in the environment.

		Returns:
			tuple: The new observation, the reward, whether the episode
				is done, and additional information.
		"""
		obs, reward, term, trunc, info = self.env.step(action)

		if self._last_observation is None:
			self._last_observation = obs
			for key in self.keys:
				self._last_observation[key] = np.concatenate([obs[key], [0.0]])
			return self._last_observation, reward, term, trunc, info

		to_return_obs = {}
		for name, signal in self.env._signals.items():  # type: ignore
			last_obs = self._last_observation[name]
			vec, last_seen = last_obs[:-1], last_obs[-1]
			renew = signal.renew(
				last_seen,
				period=self.curr_period,
				gamma_beta=self.curr_gamma_beta,
			)
			if renew:
				to_return_obs[name] = np.concatenate([obs[name], [0.0]])
			else:
				to_return_obs[name] = np.concatenate([vec, [last_seen + 1.0]])
			self._last_observation[name] = to_return_obs[name]

		if self._action_history:
			to_return_obs["action_history"] = np.concatenate(
				[
					self._last_observation["action_history"][
						self._action_dim :
					],
					np.array(action),
				]
			)
			self._last_observation["action_history"] = to_return_obs[
				"action_history"
			]

		if not self.remember:
			for name in self.env._signals.keys():  # type: ignore
				to_return_obs[name] = np.concatenate(
					[
						to_return_obs[name][:-1]
						if to_return_obs[name][-1] < 0.5
						else np.zeros_like(to_return_obs[name][:-1]),
						0.0 if to_return_obs[name][-1] < 0.5 else 1.0,
					]
				)

		if self.include_flattened_obs:
			to_return_obs["observation"] = np.concatenate(
				[to_return_obs[key] for key in self.keys]
			)
			to_return_obs["observation"] = np.concatenate(
				[to_return_obs["observation"], to_return_obs["action_history"]]
			)

		return to_return_obs, reward, term, trunc, info

	def reset(
		self,
		*,
		seed: Optional[int] = None,
		options: Optional[dict] = None,
	) -> tuple:
		"""Reset the environment.

		Args:
			seed (Optional[int]): The seed to use for the environment.
			options (Optional[dict]): The options to pass to the
				environment.

		Returns:
			tuple: The new observation and additional information.
		"""
		obs, info = self.env.reset(seed=seed, options=options)

		self._last_observation = obs
		for key in self.keys:
			self._last_observation[key] = np.concatenate([obs[key], [0.0]])

		if self._action_history:
			self._last_observation["action_history"] = np.zeros(
				(self._action_dim * self._n,),
				dtype=self.env.action_space.dtype,
			)

		if self.include_flattened_obs:
			self._last_observation["observation"] = np.concatenate(
				[self._last_observation[key] for key in self.keys]
			)
			self._last_observation["observation"] = np.concatenate(
				[
					self._last_observation["observation"],
					self._last_observation["action_history"],
				]
			)

		if self.period_range is not None:
			self.curr_period = np.random.uniform(*self.period_range)
		if self.gamma_beta_range is not None:
			self.curr_gamma_beta = np.random.uniform(*self.gamma_beta_range)

		return self._last_observation, info
