from typing import Optional, Tuple

import gymnasium as gym
import numpy as np

from src.async_rl.amdp import wrapper as awrap
from src.async_rl.amdp.signal import Signal


def make_halfcheetah_env(
	action_history: bool = False,
	period_range: Optional[Tuple[float, float]] = None,
	gamma_beta_range: Optional[Tuple[float, float]] = None,
	flatten_obs: bool = True,
	remember: bool = True,
) -> gym.Env:
	"""Creates an asynchronous HalfCheetah environment.

	The environment is defined so that to each hinge joint corresponds a
	signal. The average periodicity of every signal is `1` time step.

	Args:
		action_history (bool, optional): Whether to include the action
			history in the observation. Defaults to `False`.
		period_range (Optional[Tuple[float, float]], optional): The
			range from which to sample signals' renewal periods at each
			rollout. Defaults to `None`, which means the default period
			of each signal is used.
		gamma_beta_range (Optional[Tuple[float, float]], optional): The
			range from which to sample signals' shape parameter at each
			rollout. Defaults to `None`, which means the default shape
			parameter of each signal is used.
		flatten_obs (bool, optional): Whether to flatten the observation
			dictionary before returning it. Defaults to `True`. Notice
			that this does not discards the individual observations of
			each signal, but rather includes the flattened aggreagation
			in a new key called `"observation"`.
		remember (bool, optional): Whether to remember the most recent
			observations of unseen signals, together with the time since
			they were last seen. Defaults to `True`. If `False`, the
			observations of unseen signals are set to `0` and the last
			coordinate in the observation is used as a flag to indicate
			whether the signal was seen in the last time step.

	Returns:
		gym.Env: The asynchronous HalfCheetah environment.
	"""
	env = gym.make("HalfCheetah-v4")

	front_tip: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(5,)),
		lambda obs: np.concatenate([obs[:2], obs[8:11]]),
		max_interval=4.0,
		name="front_tip",
	)
	back_thigh: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(2,)),
		lambda obs: np.array([obs[2], obs[11]]),
		max_interval=4.0,
		name="back_thigh",
	)
	back_shin: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(2,)),
		lambda obs: np.array([obs[3], obs[12]]),
		max_interval=4.0,
		name="back_shin",
	)
	back_foot: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(2,)),
		lambda obs: np.array([obs[4], obs[13]]),
		max_interval=4.0,
		name="back_foot",
	)
	front_thigh: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(2,)),
		lambda obs: np.array([obs[5], obs[14]]),
		max_interval=4.0,
		name="front_thigh",
	)
	front_shin: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(2,)),
		lambda obs: np.array([obs[6], obs[15]]),
		max_interval=4.0,
		name="front_shin",
	)
	front_foot: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(2,)),
		lambda obs: np.array([obs[7], obs[16]]),
		max_interval=4.0,
		name="front_foot",
	)

	env = awrap.SignaledObservation(
		env,
		[
			front_tip,
			back_thigh,
			back_shin,
			back_foot,
			front_thigh,
			front_shin,
			front_foot,
		],
	)
	env = awrap.AsynchronousSignals(
		env,
		action_history=action_history,
		period_range=period_range,
		gamma_beta_range=gamma_beta_range,
		remember=remember,
		include_flattened_obs=flatten_obs,
	)

	return env


