''' Fair division of [0,1] given n buyers with linear valuation functions '''

import numpy as np
# import cvxpy as cp
# np.random.seed(888)

# def compute_eq_utilities(B, c, d) # c[i] * theta + d[i] >= 0 for 0 <= theta <= 1

def compute_b_from_x(x, v, B):
    tt = v*x
    return ((tt.T / np.sum(tt, 1)) * B).T

def eg_primal_obj_val(x, v, B):
    return np.sum(B * np.log(np.sum(v*x, 1))) # maximization obj

def eg_dual_obj_val(beta, v, B):
    m = v.shape[1]
    return np.sum(np.max(beta * v.T, 1))/m - np.sum(B * np.log(beta)) + np.sum(B * np.log(B)) - np.sum(B)

def compute_me_fin_dim(v, B = None, accu = 1e-3, max_iter = 1000):
    ''' sum(v[i]) == m for all i, sum(B) == 1 '''
    n, m = v.shape
    if B is None:
        B = np.ones(n)/n
    x = np.multiply((B/sum(B)), np.ones(shape=(m,n))).T
    b = v * x
    for iter in range(max_iter): # proportional response dynamics
        p = np.sum(b, axis=0) # compute prices
        x = b/p # new allocation
        b = compute_b_from_x(x, v, B) # new bids
    return m*p, x/m

##################################### call Mosek directly ###################################
import scipy as sp
from scipy import sparse
import scipy.linalg as spla
from time import time
from mosek import *
from mosek.fusion import *

def compute_qlme_mosek(v, s = None, B = None):
    # solve the (primal) QLEG convex program using Mosek (via calling its Fusion API)
    n, m = v.shape
    if s is None: s = np.ones(m) / m
    if B is None: B = np.ones(n) / n
    model = Model('QLEG')
    x = model.variable('x', [n, m], Domain.greaterThan(0))
    u = model.variable('u', n, Domain.greaterThan(0))
    q = model.variable('q', n, Domain.unbounded())
    delta = model.variable('delta', n, Domain.greaterThan(0))
    # supply constraints
    model.constraint('supply', Expr.sum(x, 0), Domain.lessThan(s)) # old: Domain.lessThan(1/m)
    # buyer-wise QL utility constraints
    for i in range(n):
        rhs = Expr.add(Expr.dot(v[i], x.slice([i,0], [i+1,m])), delta.index(i))
        model.constraint('utility {}'.format(i), Expr.sub(u.index(i), rhs), Domain.lessThan(0))

    # obj and exp-cone transformation
    obj_expr = Expr.sub(Expr.dot(B, q), Expr.sum(delta, 0))
    for i in range(n):
        model.constraint('exp-cone {}'.format(i), Expr.vstack(u.index(i), 1, q.index(i)), Domain.inPExpCone())

    model.objective('maximize-obj', ObjectiveSense.Maximize, obj_expr)
    model.solve()

    # get optimal solutions (primal & dual)
    x_opt, u_opt, delta_opt = x.level().reshape((n,m)), u.level(), delta.level()
    beta_opt = np.array([model.getConstraint('utility {}'.format(i)).dual()[0] for i in range(n)])
    p_opt = np.max(beta_opt * v.T, 1) # model.getConstraint('supply').dual() # p_from_beta_opt = pp = np.max(beta * v.T, 1)

    # return all info
    return np.maximum(x_opt, 0), u_opt, delta_opt, beta_opt, p_opt



def compute_me_mosek(v, B = None):
    ''' assume s[j] == 1/m, sum(v[i]) == m, sum(B) == 1 
        this requires scipy version 1.5.2 (higher versions use a differnet block_diag syntax) '''
    begin_ref_time = time()
    n, m = v.shape
    if B is None:
        B = np.ones(n)/n
    # first, scale to s[j] = 1, sum(v[i]) == 1 and solve it
    v = (v.T / np.sum(v, 1)).T
    v_part = sparse.block_diag([v[i] for i in range(n)])
    I_rep = sparse.hstack([sparse.identity(m) for i in range(n)])
    In = sparse.identity(n)
    A_full = sparse.hstack([v_part, -In])
    temp = sparse.hstack([I_rep, sparse.csc_matrix(([], ([],[])), shape=(m, n))])
    A_full = sparse.vstack([A_full, temp])
    A_full = sparse.block_diag([A_full, In])
    temp = sparse.csc_matrix(([], ([],[])), shape=(2*n+m, n))
    A_full = sparse.hstack([A_full, temp]) # A_full.shape == (2*n+m, n*m+3*n )
    # A_full = A_full.tocsr()
    rows_A, cols_A = A_full.nonzero()
    vals_A = A_full.data
    ############# call Mosek #############
    x_all = [0.0] * (n*m + 3*n)
    env = Env()
    task = env.Task(0, 1)
    task.putintparam(iparam.log, 10)
    task.putintparam(iparam.intpnt_multi_thread, onoffkey.off)
    # task.putdouparam(dparam.intpnt_tol_rel_gap, 1e-8)
    # variables
    task.appendvars(n*m + 3*n) # in the order x[i,j], u[i], s[i], t[i]
    task.putvarboundlist(list(np.arange(n*m+3*n)), [boundkey.lo] * (n*m) + [boundkey.fr] * (3*n), [0.0]*(n*m+3*n), [1]*(n*m+3*n))
    # obj.
    c_vec = np.zeros((n*m+3*n,))
    c_vec[-n:] = -B
    task.putclist(np.arange(n*m+3*n), c_vec)
    task.putobjsense(objsense.minimize)
    # constr.
    task.appendcons(2*n+m)
    task.putaijlist(rows_A, cols_A, vals_A)
    rhs = list(np.concatenate([np.zeros((n,)), np.ones(n + m)]))
    task.putconboundslice(0, 2*n+m, [boundkey.fx]*(2*n+m), rhs, rhs)
    [task.appendcone(conetype.pexp, 0.0, [ii, ii+n, ii+2*n]) for ii in range(n*m, n*m+n)]
    print('reformulation time = {:.4f}'.format(time() - begin_ref_time))
    begin_msk = time()
    task.optimize()
    task.getsolsta(soltype.itr)
    task.getxx(soltype.itr, x_all)
    x_msk = np.reshape(x_all[:n*m], (n, m))
    msk_direct_obj = task.getprimalobj(soltype.itr)
    # b_msk = compute_b_from_x(x_msk) # np.sum(b_msk, 1) - B is very small
    # ave_dgap_msk = compute_ave_dgap(b_msk)
    print("Mosek (direct) time = {:.4f}".format(time() - begin_msk))
    # print("||np.sum(x, 0)-1|| = {}".format(np.linalg.norm(np.sum(x_msk, 0)-1)))
    b = compute_b_from_x(x_msk, v, B)
    p = np.sum(b, axis = 0)
    return m * p, x_msk / m


if __name__ == '__main__':
    n, m = 100, 500
    v = np.random.uniform(size=(n, m))
    v = (v.T / np.sum(v, 1)).T
    B = np.random.uniform(size = n) + 0.2
    B = B / np.sum(B)
    p, x = compute_me_fin_dim(v, B, accu=0.001, max_iter=2000)
    p_mosek, x_mosek = compute_me_mosek(v, B)
    # np.linalg.norm(p - p_mosek)/np.linalg.norm(p)