import numpy as np
from tqdm import tqdm

np.random.seed(123)

class TargetSet:

	def __init__(self, shape, n, info):
		self.shape = shape
		self.n = n
		if shape == "arc":
			self.r = info["r"]
		else:
			raise NotImplementedError

	def maximize(self, alpha):
		assert len(alpha) == self.n
		if self.shape == "arc":
			if np.max(alpha) <= 0:
				return np.zeros_like(alpha)
			else:
				x = np.maximum(alpha, 0)
				return x / np.linalg.norm(x) * self.r
		else:
			raise NotImplementedError

	def eval(self, v):
		assert len(v) == self.n
		if self.shape == "arc":
			assert np.min(v) >= 0
			return np.maximum(np.linalg.norm(v) - self.r, 0)
		else:
			raise NotImplementedError

class PrefFunc:

	def __init__(self, typ, n, info):
		self.typ = typ
		self.n = n
		if typ == "poly":
			self.c = info["c"]
			self.p = info["p"]
		elif typ == "qua":
			self.A = info["A"]
			self.b = info["b"]
			self.c = info["c"]
			self.A_inv = np.linalg.inv(self.A)
		elif typ == 'comb':
			self.f1 = info["f1"]
			self.f2 = info["f2"]
		else:
			raise NotImplementedError

	def eval(self, x):
		if self.typ == "poly":
			g = 0
			for i in range(self.n):
				g += self.c * x[i] ** self.p[i]
			return g
		elif self.typ == "qua":
			return 1/2 * x.T @ self.A @ x + self.b.T @ x + self.c
		elif self.typ == "comb":
			return self.f1.eval(x[0:self.n//2]) + self.f2.eval(x[self.n//2:])
		else:
			raise NotImplementedError

	def de_conjug(self, x):
		if self.typ == "poly":
			d = 0
			for i in range(self.n):
				qi = self.p[i] / (self.p[i] - 1)
				d += x[i] ** qi / (self.c ** (qi - 1) * self.p[i] ** (qi - 1))/ qi
			return d
		elif self.typ == "qua":
			return 1/2 * (x-self.b).T @ self.A_inv @ (x-self.b) - self.c
		elif self.typ == "comb":
			return self.f1.de_conjug(x[0:self.n//2]) + self.f2.de_conjug(x[self.n//2:])
		else:
			raise NotImplementedError

class Model:

	def __init__(self, hor, card_s, card_a, dim_c):
		self.hor = hor
		self.card_s = card_s
		self.card_a = card_a
		self.dim_c = dim_c

		self.init_s = np.random.choice(np.arange(card_s))

		self.p = np.random.rand(hor, card_s, card_a, card_s)
		for h in range(hor):
			for s in range(card_s):
				for a in range(card_a):
					self.p[h, s, a] /= np.sum(self.p[h, s, a])

		self.c = np.random.rand(dim_c, hor, card_s, card_a) # deterministic for simplicity

		# optimal:  always take (opt_s, opt_a)
		opt_s = card_s - 1
		opt_a = card_a - 1
		self.init_s = opt_s
		self.p[:, opt_s, opt_a] = 0
		self.p[:, opt_s, opt_a, opt_s] = 1
		self.c[:, :, opt_s, opt_a] = 0
	
	def eval(self, pi):
		prob_s = np.zeros([self.card_s])
		prob_s[self.init_s] = 1.0
		sum_c = np.zeros(self.dim_c)
		for h in range(self.hor):
			prob_s_next = np.zeros([self.card_s])
			for s in range(card_s):
				for a in range(card_a):
					for s_next in range(card_s):
						prob_s_next[s_next] += prob_s[s] * pi[h, s, a] * self.p[h, s, a, s_next]
					for i in range(self.dim_c):
						sum_c[i] += prob_s[s] * pi[h, s, a] * self.c[i, h, s, a]
			prob_s = prob_s_next
		return sum_c

	def sample(self, policy, num_data):
		if policy is None:
			# assume a uniform policy
			dataset = []
			for _ in tqdm(range(num_data)):
				s = self.init_s
				traj = []
				for h in range(hor):
					a = np.random.randint(self.card_a)
					c = self.c[:, h, s, a]
					s_next = np.random.choice(np.arange(card_s), p=self.p[h, s, a])
					traj.append((s, a, c, s_next))
					s = s_next
				dataset.append(traj)
			return dataset
		else:
			raise NotImplementedError

def pess_planning(theta, dim_c, p, c, uq_p, uq_c, hor, card_s, card_a):
	pi = np.zeros([hor, card_s, card_a])
	q_bar = np.zeros([dim_c, hor, card_s, card_a])
	q = np.zeros([dim_c, hor, card_s, card_a])
	v = np.zeros([dim_c, hor+1, card_s])
	for h in range(hor-1, -1, -1):
		for s in range(card_s):
			for a in range(card_a):
				uq = uq_p[h, s, a] + uq_c[:, h, s, a]
				q_bar[:, h, s, a] = c[:, h, s, a] + uq
				for s_next in range(card_s):
					q_bar[:, h, s, a] += p[h, s, a, s_next] * v[:, h+1, s_next]

		q[:, h] = np.clip(q_bar[:, h], 0, hor-h+1)
			
		tmp_q = np.zeros([card_s, card_a])
		for s in range(card_s):
			for a in range(card_a):
				tmp_q[s, a] = np.dot(q[:, h, s, a], theta)

		for s in range(card_s):
			pos = np.argmin(tmp_q[s])
			pi[h, s, pos] = 1
			v[:, h, s] = q[:, h, s, pos]

	return pi, q, v[:, 0:-1, :]

def dual_update(alpha, beta, eta, v, target_set, pref_func):

	alpha += eta * (v - target_set.maximize(alpha))
	if np.linalg.norm(alpha) > 1:
		alpha /= np.linalg.norm(alpha)
	
	beta += eta * (v - pref_func.de_conjug(beta))
	if np.linalg.norm(beta) > 1:
		beta /= np.linalg.norm(beta)
	
	return alpha, beta

def gen_pos_def_mat(siz, eig_val_lim):
	A = np.random.rand(siz, siz)
	B = A @ A.T
	max_eig_val = np.max(np.linalg.eig(B)[0])
	return B / max_eig_val * eig_val_lim

hor = 5
card_s = 5
card_a = 5
dim_c = 6
num_data = 50000

num_iter = 100
delta = 0.9
eta = 0.01

m = Model(hor, card_s, card_a, dim_c)
target_set = TargetSet(shape="arc", n=dim_c, info={"r":1})

# pref_func = PrefFunc(typ="qua", n=dim_c, info={"A":gen_pos_def_mat(dim_c, 1/(2*hor*dim_c**(1/2))), "b":np.zeros(dim_c), "c":0})
pref_func = PrefFunc(typ="poly", n=dim_c, info={"c":1/(2*hor**(2-1)*dim_c**(1/2)), "p":[2 for _ in range(dim_c)]})

# pref_func1 = PrefFunc(typ="qua", n=dim_c//2, info={"A":gen_pos_def_mat(dim_c//2, 1/(2*hor*dim_c**(1/2))), "b":np.zeros(dim_c//2), "c":0})
# pref_func2 = PrefFunc(typ="poly", n=dim_c//2, info={"c":1/(2*hor**(2-1)*dim_c**(1/2)), "p":[2 for _ in range(dim_c//2)]})
# pref_func = PrefFunc(typ="comb", n=dim_c, info={"f1":pref_func1, "f2":pref_func2})

print("=== generate dataset ===")

dataset = m.sample(None, num_data)

print("=== start PEDI ===")

# PEDI
alpha = np.random.randn(dim_c)
alpha /= np.linalg.norm(alpha)
beta = np.random.randn(dim_c)
beta /= np.linalg.norm(beta)
nu = 3
theta = nu * alpha + beta

tot = np.zeros([hor, card_s, card_a])
vis = np.zeros([hor, card_s, card_a, card_s])
sum_c = np.zeros([dim_c, hor, card_s, card_a])

est_p = np.zeros([hor, card_s, card_a, card_s])
est_c = np.zeros([dim_c, hor, card_s, card_a])

uq_p = np.zeros([hor, card_s, card_a])
uq_c = np.zeros([dim_c, hor, card_s, card_a])

for traj in dataset:
	for h, tran in enumerate(traj):
		s, a, c, s_next = tran
		vis[h, s, a, s_next] += 1
		tot[h, s, a] += 1
		sum_c[:, h, s, a] += c

tot = np.maximum(tot, 1)
for h in range(hor):
	for s in range(card_s):
		for a in range(card_a):

			est_c[:, h, s, a] = sum_c[:, h, s, a] / tot[h, s, a]
			uq_c[:, h, s, a] = np.sqrt(np.log(2 * dim_c * hor * card_s * card_a / delta) / 2 / tot[h, s, a])

			for s_next in range(card_s):
				est_p[h, s, a, s_next] = vis[h, s, a, s_next] / tot[h, s, a]
			uq_p[h, s, a] = hor * card_s * np.sqrt(np.log(2 * hor * card_s * card_a * card_s / delta) / 2 / tot[h, s, a])

# print(f"min cover: {np.min(tot[1:])}, ave cover: {np.mean(tot[1:])}, max cover: {np.max(tot[1:])}")
# print(f"p min uq: {np.min(uq_p[1:])}, ave uq: {np.mean(uq_p[1:])}, max uq: {np.max(uq_p[1:])}")
# print(f"c min uq: {np.min(uq_c[1:])}, ave uq: {np.mean(uq_c[1:])}, max uq: {np.max(uq_c[1:])}")

pi_pool = []

for t in range(num_iter):
	pi, q, v = pess_planning(theta, dim_c, est_p, est_c, uq_p, uq_c, hor, card_s, card_a)
	alpha, beta = dual_update(alpha, beta, eta, v[:, 0, m.init_s], target_set, pref_func)
	theta = nu * alpha + beta
	pi_pool.append(pi)

	print("%3d" % (t+1), end=" ")
	cost = m.eval(pi)

	print("%.3f %.3f" % (pref_func.eval(cost), target_set.eval(cost)))

print("==========")
ave_cost = [m.eval(pi) for pi in pi_pool]
print("Cost:", np.mean([pref_func.eval(cost) for cost in ave_cost]))
print("Violation:", target_set.eval(np.mean(ave_cost, axis=0)))