from gurobipy import *
import numpy as np 
import sys
import time



def lp_to_compute_index(T, R, C, B, start_state, a_index, lambda_lim=None, gamma=0.95):

	start = time.time()

	NPROCS = T.shape[0] #number of arms
	NSTATES = T.shape[1] #number of states
	NACTIONS = T.shape[2] #number of actions

	# Create a new model
	m = Model("LP for Computing multi-action indices")
	m.setParam( 'OutputFlag', False )

	L = np.zeros((NPROCS,NSTATES),dtype=object)
	mu = np.zeros((NPROCS,NSTATES),dtype=object)

	for i in range(NPROCS):
		# mu[i] = np.random.dirichlet(np.ones(NSTATES))
		mu[i, int(start_state[i])] = 1 #indicator of current state for each arm

	c = C
	if not isinstance(c[0],list):
		c = [c for i in range(NPROCS)]

	# Create variables
	lb = 0 # lower bound for lambda
	ub = GRB.INFINITY # uper bound for lambda 
	if lambda_lim is not None:
		ub = lambda_lim


	# going to compute indices in a decoupled manner
	index_variables = np.zeros(NPROCS,dtype=object) #array of lambda variables for each arm
	for i in range(NPROCS):
		index_variables[i] = m.addVar(vtype=GRB.CONTINUOUS, lb=lb, ub=ub, name='index_%s'%i)


	for p in range(NPROCS):
		for i in range(NSTATES):
			L[p,i] = m.addVar(vtype=GRB.CONTINUOUS, name='L_%s_%s'%(p,i))


	L = np.array(L) #array of variables for each arm,state


	# print('Variables added in %ss:'%(time.time() - start))
	start = time.time()


	m.modelSense=GRB.MINIMIZE

	# Set objective
	# m.setObjectiveN(obj, index, priority) -- larger priority = optimize first
	# minimze the value function

	# In Hawkins, only min the value function of the start state
	# print(current_state)
	# m.setObjectiveN(sum([L[i][current_state[i]] for i in range(NPROCS)]) + index_variable*B*((1-gamma)**-1), 0, 1)

	m.setObjectiveN(sum([L[i].dot(mu[i]) for i in range(NPROCS)]) + index_variables[i]*B*((1-gamma)**-1), 0, 1)

	# set constraints
	for p in range(NPROCS):
		for i in range(NSTATES):
			for j in range(NACTIONS):
				# m.addConstr( L[p][i] >= R[p][i] - index_variable*c[j] + gamma*L[p].dot(T[p,i,j]) )
				m.addConstr( L[p][i] >= R[p][i] - index_variables[p]*c[p][j] + gamma*LinExpr(T[p,i,j], L[p])) 


	# this computes the index
	# out of convenience it will assume actions are the same on all arms
	# and will compute them in parallel, even though arms are not coupled
	for p in range(NPROCS):
		m.addConstr(R[p][start_state[p]] - index_variables[p]*c[p][a_index] + gamma*LinExpr(T[p,start_state[p],a_index], L[p]) == R[p][start_state[p]] + gamma*LinExpr(T[p,start_state[p],0], L[p]) ) 

	# print('Constraints added in %ss:'%(time.time() - start))
	start = time.time()

	# Optimize model

	m.optimize()
	# m.printStats()

	# print('Model optimized in %ss:'%(time.time() - start))
	start = time.time()


	L_vals = np.zeros((NPROCS,NSTATES))
	index_solved_values = np.zeros(NPROCS)

	for v in m.getVars():
		if 'index' in v.varName:
			i = int(v.varName.split('_')[1])
			index_solved_values[i] = v.x

		if 'L' in v.varName:
			i = int(v.varName.split('_')[1])
			j = int(v.varName.split('_')[2])

			L_vals[i,j] = v.x

	# print('Variables extracted in %ss:'%(time.time() - start))
	start = time.time()

	return L_vals, index_solved_values





# Transition matrix, reward vector, action cost vector
def action_knapsack(values, C, B, exact_knapsack=True):


	m = Model("Knapsack")
	m.setParam( 'OutputFlag', False )

	c = C

	x = np.zeros(values.shape, dtype=object)

	for i in range(x.shape[0]):
		for j in range(x.shape[1]):
			x[i,j] = m.addVar(vtype=GRB.BINARY, name='x_%i_%i'%(i,j))



	m.modelSense=GRB.MAXIMIZE

	# Set objective
	# m.setObjectiveN(obj, index, priority) -- larger priority = optimize first

	# minimze the value function
	m.setObjectiveN((x*values).sum(), 0, 1)

	# set constraints
	if exact_knapsack:
		m.addConstr( x.dot(C).sum() == B )
	else:
		m.addConstr( x.dot(C).sum() <= B )
	for i in range(values.shape[0]):
		# m.addConstr( x[i].sum() <= 1 )
		m.addConstr( x[i].sum() == 1 )


	# Optimize model

	m.optimize()

	x_out = np.zeros(x.shape)

	for v in m.getVars():
		if 'x' in v.varName:
			i = int(v.varName.split('_')[1])
			j = int(v.varName.split('_')[2])

			x_out[i,j] = v.x

		else:
			pass
			# print((v.varName, v.x))

	# print(x_out)
	return x_out