import numpy as np


class SimulatedMDP:
	"""Toy tabular MDP used for robust Q-learning experiments."""

	def __init__(self, seed=None):
		self.nS = 5
		self.nA = 5
		self.nH = 3
		self._state = 0
		if seed is not None:
			np.random.seed(seed)

	def reset(self):
		self._state = 0
		return self._state

	def step(self, action, perturb=0.0):
		"""Takes an action and returns (next_state, reward)."""
		prob = [0] * 5
		match self._state:
			case 0:
				prob = [
					0,
					0.4 + action / 10,
					0,
					0.1 + perturb * (0.5 - action / 10),
					(1 - perturb) * (0.5 - action / 10),
				]
			case 1:
				prob = [0, 0, action / 10, 1 - action / 10, 0]
			case 2:
				prob = [0, 0, 0, 1 - action / 10, action / 10]
			case 3:
				prob = [0, 0, 0, 1, 0]
			case 4:
				prob = [0, 0, 0, 0, 1]

		reward = action / 20 if self._state not in [3, 4] else (self._state == 4)
		self._state = np.random.choice(range(0, 5), p=prob)
		return self._state, reward
