# --------------------------------------------------------------------------
# Source file provided under Apache License, Version 2.0, January 2004,
# http://www.apache.org/licenses/
# (c) Copyright IBM Corp. 2015, 2018
# --------------------------------------------------------------------------

"""The model aims at minimizing the production cost for a number of products
while satisfying customer demand. Each product can be produced either inside
the company or outside, at a higher cost.
The inside production is constrained by the company's resources, while outside
production is considered unlimited.
The model first declares the products and the resources.
The data consists of the description of the products (the demand, the inside
and outside costs, and the resource consumption) and the capacity of the
various resources.
The variables for this problem are the inside and outside production for each
product.
"""

from docplex.mp.model import Model
from docplex.util.environment import get_environment
import numpy as np
import Simulation as sim
import torch


# ----------------------------------------------------------------------------
# Initialize the problem data
# ----------------------------------------------------------------------------


# ----------------------------------------------------------------------------
# Build the model
# ----------------------------------------------------------------------------


def solve_br( Z=np.zeros((50)),
             fC2=np.zeros((5, 8 + 50)), fC3=np.ones((5)),fB2=np.ones((5)),fB3=np.ones((5)),lmbda=1,value=0):
    """ Takes as input:
        - a list of product tuples (name, demand, inside, outside)
        - a list of resource tuples (name, capacity)
        - a list of consumption tuples (product_name, resource_named, consumed)
    """

    # print('xk',xk)



    # --- decision variables ---
    mdl = Model(name='BoundedRat')

    pUpper = np.array([4.0])
    pLower = np.array([0])
    set_params = range(0, 1)



    params= {(i): mdl.continuous_var(lb=pLower[i], ub=pUpper[i], name="u_{0}".format(i))
              for i in set_params}




    ###################RELU
    M = 100000
    neuronNum_fc2 = fC2.shape[0]
    set_Neurons2 = range(0, neuronNum_fc2)

    neuronNum_fc3 = fC3.shape[0]
    set_Neurons3 = range(0, neuronNum_fc3)

    i_vars1 = {(i): mdl.binary_var(name="i1_{0}".format(i))
              for i in set_Neurons2}
    o1 = {(i): mdl.continuous_var(lb=-1000000000, ub=100000000, name="o1_{0}".format(i), )
          for i in set_Neurons2}
    o2 = {(i): mdl.continuous_var(lb=-10000000000, ub=1000000000, name="o2_{0}".format(i))
          for i in set_Neurons2}
    o3 = {(i): mdl.continuous_var(lb=-1000000000, ub=100000000, name="o3_{0}".format(i), )
          for i in set_Neurons3}







    uMat = fC2[:, -1:]
    ZMat = fC2[:, :-1]
    zVec = np.dot(ZMat, Z)
    ################RELU CONSTRAINTS
    eps1 = .0000001

    constraintsO1 = {(n):
        mdl.add_constraint(
            ct=o1[n] == mdl.sum(uMat[n, u] * params[u] for u in set_params) + zVec[n].item()+fB2[n],
            ctname="chanceConstraintO1_{0}".format(n))
        for n in set_Neurons2}

    constraintsRL1 = {(n):
                          mdl.add_indicator(i_vars1[n],
                                            -eps1 <= o1[n],
                                            name="chanceConstraintRL1_{0}".format(n), active_value=1)
                      for n in set_Neurons2}

    constraintsRL2 = {(n):
                          mdl.add_indicator(i_vars1[n],
                                            o1[n] <= eps1,
                                            name="chanceConstraintRL2_{0}".format(n), active_value=0)
                      for n in set_Neurons2}

    constraintsRL3 = {(n):
        mdl.add_constraint(
            ct=o1[n] - (1 - i_vars1[n]) * M - eps1 <= o2[n],
            ctname="chanceConstraintRL3_{0}".format(n))
        for n in set_Neurons2}

    constraintsRL4 = {(n):
        mdl.add_constraint(
            ct=o2[n] <= o1[n] + (1 - i_vars1[n]) * M + eps1,
            ctname="chanceConstraintRL4_{0}".format(n))
        for n in set_Neurons2}

    constraintsRL5 = {(n):
        mdl.add_constraint(
            -1 * (i_vars1[n]) * M - eps1 <= o2[n],
            ctname="chanceConstraintRL5_{0}".format(n))
        for n in set_Neurons2}

    constraintsRL6 = {(n):
        mdl.add_constraint(
            o2[n] <= (i_vars1[n]) * M + eps1,
            ctname="chanceConstraintRL6_{0}".format(n))
        for n in set_Neurons2}

    constraintsO1 = {(n):
        mdl.add_constraint(
            ct=o3[n] == mdl.sum(fC3[r] * o2[r] for r in set_Neurons2)+fB3[n],
            ctname="chanceConstraintO1_{0}".format(n))
        for n in range(0,1)}



    # --- objective ---
    #

    maxInfoGain_Objective = mdl.sum(fC3[r].item() * o2[r] for r in set_Neurons2)+fB3[0]


    # -.0000001*maxInfoGain_Objective
    objective = maxInfoGain_Objective
    mdl.maximize(objective)
    sol = mdl.solve()
    #mdl.print_solution()
    # print(mdl.get_solve_status())

    solution = sol[params[0]]
    maxVal = sol[o3[0]]
    #print('maxVal',maxVal)

    print('Best Action for Info Gain',solution)
    return solution


def print_production_solution(mdl, products):
    obj = mdl.objective_value
    print("* Production model solved with objective: {:g}".format(obj))
    print("* Total inside cost=%g" % mdl.total_inside_cost.solution_value)
    for p in products:
        print("Inside production of {product}: {ins_var}".format
              (product=p[0], ins_var=mdl.inside_vars[p].solution_value))
    print("* Total outside cost=%g" % mdl.total_outside_cost.solution_value)
    for p in products:
        print("Outside production of {product}: {out_var}".format
              (product=p[0], out_var=mdl.outside_vars[p].solution_value))


# ----------------------------------------------------------------------------
# Solve the model and display the result
# ----------------------------------------------------------------------------


if __name__ == '__main__':
    # Build the model
    model = solve_br()
    model.print_information()