import numpy as np 
import mdptoolbox
import time


def get_index_of_budget_spent(B, C, arm_list):
	payment = 0
	ind = -1
	while payment < B and ind < len(arm_list):
		payment += C[arm_list[ind]]
		ind += 1

	if payment > B:
		ind -= 1

	return ind




def adjust_indexes(T, R, C, current_state, index_list, index_lb=-1, index_ub=1, gamma=0.95, tolerance=1e-1):

	N = T.shape[0]
	A = T.shape[2]

	adjusted_indexes = np.zeros((N,A-1))


	# get states of the current index
	indexes = np.zeros((N,A-1))
	for arm in range(N):
		indexes[arm] = index_list[arm, current_state[arm]]


	# include the blank to make indexing easier
	# sorted_index_lists = [[]]
	# sorted_index_lists_arm_inds = [[]]
	# for a in range(A-1):
	# 	# whittle indexes
	# 	sorted_index_lists.append(np.array(np.sort(indexes[:,a]).tolist()[::-1]))
	# 	# indexes corresponding to the arms of those whittle indexes
	# 	sorted_index_lists_arm_inds.append(np.array(np.argsort(indexes[:,a]).tolist()[::-1]))


	# now recompute the adjusted indexes for each arm
	for arm in range(N):

		for a_curr in range(1,A):

			# compute the lambda levels of other actions, based on budget for those actions
			new_cost = np.copy(C[arm]).astype(float)
			for a_other in range(1,A):

				if a_curr != a_other:

					
					# index_list_for_a = sorted_index_lists[a_other]

					# set lambda as halfway between the last played index and the first unplayed index
					# B_ind = B[a_other-1]-1
					# lambda_a_other = index_list[arm,:,a_other-1].max()
					lambda_a_other = index_list[arm,current_state[arm],a_other-1]
					# B_ind = get_index_of_budget_spent(B[a_other-1], C[:,a_other], sorted_index_lists_arm_inds[a_other])
					# if B_ind >= len(index_list_for_a) - 1:
					# 	lambda_a_other = index_list_for_a[-1]
					# elif B_ind == -1:
					# 	lambda_a_other = index_list_for_a[0]
					# else:
					# 	lambda_a_other = index_list_for_a[B_ind:B_ind+2].mean()
					# print("arm",arm)
					# print("current_state",current_state[arm])
					# print("a_curr",a_curr)
					# print("a_other",a_other)
					# print('lambda_a_other', lambda_a_other)


					# lambda_a_other = 0
					# lambda_a_other = 1e2 # should guarantee that indexes don't change, since it makes oother actions too costly to be optimal

					new_cost[a_other] *= lambda_a_other
					# print('newcost',new_cost)
					# print()


			# print('T',T[arm])
			# print('R',R[arm])
			# print("newcost",new_cost)
			# print('current_state[arm]',current_state[arm])
			# print('a_curr',a_curr)
			# print()
			

			# now rerun binary search using these adjusted costs and lambdas
			adjusted_indexes[arm, a_curr-1] = binary_search_one_arm_full_T(T[arm],R[arm],new_cost,current_state[arm], a_curr, 
						index_lb=index_lb, index_ub=index_ub, gamma=gamma, tolerance=tolerance)

			# print('index list')
			# print(index_list)
			# print(current_state)
			# print(adjusted_indexes)
			# print(new_cost)
			# print()

		adjusted_indexes[arm, adjusted_indexes[arm] < 0] = 0.0

	# print('index list')
	# print(index_list)
	# print(current_state)
	# print(adjusted_indexes)
	
	# print()


	# adjusted_indexes = np.swapaxes(adjusted_indexes,0,1)


	return adjusted_indexes



def binary_search_all_arms(T, R, C, current_state, a_index, index_lb=-1, index_ub=1, gamma=0.95, tolerance=1e-1):

	N = T.shape[0]

	indexes = np.zeros(N)


	for i in range(N):

		indexes[i] = binary_search_one_arm(T[i],R[i],C[i],current_state[i], a_index, 
						index_lb=index_lb, index_ub=index_ub, gamma=gamma, tolerance=tolerance)


	return indexes


# the next two functions handle annoyances around 0 w.r.t. 
# the whole "increase/decrease by factor of 2 depending on the bound" idea
def reduce_lb(lb):
	if lb <= -1:
		return lb*2
	elif lb > -1 and lb < 1:
		return lb - 1
	elif lb >= 1:
		return lb/2

def increase_ub(ub):
	if ub <= -1:
		return ub/2
	elif ub > -1 and ub < 1:
		return ub + 1
	elif ub >= 1:
		return ub*2	




