import numpy as np
from scipy.special import logsumexp
from scipy.special import expit
from copy import copy

class EpsGreedyPolicy():

	def __init__(self, rng, nstates, noptions, epsilon):
		self.rng = rng
		self.nstates = nstates
		self.noptions = noptions
		self.epsilon = epsilon
		self.Q_Omega_table = np.zeros((nstates, noptions))

	def Q_Omega(self, state, option=None):
		if option is None:
			return self.Q_Omega_table[state,:]
		else:
			return self.Q_Omega_table[state, option]

	def sample(self, state):
		if self.rng.uniform() < self.epsilon:
			return int(self.rng.randint(self.noptions))
		else:
			return int(np.argmax(self.Q_Omega(state)))
	
	def greedy_option(self,state):
		return int(np.argmax(self.Q_Omega(state)))


class SoftmaxPolicy():

	def __init__(self, rng, lr, nstates, nactions, temperature=1.0):
		self.rng = rng
		self.lr = lr
		self.nstates = nstates
		self.nactions = nactions
		self.temperature = temperature
		self.weights = np.zeros((nstates, nactions))

	def Q_U(self, state, action=None):
		if action is None:
			return self.weights[state,:]
		else:
			return self.weights[state, action]

	# probability mass function (using softmax)
	def pmf(self, state):
		exponent = self.Q_U(state) / self.temperature
		prob = np.exp(exponent - logsumexp(exponent))
		prob /= prob.sum()
		return prob

	# get P(s, a_prob)
	def pmf_(self):
		exponent = self.weights / self.temperature
		prob = np.exp(exponent - logsumexp(exponent))
		prob /= np.sum(prob, axis=1, keepdims=True)
		return prob

	def sample(self, state):
		return int(self.rng.choice(self.nactions, p=self.pmf(state)))

	def gradient(self):
		pass

	def update(self, state, action, Q_U, KL_grid=0):
		actions_pmf = self.pmf(state)
		self.weights[state, :] -= self.lr * actions_pmf * Q_U
		self.weights[state, action] += self.lr * Q_U
		self.weights[state, :] += self.lr * KL_grid


class SigmoidTermination():

	def __init__(self, rng, lr, nstates):
		self.rng = rng
		self.lr = lr
		self.nstates = nstates
		self.weights = np.zeros((nstates,))

	def pmf(self, state):
		return expit(self.weights[state])

	def sample(self, state):
		return int(self.rng.uniform() < self.pmf(state))

	def gradient(self, state):
		return self.pmf(state) * (1.0 - self.pmf(state)), state

	def update(self, state, advantage):
		magnitude, direction = self.gradient(state)
		self.weights[direction] -= self.lr * magnitude * advantage


class Critic():

	def __init__(self, lr, gamma, option_policies, delib, discount, Q_Omega_table, nstates, noptions, nactions):
		self.lr = lr
		self.gamma = gamma
		self.option_policies = option_policies
		self.delib = delib
		self.discount = discount
		self.Q_Omega_table = Q_Omega_table
		self.Q_U_table = np.zeros((nstates, noptions, nactions))

	def cache(self, state, option, action):
		self.last_state = state
		self.last_option = option
		self.last_action = action
		self.last_Q_Omega = self.Q_Omega(state, option)

	def Q_Omega(self, state, option=None):
		if option is None:
			return self.Q_Omega_table[state, :]
		else:
			return self.Q_Omega_table[state, option]

	def Q_U(self, state, option, action):
		return self.Q_U_table[state, option, action]

	def A_Omega(self, state, option=None):
		advantage = self.Q_Omega(state) - np.max(self.Q_Omega(state)) + self.delib

		if option is None:
			return advantage
		else:
			return advantage[option]

	def KL_grid(self,state):
		current_option_policy = self.option_policies[self.last_option]
		other_option_policy = self.option_policies.copy()
		other_option_policy.remove(current_option_policy)
		kl_gird = 0
		for oop in other_option_policy:
			kl_gird += -np.sum(oop.pmf(state)*(1-current_option_policy.pmf(state)))
		# print(self.gamma*kl_gird)
		return self.gamma*kl_gird

	def update_Qs(self, state, option, action, reward, done, terminations, new_term_table=None):
		# One step target for Q_Omega
		target = reward

		if not done:
			if new_term_table is None:
				beta_omega = terminations[self.last_option].pmf(state)
			else:
				beta_omega = new_term_table[self.last_option, state]

			target += self.discount * ((1.0 - beta_omega)*self.Q_Omega(state, self.last_option) + \
						beta_omega*np.max(self.Q_Omega(state)))

		# Difference update
		tderror_Q_Omega = target - self.last_Q_Omega
		self.Q_Omega_table[self.last_state, self.last_option] += self.lr * tderror_Q_Omega

		tderror_Q_U = target - self.Q_U(self.last_state, self.last_option, self.last_action)
		self.Q_U_table[self.last_state, self.last_option, self.last_action] += self.lr * tderror_Q_U

		# Cache
		self.last_state = state
		self.last_option = option
		self.last_action = action
		if not done:
			self.last_Q_Omega = self.Q_Omega(state, option)

