import numpy as np


class Differential_Q_Learning(object):
	def __init__(self, P, r, q, g, lr, b_policy):
		self.P = P
		self.r = r
		self.q = q
		self.g = g
		self.lr = lr
		self.h = 0
		self.f = np.zeros_like(q)
		self.curr_s = 0
		self.t = 0
		self.b_policy = b_policy
		
	def step(self):
		s = self.curr_s
		a = np.random.choice([0, 1], p=self.b_policy[s])
		next_s = np.random.choice(self.P.shape[2], p=self.P[s][a])
		delta = self.r[s, a] - self.g + np.max(self.q[next_s]) - self.q[s, a]
		self.q[s, a] = self.q[s, a] + self.lr * delta
		self.g = self.g + self.lr * delta
		self.curr_s = next_s
		self.t += 1
		return self.r[s, a]


class RVI_Q_Learning(object):
	def __init__(self, P, r, q, lr, f, b_policy):
		self.P = P
		self.r = r
		self.q = q
		self.lr = lr
		self.f = np.zeros_like(q)
		self.curr_s = 0
		self.f = f
		self.b_policy = b_policy
		
	def step(self):
		s = self.curr_s
		a = np.random.choice([0, 1], p=self.b_policy[s])
		next_s = np.random.choice(self.P.shape[2], p=self.P[s][a])
		delta = self.r[s, a] + np.max(self.q[next_s]) - self.q[s, a] - self.f(self.q)
		self.q[s, a] = self.q[s, a] + self.lr * delta
		self.curr_s = next_s
		return self.r[s, a]