# this version has the step 0 which searches for upper and lower bounds first
def binary_search_one_arm_full_T(T, R, C, current_state, a_index, index_lb=-1, index_ub=1, gamma=0.95, tolerance=1e-1):


	# Go from S,A,S to A,S,S
	T_i = np.swapaxes(T,0,1)
	# don't trim the t-matrix in this one
	# T_i = T_i[[0, a_index]] # trim down to a 2-action matrix 
	# C_i = C[[0, a_index]] # trim down to 2-action cost vector


	# rewards need to be A,S,S too, but R is only S (current state)
	# create the "base-case" reward matrix
	R_base = np.zeros(T_i.shape)
	for x in range(R_base.shape[0]):
		# subtract off the lambda-adjusted costs of other actions
		if x != a_index:
			R_base[x] -= C[x]
		for y in range(R_base.shape[1]):
			R_base[x,:,y] += R

	# this is the reward matrix we will edit
	R_i = np.copy(R_base)



	#####
	# run once to see if we should go up or down
	#####

	upper = index_ub
	lower = index_lb

	index_estimate = (upper + lower)/2

	# adjust the rewards
	# change the reward along the A axis, to account for new index estimate
	R_i[a_index] = R_base[a_index] - index_estimate*C[a_index]


	# run value iteration
	# import pdb; pdb.set_trace()

	mdp = mdptoolbox.mdp.ValueIteration(T_i, R_i, discount=gamma, stop_criterion='fast')
	mdp.run()
	policy = np.array(mdp.policy)

	action = policy[current_state]
	# if not acting, reduce the penalty for acting
	if action != a_index:
		upper = index_estimate

	# if acting, increase the penalty for acting
	elif action == a_index:
		lower = index_estimate


	# now loop until we are told to turn around (finds appropriate upper and lower bounds)
	previous_action = action	
	while action == previous_action:
		# print('lower',lower, 'upper',upper)

		# this needs to be here for loop logic
		previous_action = action

		# if not acting, need to go down
		if previous_action != a_index:
			index_estimate = lower
		# if acting, need to go up
		elif previous_action == a_index:
			index_estimate = upper

		# adjust the rewards
		# change the reward along the A axis, to account for new index estimate
		R_i[a_index] = R_base[a_index] - index_estimate*C[a_index]


		# run value iteration

		mdp = mdptoolbox.mdp.ValueIteration(T_i, R_i, discount=gamma, stop_criterion='fast')

		mdp.run()
		policy = np.array(mdp.policy)

		action = policy[current_state]

		# if we haven't found our flip point, then shift the bounds as needed
		if action == previous_action:
			# if not acting, reduce the penalty for acting and reduce the lower bound
			if action != a_index:
				upper = lower
				lower = reduce_lb(lower)

			# if acting, increase the penalty for acting and increase the upper bound
			elif action == a_index:
				lower = index_estimate
				upper = increase_ub(upper)

		# else, we have found good bounds for our index
		else:
			# if not acting, reduce the penalty for acting and reduce the lower bound
			if action != a_index:
				upper = index_estimate

			# if acting, increase the penalty for acting and increase the upper bound
			elif action == a_index:
				lower = index_estimate

			# loop condition should let us break after this

		if abs(lower) > 1e6 or abs(upper) > 1e6: raise ValueError("bounds got too big")


	# NOW  we should have upper and lower bounds that are guaranteed to contain our index value
	# print("using these upper and lower bounds",lower,upper)
	while (upper - lower) > tolerance:
		# print("upper",upper)
		# print("lower",lower)

		index_estimate = (upper + lower)/2
		
		# print("index_estimate",index_estimate)


		# adjust the rewards
		# change the reward along the A axis, to account for new index estimate
		R_i[a_index] = R_base[a_index] - index_estimate*C[a_index]

		# run value iteration
		# import pdb; pdb.set_trace()
		# s = time.time()

		mdp = mdptoolbox.mdp.ValueIteration(T_i, R_i, discount=gamma, stop_criterion='fast')
		mdp.run()
		policy = np.array(mdp.policy)
		# stoptime = time.time()
		# print('time taken',stoptime-s)

		action = policy[current_state]
		# print("action",action)


		# if not acting, reduce the penalty for acting
		if action != a_index:
			upper = index_estimate
			# print("upper!")

		# if acting, increase the penalty for acting
		elif action == a_index:
			lower = index_estimate
			# print("lower!")

		# print()

	index_estimate = (upper + lower)/2

	return index_estimate