def make_hopper_env(
	action_history: bool = False,
	period_range: Optional[Tuple[float, float]] = None,
	gamma_beta_range: Optional[Tuple[float, float]] = None,
	flatten_obs: bool = True,
	remember: bool = True,
) -> gym.Env:
	"""Creates an asynchronous Hopper environment.

	The environment is defined so that to each hinge joint corresponds a
	signal. The average periodicity of every signal is `1` time step.

	Args:
		action_history (bool, optional): Whether to include the action
			history in the observation. Defaults to `False`.
		period_range (Optional[Tuple[float, float]], optional): The
			range from which to sample signals' renewal periods at each
			rollout. Defaults to `None`, which means the default period
			of each signal is used.
		gamma_beta_range (Optional[Tuple[float, float]], optional): The
			range from which to sample signals' shape parameter at each
			rollout. Defaults to `None`, which means the default shape
			parameter of each signal is used.
		flatten_obs (bool, optional): Whether to flatten the observation
			dictionary before returning it. Defaults to `True`. Notice
			that this does not discards the individual observations of
			each signal, but rather includes the flattened aggreagation
			in a new key called `"observation"`.
		remember (bool, optional): Whether to remember the most recent
			observations of unseen signals, together with the time since
			they were last seen. Defaults to `True`. If `False`, the
			observations of unseen signals are set to `0` and the last
			coordinate in the observation is used as a flag to indicate
			whether the signal was seen in the last time step.

	Returns:
		gym.Env: The asynchronous Hopper environment.
	"""
	env = gym.make("Hopper-v4", terminate_when_unhealthy=False)

	torso: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(5,)),
		lambda obs: np.concatenate([obs[:2], obs[5:8]]),
		max_interval=4.0,
		name="torso",
	)
	thigh: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(2,)),
		lambda obs: np.array([obs[2], obs[8]]),
		max_interval=4.0,
		name="thigh",
	)
	leg: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(2,)),
		lambda obs: np.array([obs[3], obs[9]]),
		max_interval=4.0,
		name="leg",
	)
	foot: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(2,)),
		lambda obs: np.array([obs[4], obs[10]]),
		max_interval=4.0,
		name="foot",
	)

	env = awrap.SignaledObservation(
		env,
		[torso, thigh, leg, foot],
	)
	env = awrap.AsynchronousSignals(
		env,
		action_history=action_history,
		period_range=period_range,
		gamma_beta_range=gamma_beta_range,
		remember=remember,
		include_flattened_obs=flatten_obs,
	)

	return env


def make_walker2d_env(
	action_history: bool = False,
	period_range: Optional[Tuple[float, float]] = None,
	gamma_beta_range: Optional[Tuple[float, float]] = None,
	flatten_obs: bool = True,
	remember: bool = True,
) -> gym.Env:
	"""Creates an asynchronous Walker2D environment.

	The environment is defined so that to each hinge joint corresponds a
	signal. The average periodicity of every signal is `1` time step.

	Args:
		action_history (bool, optional): Whether to include the action
			history in the observation. Defaults to `False`.
		period_range (Optional[Tuple[float, float]], optional): The
			range from which to sample signals' renewal periods at each
			rollout. Defaults to `None`, which means the default period
			of each signal is used.
		gamma_beta_range (Optional[Tuple[float, float]], optional): The
			range from which to sample signals' shape parameter at each
			rollout. Defaults to `None`, which means the default shape
			parameter of each signal is used.
		flatten_obs (bool, optional): Whether to flatten the observation
			dictionary before returning it. Defaults to `True`. Notice
			that this does not discards the individual observations of
			each signal, but rather includes the flattened aggreagation
			in a new key called `"observation"`.
		remember (bool, optional): Whether to remember the most recent
			observations of unseen signals, together with the time since
			they were last seen. Defaults to `True`. If `False`, the
			observations of unseen signals are set to `0` and the last
			coordinate in the observation is used as a flag to indicate
			whether the signal was seen in the last time step.

	Returns:
		gym.Env: The asynchronous Walker2D environment.
	"""
	env = gym.make("Walker2d-v4", terminate_when_unhealthy=False)

	torso: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(5,)),
		lambda obs: np.concatenate([obs[:2], obs[8:11]]),
		max_interval=4.0,
		name="torso",
	)
	right_thigh: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(2,)),
		lambda obs: np.array([obs[2], obs[11]]),
		max_interval=4.0,
		name="right_thigh",
	)
	right_leg: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(2,)),
		lambda obs: np.array([obs[3], obs[12]]),
		max_interval=4.0,
		name="right_leg",
	)
	right_foot: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(2,)),
		lambda obs: np.array([obs[4], obs[13]]),
		max_interval=4.0,
		name="right_foot",
	)
	left_thigh: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(2,)),
		lambda obs: np.array([obs[5], obs[14]]),
		max_interval=4.0,
		name="left_thigh",
	)
	left_leg: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(2,)),
		lambda obs: np.array([obs[6], obs[15]]),
		max_interval=4.0,
		name="left_leg",
	)
	left_foot: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(2,)),
		lambda obs: np.array([obs[7], obs[16]]),
		max_interval=4.0,
		name="left_foot",
	)

	env = awrap.SignaledObservation(
		env,
		[
			torso,
			left_thigh,
			left_leg,
			left_foot,
			right_thigh,
			right_leg,
			right_foot,
		],
	)
	env = awrap.AsynchronousSignals(
		env,
		action_history=action_history,
		period_range=period_range,
		gamma_beta_range=gamma_beta_range,
		remember=remember,
		include_flattened_obs=flatten_obs,
	)

	return env


