from gurobipy import *
import numpy as np 
import sys
import time



# all action types constrained by an upper and lower bound
def hawkins_fairness(T, R, C, B, L, start_state, lambda_lbs=None, lambda_ubs=None, gamma=0.95):

	# T is the transition matrix with shape (NPROCS, NSTATES, NACTIONS, NSTATES)
	# R is reward matrix (NPROCS, STATES)
	# C is cost matrix with shape (NPROCS, NACTIONS)

	# B is the upper bound
	# L is the lower bound

	# start_state is array of length (NPROCS) defining the current state of each process/arm

	# lambda_lbs and lambda_ubs let you reduce the search space of the relevant lambda values 
	# by defining upper/lower bounds for each one, which will be passed to the LP (this will only
	# speed up the LP)
	# NOTE: if you don't pass anything in for lambda_lbs or lambda_ubs, they'll just default to inf or -inf as needed

	# Other notes:
	# arrays that handle the constraints will have shape (2, NACTIONS)
	# and the 0 index will relate to lower bounds and 1 will relate to upper bounds

	# also some translations from paper notation to code notation: 
	# N == NPROCS
	# S == NSTATES
	# M == NACTIONS
	# import pdb; pdb.set_trace()

	start = time.time()

	NPROCS = T.shape[0]
	NSTATES = T.shape[1]
	NACTIONS = T.shape[2]

	# Create a new model
	m = Model("LP for Hawkins Lagrangian relaxation")
	m.setParam( 'OutputFlag', False )

	L_vals = np.zeros((NPROCS,NSTATES),dtype=object)
	
	mu = np.zeros((NPROCS,NSTATES),dtype=object)
	for i in range(NPROCS):
		# mu[i] = np.random.dirichlet(np.ones(NSTATES))
		mu[i, int(start_state[i])] = 1


	# Create variables
	lbs = np.zeros((2,NACTIONS))
	ubs = np.zeros((2,NACTIONS))
	lbs[:] = 0#-GRB.INFINITY
	ubs[:] = 100#GRB.INFINITY

	if lambda_lbs is not None:
		for i in range(NACTIONS):
			lbs[0, i] = lambda_lbs[0, i]
			lbs[1, i] = lambda_lbs[1, i]


	if lambda_ubs is not None:
		for i in range(NACTIONS):
			ubs[0, i] = lambda_ubs[0, i]
			ubs[1, i] = lambda_ubs[1, i]

	# 
	index_variables = np.zeros((2, NACTIONS), dtype=object)
	for i in range(NACTIONS):
		index_variables[0, i] = m.addVar(vtype=GRB.CONTINUOUS, lb=lbs[0, i], ub=ubs[0, i], name='index_0_%s'%i)
		index_variables[1, i] = m.addVar(vtype=GRB.CONTINUOUS, lb=lbs[1, i], ub=ubs[1, i], name='index_1_%s'%i)


	value_lb = 0#-GRB.INFINITY

	for p in range(NPROCS):
		for i in range(NSTATES):
			L_vals[p,i] = m.addVar(vtype=GRB.CONTINUOUS, lb=value_lb, name='Lvals_%s_%s'%(p,i))


	L_vals = np.array(L_vals)


	# print('Variables added in %ss:'%(time.time() - start))
	start = time.time()


	m.modelSense=GRB.MINIMIZE

	# Set objective
	# m.setObjectiveN(obj, index, priority) -- larger priority = optimize first
	# minimze the value function

	# In Hawkins, only min the value function of the start state
	# print(current_state)
	# m.setObjectiveN(sum([L_vals[i][current_state[i]] for i in range(NPROCS)]) + index_variable*B*((1-gamma)**-1), 0, 1)
	tiny=1e-6
	m.setObjectiveN(sum([L_vals[i].dot(mu[i]) for i in range(NPROCS)]) + index_variables[1,:].sum()*B*((1-gamma)**-1) - index_variables[0,:].sum()*L*((1-gamma)**-1) + tiny*L_vals.sum(), 0, 1)

	# set constraints
	for p in range(NPROCS):
		for i in range(NSTATES):
			for j in range(NACTIONS):
				# m.addConstr( L_vals[p][i] >= R[p][i] - index_variable*c[j] + gamma*L_vals[p].dot(T[p,i,j]) )
				m.addConstr( L_vals[p][i] >= R[p][i] + C[p,j]*(index_variables[0,j] - index_variables[1,j]) + gamma*LinExpr(T[p,i,j], L_vals[p])) 


	# effectively remove constraints from the passive action.
	# This is important to keep the solution bounded. Since passive action
	# has no cost, it can make objective arbitrarily small without making L_vals larger
	m.addConstr( index_variables[0,0] == 0)
	m.addConstr( index_variables[1,0] == 0)


	# print('Constraints added in %ss:'%(time.time() - start))
	start = time.time()

	# Optimize model

	m.optimize()
	# print(m.printStats())
	# print("Status")
	# print(m.status) # https://www.gurobi.com/documentation/9.1/refman/optimization_status_codes.html
	# print("L")
	# print(L_vals)
	# print("index lowers")
	# print(index_variables[0])
	# print("index uppers")
	# print(index_variables[1])

	# import pdb;pdb.set_trace()

	# print('Model optimized in %ss:'%(time.time() - start))
	start = time.time()


	L_vals = np.zeros((NPROCS,NSTATES))

	index_solved_values = np.zeros((2, NACTIONS)) 
	for v in m.getVars():
		if 'index' in v.varName:
			i = int(v.varName.split('_')[1])
			j = int(v.varName.split('_')[2])
			
			index_solved_values[i,j] = v.x

		if 'Lvals' in v.varName:
			i = int(v.varName.split('_')[1])
			j = int(v.varName.split('_')[2])

			L_vals[i,j] = v.x

	# print('Variables extracted in %ss:'%(time.time() - start))
	start = time.time()

	obj = m.getObjective()
	

	return L_vals, index_solved_values, obj.getValue()