class Low_Critic():
	def __init__(self, lr, gamma, option_policies, delib, discount, policy_over_options, \
	      			terminations, nstates, noptions, nactions):
		self.lr = lr
		self.gamma = gamma
		self.option_policies = option_policies
		self.delib = delib
		self.discount = discount
		self.Q_Omega_table = np.zeros((nstates, noptions))
		self.Q_U_table = np.zeros((nstates, noptions, nactions))
		self.policy_over_options = policy_over_options
		self.terminations = terminations

	def cache(self, state, option, action):
		self.last_state = state
		self.last_option = option
		self.last_action = action
		self.last_Q_Omega = self.Q_Omega(state, option)
		self.last_Q_U = self.Q_U(state, option, action)

	def Q_Omega(self, state, option=None):
		if option is None:
			return self.Q_Omega_table[state, :]
		else:
			return self.Q_Omega_table[state, option]

	def Q_U(self, state, option, action):
		return self.Q_U_table[state, option, action]

	def A_Omega(self, state, option=None):
		advantage = self.Q_Omega(state) - self.Q_Omega(state,self.policy_over_options.greedy_option(state)) + self.delib

		if option is None:
			return advantage
		else:
			return advantage[option]

	def update_Qs(self, state, option, action, reward, done):
		# One step target for Q_Omega
		target = reward

		if not done:
			beta_omega = self.terminations[self.last_option].pmf(state)
			# print("beta_omega: ", beta_omega)
			target += self.discount * ((1.0 - beta_omega)*self.Q_Omega(state, self.last_option) + \
						beta_omega*self.Q_Omega(state,self.policy_over_options.greedy_option(state)))

		# Difference update
		tderror_Q_Omega = target - self.last_Q_Omega
		self.Q_Omega_table[self.last_state, self.last_option] += self.lr * tderror_Q_Omega
		tderror_Q_U = target - self.last_Q_U 
		self.Q_U_table[self.last_state, self.last_option, self.last_action] += self.lr * tderror_Q_U

		# Cache
		self.last_state = state
		self.last_option = option
		self.last_action = action
		if not done:
			self.last_Q_Omega = self.Q_Omega(state, option)
			self.last_Q_U = self.Q_U(state, option, action)

class High_layer():

	def __init__(self, lr, option_policies, discount, policy_over_options, clip):
		self.lr = lr
		self.option_policies = option_policies
		self.discount = discount
		self.Q_Omega_table = policy_over_options.Q_Omega_table
		self.policy_over_options = policy_over_options
		self.clip = clip

	def cache(self, state, option, action):
		self.last_state = state
		self.last_option = option
		self.last_action = action
		self.last_Q_Omega = self.Q_Omega(state, option)

	def Q_Omega(self, state, option=None):
		if option is None:
			return self.Q_Omega_table[state, :]
		else:
			return self.Q_Omega_table[state, option]

	def hi_q_update(self, state, option, action, reward, done):

		# Q-learning update
		target = reward
		if not done:
			target = reward + self.discount*np.max(self.Q_Omega_table[state,option])
		self.Q_Omega_table[self.last_state, self.last_option] += self.lr*(target - self.last_Q_Omega)

		# Cache
		self.last_state = state
		self.last_option = option
		self.last_action = action
		if not done:
			self.last_Q_Omega = self.Q_Omega(state, option)

def copy_term(option_terminations, nstates):
	op_num = len(option_terminations)
	copy_weight = np.zeros((op_num, nstates))
	for index, t in enumerate(option_terminations):
		copy_weight[index] = copy(t.weights)
	copy_prob = expit(copy_weight)
	return copy_prob

def Q_cal_(option_policies, op_num, nstates, nactions):
	Q_so = np.zeros((nstates, op_num))
	for o, p in enumerate(option_policies):
		Q_so[:,o] = np.sum(p.weights*p.pmf_())
	return Q_so