def make_reacher_env(
	action_history: bool = False,
	period_range: Optional[Tuple[float, float]] = None,
	gamma_beta_range: Optional[Tuple[float, float]] = None,
	flatten_obs: bool = True,
	remember: bool = True,
) -> gym.Env:
	"""Creates an asynchronous Reacher environment.

	The environment is defined so that to each hinge joint corresponds a
	signal. The average periodicity of every signal is `1` time step.

	Args:
		action_history (bool, optional): Whether to include the action
			history in the observation. Defaults to `False`.
		period_range (Optional[Tuple[float, float]], optional): The
			range from which to sample signals' renewal periods at each
			rollout. Defaults to `None`, which means the default period
			of each signal is used.
		gamma_beta_range (Optional[Tuple[float, float]], optional): The
			range from which to sample signals' shape parameter at each
			rollout. Defaults to `None`, which means the default shape
			parameter of each signal is used.
		flatten_obs (bool, optional): Whether to flatten the observation
			dictionary before returning it. Defaults to `True`. Notice
			that this does not discards the individual observations of
			each signal, but rather includes the flattened aggreagation
			in a new key called `"observation"`.
		remember (bool, optional): Whether to remember the most recent
			observations of unseen signals, together with the time since
			they were last seen. Defaults to `True`. If `False`, the
			observations of unseen signals are set to `0` and the last
			coordinate in the observation is used as a flag to indicate
			whether the signal was seen in the last time step.

	Returns:
		gym.Env: The asynchronous Reacher environment.
	"""
	env = gym.make("Reacher-v4")

	first_arm: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(3,)),
		lambda obs: np.array([obs[0], obs[2], obs[6]]),
		max_interval=4.0,
		name="first_arm",
	)
	second_arm: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(3,)),
		lambda obs: np.array([obs[1], obs[3], obs[7]]),
		max_interval=4.0,
		name="second_arm",
	)
	target: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(2,)),
		lambda obs: np.array([obs[4], obs[5]]),
		max_interval=4.0,
		name="target",
	)
	diff: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(2,)),
		lambda obs: np.array([obs[8], obs[9]]),
		max_interval=4.0,
		name="diff",
	)

	env = awrap.SignaledObservation(
		env,
		[first_arm, second_arm, target, diff],
	)
	env = awrap.AsynchronousSignals(
		env,
		action_history=action_history,
		period_range=period_range,
		gamma_beta_range=gamma_beta_range,
		remember=remember,
		include_flattened_obs=flatten_obs,
	)

	return env


