import gurobipy as gb
from gurobipy import GRB

import numpy as np

def wasserstein(demands, x_test, X_train, params, rho, alpha = 0.95):

    c_lin = params['c_lin']
    c_quad = params['c_quad']
    b_lin = params['b_lin']
    b_quad = params['b_quad']
    h_lin = params['h_lin']
    h_quad = params['h_quad']

    d = params['d']

    S = len(params['d']) 
    N = len(demands) 

    model = gb.Model("model")

    w = model.addVars(1)
    theta   = model.addVars(1)
    lambda_ = model.addVars(1)
    mu = model.addVars(N)

    y_bl = model.addVars(S)
    y_bq = model.addVars(S)
    y_hl = model.addVars(S)
    y_hq = model.addVars(S)

    model.addConstr(lambda_[0] >= 0)
    for i in range(N): model.addConstr(mu[i] >= 0)
    model.addConstr(w[0] >= 0)
    for i in range(S):
        model.addConstr(y_bl[i] >= 0)
        model.addConstr(y_hl[i] >= 0) 
        model.addConstr(y_bq[i] >= 0)
        model.addConstr(y_hq[i] >= 0)
        model.addConstr(y_bl[i] >= d[i] - w[0])
        model.addConstr(y_bq[i] >= d[i] - w[0])
        model.addConstr(y_hl[i] >= w[0] - d[i])
        model.addConstr(y_hq[i] >= w[0] - d[i])

    model.setObjective(lambda_[0] * rho + theta[0] + (1 / (N * alpha)) * mu.sum())

    for k in range(N): 
        for i in range(S):  
            model.addConstr(mu[k] + theta[0] >= c_lin*w[0] + c_quad*w[0]*w[0] + 
                                    h_lin*y_hl[i] + b_lin*y_bl[i] + 
                                    h_quad * y_hq[i]*y_hq[i] + b_quad*y_bq[i]*y_bq[i] - 
                                    lambda_[0] * (np.abs(d[i] - demands[k]) + np.sum(np.abs(x_test - X_train[k,:])))) 
                                    # lambda_ * (np.abs(x_test - X_train).mean(axis=1))) 
                                    # lambda_[0] * (np.abs(d[i] - demands[k]) + np.abs(x_test - X_train[k,:]))) 
    model.optimize()
    return model.getVars()[0].X
