import numpy as np
from numpy.core.records import _deprecate_shape_0_as_None
import pandas as pd
from utils import *
import argparse

dataset = 'MovieLens'
# dataset = 'Household'
# dataset = 'Jokes'

for dataset in ('MovieLens', 'Household', 'Jokes'):
    if dataset == 'MovieLens':
        df = pd.read_csv("../../data/movielens_1500x1500_lr0.1_wd1e-05_dim20_rmse0.88640326.csv", header=None)
        v = df.to_numpy()
    if dataset == 'Jokes':
        df = pd.read_csv("../../data/jokes_7200.csv")
        v = df.to_numpy()
    if dataset == 'Household':
        df = pd.read_csv("../../data/household_items_understood.csv")
        v = df.to_numpy()
    v = v.astype(np.float)
    print("dataset = {}".format(dataset))

    # solve the instance (normalization done in the function)
    n, m = v.shape

    B = np.ones(n) / n * (np.sum(v) / n)
    # try to normalize everything so that sum(B) == 1
    totalBudget = np.sum(B)
    B, v = B/totalBudget, v/totalBudget

    model = Model('QLEG')
    x = model.variable('x', [n, m], Domain.greaterThan(0))
    u = model.variable('u', n, Domain.greaterThan(0))
    q = model.variable('q', n, Domain.unbounded())
    delta = model.variable('delta', n, Domain.greaterThan(0))
    # supply constraints
    model.constraint('supply', Expr.sum(x, 0), Domain.lessThan(1))
    # buyer-wise QL utility constraints
    for i in range(n):
        rhs = Expr.add(Expr.dot(v[i], x.slice([i,0],[i+1,m])), delta.index(i))
        model.constraint('utility {}'.format(i), Expr.sub(u.index(i), rhs), Domain.lessThan(0))

    # obj and exp-cone transformation
    obj_expr = Expr.sub(Expr.dot(B, q), Expr.sum(delta, 0))
    for i in range(n):
        model.constraint('exp-cone {}'.format(i), Expr.vstack(u.index(i), 1, q.index(i)), Domain.inPExpCone())

    model.objective('maximize-obj', ObjectiveSense.Maximize,  obj_expr)
    model.solve()

    x, u, delta = x.level().reshape((n,m)), u.level(), delta.level()

    # check optimality (equilibrium)
    utility_constrs = [model.getConstraint('utility {}'.format(i)) for i in range(n)]
    beta = np.array([cc.dual()[0] for cc in utility_constrs])
    p = model.getConstraint('supply').dual()
    print(np.min(beta), np.max(beta))
    # save the result
    import os
    fpath = os.path.join('results', dataset.lower(), 'offline-eq-ql')
    os.makedirs(fpath, exist_ok=True)
    np.savetxt(os.path.join(fpath, 'p'), p, fmt='%.4e')
    np.savetxt(os.path.join(fpath, 'x'), x, fmt='%.4e')
    np.savetxt(os.path.join(fpath, 'beta'), beta, fmt='%.4e')