import numpy as np


class ORBITAgent:
	"""Tabular robust Q-learning agent with regularized TV backups (original version)."""

	def __init__(self, env, est_type, est_div, lambda_val, step_size, n_episodes=1000, seed=0):
		assert est_type == "regularized"
		assert est_div == "TV"
		self.env = env
		self.est_div = est_div
		self.lambda_val = lambda_val
		self.step_size = step_size
		self.n_episodes = n_episodes

		self.nS = env.nS
		self.nA = env.nA
		self.nH = env.nH

		# Global NumPy RNG seeded externally; keep initialization deterministic per seed.
		np.random.seed(seed)

		self.cnt_SA = np.zeros((self.nS, self.nA))
		self.cnt_SAS = np.zeros((self.nS, self.nA, self.nS))
		self.Pk = np.zeros((self.nS, self.nA, self.nS))

		self.Qk = np.zeros((self.nH, self.nS, self.nA))
		self.Vk = np.zeros((self.nH + 1, self.nS))

		self.Rk = np.zeros((self.nS, self.nA))
		self.rk = np.zeros((self.nS, self.nA))
		self.pi = np.array(
			[
				[np.random.randint(0, self.nA) for _ in range(self.nS)]
				for _ in range(self.nH)
			]
		)

	def _explore_episode(self):
		h, s = 0, self.env.reset()
		while h < self.nH:
			a = self.pi[h][s]
			s_next, r = self.env.step(a)
			self.cnt_SA[s][a] += 1
			self.cnt_SAS[s][a][s_next] += 1
			self.Rk[s][a] += r
			s = s_next
			h += 1

	def _calculate_Pk(self):
		for s in range(self.nS):
			for a in range(self.nA):
				if not self.cnt_SA[s][a]:
					self.Pk[s][a] = np.repeat(1 / self.nS, self.nS)
					self.rk[s][a] = 0
				else:
					self.Pk[s][a] = self.cnt_SAS[s][a] / self.cnt_SA[s][a]
					self.rk[s][a] = self.Rk[s][a] / self.cnt_SA[s][a]

	def _estimate_Qk(self, P, r, V, N, robust=True):
		rho = self.lambda_val
		bonus = self.step_size / np.sqrt(np.maximum(N, 1))
		args = np.where(P > 0)
		P, V = P[args], V[args]
		if len(P) == 1:
			return np.minimum(r + V[0] + bonus, 3)
		if not robust:
			return np.minimum(r + np.sum(P * V) + bonus, 3)

		def fun():
			return np.sum(P * np.maximum(np.min(V) + rho - V, 0))

		Q = r - fun() + np.min(V) + rho + bonus
		return np.minimum(np.maximum(Q, r + bonus), 3)

	def _plan(self):
		self.Vk[self.nH] = 0
		for h in range(self.nH - 1, -1, -1):
			for s in range(self.nS):
				for a in range(self.nA):
					robust = h == 0
					self.Qk[h][s][a] = self._estimate_Qk(
						self.Pk[s][a], self.rk[s][a], self.Vk[h + 1], self.cnt_SA[s][a], robust=robust
					)
				best = np.max(self.Qk[h][s])
				idxs = np.where(self.Qk[h][s] == best)[0]
				self.pi[h][s] = np.random.choice(idxs)
				self.Vk[h][s] = self.Qk[h][s][self.pi[h][s]]

	def train(self):
		for _ in range(self.n_episodes):
			self._explore_episode()
			self._calculate_Pk()
			self._plan()
		return self.pi.copy()


def evaluate_policy(env, pi, perturb=0.0, episodes=500):
	"""Monte Carlo evaluation of a fixed policy."""
	total_return = 0
	for _ in range(episodes):
		h, s = 0, env.reset()
		while h < env.nH:
			a = pi[h][s]
			s_next, r = env.step(a, perturb)
			total_return += r
			s = s_next
			h += 1
	return total_return / episodes