# Transition matrix, reward vector, action cost vector
# TODO: rename to action_IP_fairness or something
def action_IP_fairness(values, C, HB, LB):


	m = Model("Knapsack")
	m.setParam( 'OutputFlag', False )

	c = C

	# nprocs, nstates
	x = np.zeros(values.shape, dtype=object)

	for i in range(x.shape[0]):
		for j in range(x.shape[1]):
			x[i,j] = m.addVar(vtype=GRB.BINARY, name='x_%i_%i'%(i,j))



	m.modelSense=GRB.MAXIMIZE

	# Set objective
	# m.setObjectiveN(obj, index, priority) -- larger priority = optimize first

	# minimze the value function
	m.setObjectiveN((x*values).sum(), 0, 1)

	# set constraints
	# m.addConstr( x.dot(C).sum() == B )


	# range from 1 onward, so we don't impose a constraint on the passive actions
	for a in range(1, C.shape[1]):
		m.addConstr( x[:,a].dot(C[:,a]) <= HB )
		m.addConstr( x[:,a].dot(C[:,a]) >= LB )

	for i in range(values.shape[0]):
		# m.addConstr( x[i].sum() <= 1 )
		m.addConstr( x[i].sum() == 1 )


	# Optimize model

	m.optimize()

	x_out = np.zeros(x.shape)

	for v in m.getVars():
		if 'x' in v.varName:
			i = int(v.varName.split('_')[1])
			j = int(v.varName.split('_')[2])

			x_out[i,j] = v.x

		else:
			pass
			# print((v.varName, v.x))

	# print(x_out)
	return x_out


def get_hawkins_actions(N,T, R, C, HB, LB, current_state, gamma):
    actions = np.zeros(N)

    # N x A x S matrix
    indexes = np.zeros((N, T.shape[2], T.shape[1]))

    current_state = current_state.reshape(-1).astype(int)

    Q_vals, lambda_vals, obj_val = hawkins_fairness(T, R, C, HB, LB, current_state, gamma=gamma)


    for i in range(N):
        for a in range(C.shape[1]):
            for s in range(T.shape[1]):
                indexes[i,a,s] = R[i,s] + C[i,a]*(lambda_vals[0,a] - lambda_vals[1,a]) + gamma*Q_vals[i].dot(T[i,s,a])


    indexes_per_state = np.zeros((N, T.shape[2]))
    for i in range(N):
        s = current_state[i]
        indexes_per_state[i] = indexes[i,:,s]


    decision_matrix = action_IP_fairness(indexes_per_state, C, HB, LB)

    actions = np.argmax(decision_matrix, axis=1)

    if not (decision_matrix.sum(axis=1) <= 1).all(): raise ValueError("More than one action per person")

    payment = np.zeros(C.shape[1])
    for arm in range(len(actions)):
        payment[actions[arm]] += C[arm][actions[arm]]

    EPS = 1e-6
    if (payment - EPS > HB).any():
        raise ValueError("Over budget", payment, actions)

    if (payment[1:] + EPS < LB).any():
        raise ValueError("Under budget", payment, actions, LB, EPS, payment[1:] - EPS)

    return actions