# this version has the step 0 which searches for upper and lower bounds first
def binary_search_one_arm(T, R, C, current_state, a_index, index_lb=-1, index_ub=1, gamma=0.95, tolerance=1e-1):


	# Go from S,A,S to A,S,S
	T_i = np.swapaxes(T,0,1)
	T_i = T_i[[0, a_index]] # trim down to a 2-action matrix 
	C_i = C[[0, a_index]] # trim down to 2-action cost vector


	# rewards need to be A,S,S too, but R is only S (current state)
	# create the "base-case" reward matrix
	R_base = np.zeros(T_i.shape)
	for x in range(R_base.shape[0]):
		for y in range(R_base.shape[1]):
			R_base[x,:,y] += R

	# this is the reward matrix we will edit
	R_i = np.copy(R_base)


	#####
	# run once to see if we should go up or down
	#####

	upper = index_ub
	lower = index_lb

	index_estimate = (upper + lower)/2

	# adjust the rewards
	# change the reward along the A axis, to account for new index estimate
	R_i[1] = R_base[1] - index_estimate*C_i[1]

	# print("lambda",index_estimate)
	# print(R_i)
	# print(C_i)
	# if a_index==1:
	# 	1/0

	# run value iteration
	# import pdb; pdb.set_trace()

	mdp = mdptoolbox.mdp.ValueIteration(T_i, R_i, discount=gamma, stop_criterion='fast')
	mdp.run()
	policy = np.array(mdp.policy)

	action = policy[current_state]
	# if not acting, reduce the penalty for acting
	if action == 0:
		upper = index_estimate

	# if acting, increase the penalty for acting
	elif action == 1:
		lower = index_estimate


	# now loop until we are told to turn around (finds appropriate upper and lower bounds)
	previous_action = action	
	while action == previous_action:
		# print('lower',lower, 'upper',upper)

		# this needs to be here for loop logic
		previous_action = action

		# if not acting, need to go down
		if previous_action == 0:
			index_estimate = lower
		# if acting, need to go up
		elif previous_action == 1:
			index_estimate = upper

		# adjust the rewards
		# change the reward along the A axis, to account for new index estimate
		R_i[1] = R_base[1] - index_estimate*C_i[1]

		# run value iteration

		mdp = mdptoolbox.mdp.ValueIteration(T_i, R_i, discount=gamma, stop_criterion='fast')
		mdp.run()
		policy = np.array(mdp.policy)

		action = policy[current_state]

		# if we haven't found our flip point, then shift the bounds as needed
		if action == previous_action:
			# if not acting, reduce the penalty for acting and reduce the lower bound
			if action == 0:
				upper = lower
				lower = reduce_lb(lower)

			# if acting, increase the penalty for acting and increase the upper bound
			elif action == 1:
				lower = index_estimate
				upper = increase_ub(upper)

		# else, we have found good bounds for our index
		else:
			# if not acting, reduce the penalty for acting and reduce the lower bound
			if action == 0:
				upper = index_estimate

			# if acting, increase the penalty for acting and increase the upper bound
			elif action == 1:
				lower = index_estimate

			# loop condition should let us break after this


	# NOW  we should have upper and lower bounds that are guaranteed to contain our index value
	# print("using these upper and lower bounds",lower,upper)
	while (upper - lower) > tolerance:
		# print("upper",upper)
		# print("lower",lower)

		index_estimate = (upper + lower)/2
		
		# print("index_estimate",index_estimate)


		# adjust the rewards
		# change the reward along the A axis, to account for new index estimate
		R_i[1] = R_base[1] - index_estimate*C_i[1]

		# run value iteration
		# import pdb; pdb.set_trace()
		# s = time.time()

		mdp = mdptoolbox.mdp.ValueIteration(T_i, R_i, discount=gamma, stop_criterion='fast')

		mdp.run()
		policy = np.array(mdp.policy)
		# stoptime = time.time()
		# print('time taken',stoptime-s)

		action = policy[current_state]
		# print("action",action)


		# if not acting, reduce the penalty for acting
		if action == 0:
			upper = index_estimate
			# print("upper!")

		# if acting, increase the penalty for acting
		elif action == 1:
			lower = index_estimate
			# print("lower!")

		# print()

	index_estimate = (upper + lower)/2

	return index_estimate


#compute all per arm indexes for BS
# T has to be in N,S,M,S
def all_per_arm_indexes(N,S,M,T,R,C,index_lb=-1,index_ub=1,gamma=0.95,tolerance=1e-4):
	per_arm_indexes = np.zeros((N,S,M))
	for s in range(S):
		current_state = np.array([s]*N)
		for a_index in range(1,M+1):
			per_arm_indexes[:,s,a_index-1] = binary_search_all_arms(T, R, C, current_state, a_index, index_lb=index_lb, index_ub=index_ub, gamma=gamma, tolerance=tolerance)
	return per_arm_indexes


#compute all per arm indexes for BS
# T has to be in N,S,M,S
def all_per_arm_adjusted_indexes(N, S, M, T, R, C, per_arm_indexes, index_lb=-1, index_ub=1, gamma=0.95, tolerance=1e-4):

	per_arm_adjusted_indexes = np.zeros((N,S,M))
	for s in range(S):

		current_state = np.array([s]*N)
		per_arm_adjusted_indexes[:,s] = adjust_indexes(T, R, C, current_state, per_arm_indexes, index_lb=index_lb, index_ub=index_ub, gamma=gamma, tolerance=tolerance)

	return per_arm_adjusted_indexes



