import numpy as np


class RobustQAgent:
	"""Model-free robust regularized Q-learning with UCB-Hoeffding (Algorithm 1)."""

	def __init__(self, env, lambda_val, c, n_episodes=1000, seed=0):
		self.env = env
		self.lambda_val = lambda_val
		self.c = c  # bonus scaling constant
		self.n_episodes = n_episodes
		self.warmup_episodes = 100

		self.nS = env.nS
		self.nA = env.nA
		self.nH = env.nH

		np.random.seed(seed)

		self.value_clip = min(1 + self.lambda_val, self.nH)

		# Per-step visitation counts and value estimates
		self.N_h = np.zeros((self.nH, self.nS, self.nA))
		# self.Q = np.full((self.nH, self.nS, self.nA), self.value_clip)
		self.Q = np.zeros((self.nH, self.nS, self.nA))
		#  self.V = np.full((self.nH + 1, self.nS), self.value_clip)
		self.V = np.zeros((self.nH + 1, self.nS))
		self.V[self.nH] = 0  # terminal value

		self.pi = np.array(
			[
				[np.random.randint(0, self.nA) for _ in range(self.nS)]
				for _ in range(self.nH)
			]
		)

	def _alpha(self, t):
		# alpha_t = (H + 1) / (H + t)
		return (self.nH + 1) / (self.nH + max(t, 1))

	def _bonus(self, t):
		# b_t = c * min{1+λ,H} * sqrt(H / t)
		return self.c / np.sqrt(max(t, 1))
	
	def train(self):
		for _ in range(self.n_episodes):
			s = self.env.reset()
			for h in range(self.nH):
				a = np.random.randint(self.nA)
				s_next, r = self.env.step(a)

				# Count and step-size
				self.N_h[h][s][a] += 1
				t = int(self.N_h[h][s][a])
				alpha_t = self._alpha(t)
				b_t = self._bonus(t)
				if h == 0:
					target = r + min(self.V[h + 1][s_next], self.lambda_val) + b_t
				else:
					target = r + self.V[h + 1][s_next] + b_t
				self.Q[h][s][a] = min((1 - alpha_t) * self.Q[h][s][a] + alpha_t * target, self.value_clip)
				self.V[h][s] = min(np.max(self.Q[h][s]), self.value_clip)
				s = s_next


		for _ in range(self.n_episodes):
			s = self.env.reset()
			for h in range(self.nH):
				a = self.pi[h][s]

				s_next, r = self.env.step(a)

				# Count and step-size
				self.N_h[h][s][a] += 1
				t = int(self.N_h[h][s][a])
				alpha_t = self._alpha(t)
				b_t = self._bonus(t)

				if h == 0:
					target = r + min(self.V[h + 1][s_next], self.lambda_val) + b_t
				else:
					target = r + self.V[h + 1][s_next] + b_t
				self.Q[h][s][a] = min((1 - alpha_t) * self.Q[h][s][a] + alpha_t * target, self.value_clip)
				self.V[h][s] = min(np.max(self.Q[h][s]), self.value_clip)
				# best_q = np.max(self.Q[h][s])
				# choices = np.where(self.Q[h][s] == best_q)[0]
				# self.pi[h][s] = np.random.choice(choices)
				self.pi[h][s] = np.argmax(self.Q[h][s])
				s = s_next
			


		# print('Q-func:', self.Q[0][0])
		# print('Policy (h=0, s=0):', self.pi[0][0])
		# print('Policy (h=1, s=1):', self.pi[1][1])
		return self.pi.copy()


def evaluate_policy(env, pi, perturb=0.0, episodes=500):
	"""Monte Carlo evaluation of a fixed policy."""
	total_return = 0
	for _ in range(episodes):
		h, s = 0, env.reset()
		while h < env.nH:
			a = pi[h][s]
			s_next, r = env.step(a, perturb)
			total_return += r
			s = s_next
			h += 1
	return total_return / episodes
