import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from learning_control import Differential_Q_Learning, RVI_Q_Learning

plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams['font.size'] = '24'
plt.rcParams['text.usetex'] = True
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False


def exp(P, r, num_runs, file_name, f, b_policy, r_bar, lr, init_q_diff, init_q_rvi, total_steps, plot_interval):
	fig, axs = plt.subplots(1, 2, figsize=(18, 8))
	axs[0].fill_between(np.array([0, 6]), np.array([1, 7]), np.array([-1, 5]), color='g', alpha=0.1)
	if file_name == 'communicating_mdp_value_dynamics.pdf':
		axs[0].plot(np.array([1, 2]), np.array([2, 1]), 'k', linewidth=2)
	axs[1].fill_between(np.array([0, 6]), np.array([1, 7]), np.array([-1, 5]), color='g', alpha=0.1)
	axs[1].plot(np.array([1, 3]), np.array([2, 2]), 'k', linewidth=2)
	for run in range(num_runs):
		results = []
		results_rvi = []
		q_diff = init_q_diff.copy()
		q_rvi = init_q_rvi.copy()
		cls_diff = Differential_Q_Learning(P, r, q_diff, r_bar, lr, b_policy)
		cls_rvi = RVI_Q_Learning(P, r, q_rvi, lr, f, b_policy)
		for i in range(total_steps):
			if i % plot_interval == 0:
				# print(i, cls_diff.q)
				# print(i, cls_rvi.q)
				if file_name == 'communicating_mdp_value_dynamics.pdf':
					results.append(cls_diff.q.max(1))
					results_rvi.append(cls_rvi.q.max(1))
				else:
					results.append(cls_diff.q[1:].max(1))
					results_rvi.append(cls_rvi.q[1:].max(1))
			cls_diff.step()
			cls_rvi.step()
		print("run: ", run, " ", results[-1], results_rvi[-1])
		data = np.array(results)
		data_rvi = np.array(results_rvi)
		axs[0].plot(data[:, 0], data[:, 1], ":", linewidth=2)
		axs[1].plot(data_rvi[:, 0], data_rvi[:, 1], ":", linewidth=2)
	axs[0].set_title('Differential Q-Learning')
	axs[1].set_title('RVI Q-Learning')
	axs[0].set_xlabel("max Q(1, )")
	axs[0].set_ylabel("max Q(2, )", rotation='horizontal', ha='right', va="top")
	axs[0].set_xlim(0, 6)
	axs[1].set_xlim(0, 6)
	axs[0].set_ylim(0, 6)
	axs[1].set_ylim(0, 6)
	axs[1].set_xlabel("max Q(1, )")
	axs[1].set_ylabel("max Q(2, )", rotation='horizontal', ha='right', va="top")
	plt.savefig(file_name, bbox_inches='tight')


def ref_action_value(q, ref_s, ref_a):
	return q[ref_s, ref_a]


def communicating_mdp_exp():
	num_states = 2
	P = np.array([
		[[1, 0], [0, 1]],
		[[0, 1], [1, 0]]]
	)
	r = np.array([[1, 0], [1, 0]])
	init_r_bar = -3
	init_q = np.zeros((num_states, 2))
	init_q_rvi = np.zeros((num_states, 2))
	lr = 0.1
	num_runs = 10
	b_policy = np.array([[0.8, 0.2], [0.8, 0.2]])
	total_steps = 1001
	plot_interval = 10
	file_name = "communicating_mdp_value_dynamics.pdf"
	exp(P, r, num_runs, file_name, lambda q:ref_action_value(q, 0, 1), b_policy, init_r_bar, lr, init_q, init_q_rvi, total_steps, plot_interval)


def weakly_communicating_mdp_exp():
	num_states = 3
	P = np.array([
		[[0.9, 0.0, 0.1], [0.9, 0.1, 0.0]],
		[[0, 1, 0], [0, 0, 1]],
		[[0, 0, 1], [0, 1, 0]]]
	)
	r = np.array([[-5, 5], [1, 0], [1, 0]])
	b_policy = np.array([[0.8, 0.2], [0.8, 0.2], [0.8, 0.2]])
	init_r_bar = -3
	init_q = np.zeros((num_states, 2))
	init_q_rvi = np.zeros((num_states, 2))
	lr = 0.1
	num_runs = 10
	total_steps = 1001
	plot_interval = 10
	file_name = "weakly_communicating_mdp_value_dynamics.pdf"
	exp(P, r, num_runs, file_name, lambda q:ref_action_value(q, 1, 1), b_policy, init_r_bar, lr, init_q, init_q_rvi, total_steps, plot_interval)


if __name__ == '__main__':
	np.random.seed(0)
	communicating_mdp_exp()
	weakly_communicating_mdp_exp()