def make_ant_env(
	action_history: bool = False,
	period_range: Optional[Tuple[float, float]] = None,
	gamma_beta_range: Optional[Tuple[float, float]] = None,
	flatten_obs: bool = True,
	remember: bool = True,
) -> gym.Env:
	"""Creates an asynchronous Ant environment.

	The environment is defined so that to each hinge joint corresponds a
	signal. The average periodicity of every signal is `1` time step.

	Args:
		action_history (bool, optional): Whether to include the action
			history in the observation. Defaults to `False`.
		period_range (Optional[Tuple[float, float]], optional): The
			range from which to sample signals' renewal periods at each
			rollout. Defaults to `None`, which means the default period
			of each signal is used.
		gamma_beta_range (Optional[Tuple[float, float]], optional): The
			range from which to sample signals' shape parameter at each
			rollout. Defaults to `None`, which means the default shape
			parameter of each signal is used.
		flatten_obs (bool, optional): Whether to flatten the observation
			dictionary before returning it. Defaults to `True`. Notice
			that this does not discards the individual observations of
			each signal, but rather includes the flattened aggreagation
			in a new key called `"observation"`.
		remember (bool, optional): Whether to remember the most recent
			observations of unseen signals, together with the time since
			they were last seen. Defaults to `True`. If `False`, the
			observations of unseen signals are set to `0` and the last
			coordinate in the observation is used as a flag to indicate
			whether the signal was seen in the last time step.

	Returns:
		gym.Env: The asynchronous Ant environment.
	"""
	env = gym.make("Ant-v4", terminate_when_unhealthy=False)

	torso: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(11,)),
		lambda obs: np.concatenate([obs[:5], obs[13:19]]),
		max_interval=4.0,
		name="torso",
	)
	front_left_leg: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(4,)),
		lambda obs: np.concatenate([obs[5:7], obs[19:21]]),
		max_interval=4.0,
		name="front_left_leg",
	)
	front_right_leg: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(4,)),
		lambda obs: np.concatenate([obs[7:9], obs[21:23]]),
		max_interval=4.0,
		name="front_right_leg",
	)
	back_left_leg: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(4,)),
		lambda obs: np.concatenate([obs[9:11], obs[23:25]]),
		max_interval=4.0,
		name="back_left_leg",
	)
	back_right_leg: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(4,)),
		lambda obs: np.concatenate([obs[11:13], obs[25:27]]),
		max_interval=4.0,
		name="back_right_leg",
	)

	env = awrap.SignaledObservation(
		env,
		[
			torso,
			front_left_leg,
			front_right_leg,
			back_left_leg,
			back_right_leg,
		],
	)
	env = awrap.AsynchronousSignals(
		env,
		action_history=action_history,
		period_range=period_range,
		gamma_beta_range=gamma_beta_range,
		remember=remember,
		include_flattened_obs=flatten_obs,
	)

	return env


def make_humanoid_env(
	action_history: bool = False,
	period_range: Optional[Tuple[float, float]] = None,
	gamma_beta_range: Optional[Tuple[float, float]] = None,
	flatten_obs: bool = True,
	remember: bool = True,
) -> gym.Env:
	"""Creates an asynchronous Humanoid environment.

	The environment is defined so that to each hinge joint corresponds a
	signal. The average periodicity of every signal is `1` time step.

	Args:
		action_history (bool, optional): Whether to include the action
			history in the observation. Defaults to `False`.
		period_range (Optional[Tuple[float, float]], optional): The
			range from which to sample signals' renewal periods at each
			rollout. Defaults to `None`, which means the default period
			of each signal is used.
		gamma_beta_range (Optional[Tuple[float, float]], optional): The
			range from which to sample signals' shape parameter at each
			rollout. Defaults to `None`, which means the default shape
			parameter of each signal is used.
		flatten_obs (bool, optional): Whether to flatten the observation
			dictionary before returning it. Defaults to `True`. Notice
			that this does not discards the individual observations of
			each signal, but rather includes the flattened aggreagation
			in a new key called `"observation"`.
		remember (bool, optional): Whether to remember the most recent
			observations of unseen signals, together with the time since
			they were last seen. Defaults to `True`. If `False`, the
			observations of unseen signals are set to `0` and the last
			coordinate in the observation is used as a flag to indicate
			whether the signal was seen in the last time step.

	Returns:
		gym.Env: The asynchronous Humanoid environment.
	"""
	env = gym.make("Humanoid-v4", terminate_when_unhealthy=False)

	torso: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(11,)),
		lambda obs: np.concatenate([obs[:5], obs[22:28]]),
		max_interval=4.0,
		name="torso",
	)
	upper_half: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(18,)),
		lambda obs: np.concatenate(
			[obs[5:8], obs[28:31], obs[16:22], obs[39:45]]
		),
		max_interval=4.0,
		name="upper_half",
	)
	lower_half: Signal[np.ndarray, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(16,)),
		lambda obs: np.concatenate([obs[8:16], obs[31:39]]),
		max_interval=4.0,
		name="lower_half",
	)

	env = awrap.SignaledObservation(
		env,
		[
			torso,
			upper_half,
			lower_half,
		],
	)
	env = awrap.AsynchronousSignals(
		env,
		action_history=action_history,
		period_range=period_range,
		gamma_beta_range=gamma_beta_range,
		remember=remember,
		include_flattened_obs=flatten_obs,
	)

	return env
