# --------------------------------------------------------------------------
# Source file provided under Apache License, Version 2.0, January 2004,
# http://www.apache.org/licenses/
# (c) Copyright IBM Corp. 2015, 2018
# --------------------------------------------------------------------------

"""The model aims at minimizing the production cost for a number of products
while satisfying customer demand. Each product can be produced either inside
the company or outside, at a higher cost.
The inside production is constrained by the company's resources, while outside
production is considered unlimited.
The model first declares the products and the resources.
The data consists of the description of the products (the demand, the inside
and outside costs, and the resource consumption) and the capacity of the
various resources.
The variables for this problem are the inside and outside production for each
product.
"""

from docplex.mp.model import Model
from docplex.util.environment import get_environment
import numpy as np
import torch



# ----------------------------------------------------------------------------
# Initialize the problem data
# ----------------------------------------------------------------------------


# ----------------------------------------------------------------------------
# Build the model
# ----------------------------------------------------------------------------




def solve_br(
             fC2=np.zeros((5, 8 + 50)), fC3=np.ones((5)),fB1=np.zeros((5,5)),fB2=np.zeros((5,5))):
    """ Takes as input:
        - a list of product tuples (name, demand, inside, outside)
        - a list of resource tuples (name, capacity)
        - a list of consumption tuples (product_name, resource_named, consumed)
    """

    # --- decision variables ---
    mdl = Model(name='BoundedRat')

    pUpper = np.array([4.0])
    pLower = np.array([0])
    set_params = range(0, 1)



    params= {(i): mdl.continuous_var(lb=pLower[i], ub=pUpper[i], name="u_{0}".format(i))
              for i in set_params}


    ###################RELU
    M = 100
    neuronNum_fc2 = fC2.shape[0]
    # set_U_allStep = range(0, 8)
    set_Neurons = range(0, neuronNum_fc2)
    # set_Neurons = set([0,1,2,3,4])
    o2 = {(i): mdl.continuous_var(lb=-10000000000, ub=1000000000, name="o2_{0}".format(i))
          for i in set_Neurons}
    i_vars = {(i): mdl.binary_var(name="i_{0}".format(i))
              for i in set_Neurons}
    o1 = {(i): mdl.continuous_var(lb=-1000000000, ub=100000000, name="o1_{0}".format(i), )
          for i in set_Neurons}
    o3 = {(i): mdl.continuous_var(lb=-1000000000, ub=100000000, name="o3_{0}".format(i), )
          for i in range(0,1)}




    ################RELU CONSTRAINTS


    constraintsO1 = {(n):
        mdl.add_constraint(
            ct=o1[n] == mdl.sum(fC2[n, p] * params[p] +fB1[n] for p in set_params),
            ctname="chanceConstraintO1_{0}".format(n))
        for n in set_Neurons}

    constraintsRL1 = {(n):
                          mdl.add_indicator(i_vars[n],
                                            0<= o1[n],
                                            name="chanceConstraintRL1_{0}".format(n), active_value=1)
                      for n in set_Neurons}

    constraintsRL2 = {(n):
                          mdl.add_indicator(i_vars[n],
                                            o1[n] <= 0,
                                            name="chanceConstraintRL2_{0}".format(n), active_value=0)
                      for n in set_Neurons}

    constraintsRL3 = {(n):
        mdl.add_constraint(
            ct=o1[n] - (1 - i_vars[n]) * M -0 <= o2[n],
            ctname="chanceConstraintRL3_{0}".format(n))
        for n in set_Neurons}

    constraintsRL4 = {(n):
        mdl.add_constraint(
            ct=o2[n] <= o1[n] + (1 - i_vars[n]) * M +0,
            ctname="chanceConstraintRL4_{0}".format(n))
        for n in set_Neurons}

    constraintsRL5 = {(n):
        mdl.add_constraint(
            -1 * (i_vars[n]) * M - 0 <= o2[n],
            ctname="chanceConstraintRL5_{0}".format(n))
        for n in set_Neurons}

    constraintsRL6 = {(n):
        mdl.add_constraint(
            o2[n] <= (i_vars[n]) * M + 0,
            ctname="chanceConstraintRL6_{0}".format(n))
        for n in set_Neurons}
    print(fC3.shape)
    #fC3=fC3.reshape(1,-1)
    constraintsO1 = {(n):
        mdl.add_constraint(
            ct=o3[n] == mdl.sum(fC3[n,r].item() * o2[r] for r in set_Neurons)+fB2[0],
            ctname="chanceConstraintO1_{0}".format(n))
        for n in range(0,1)}


    # -.0000001*maxInfoGain_Objective
    print(fC3.shape)
    objective =mdl.sum(o3[n] for n in range(0,1))
    mdl.maximize(objective)
    sol = mdl.solve()
    #mdl.print_solution()
    # print(mdl.get_solve_status())

    solution = [sol[params[u]] for u in set_params]
    print('Global Best Action',solution)
    #allU = [sol[u_vars[u]] for u in set_U_allStep]

    # xkC_Step, ukC_Step = propogateStates(AAverage, BAverage, t, 3)
    # coef1_Step = np.hstack((np.dot(C, ukC_Step[i]) for i in range(0, 3)))
    # print('here',np.dot(coef1_Step , np.asarray(allU)).reshape(10,1)+np.dot(C,np.dot(xkC_Step,xk)).reshape(10,1))

    # print("STEP")
    # print(u0)
    #print("u",solution)
    maxVal=sol[o3[0]]
    #print("max",sol[o3[0]])


    maxMem=[fC3[0,i].item() * sol[o2[i]] for i in set_Neurons ]
    #print(np.sum(maxMem)+fB2[0])


    # print('minchange',[(sol[uP1_vars[z]]+sol[uP2_vars[z]])*prime_weight[z] for z in set_U_allStep])
    # print("prob",[(1-sol[p2_vars[p2]])*Probs[p2] for p2 in set_P]+[(1-sol[p3_vars[p3]])*Probs[p3] for p3 in set_P]+[(1-sol[p4_vars[p4]])*Probs[p4] for p4 in set_P])
    # print("zVec",zVec)
    # print('Must be greater than',[-(1-sol[e_vars[u]])*M+eps1 for u in set_Neurons])
    # print('Must be less than', [( sol[e_vars[u]]) * M + eps1 for u in set_Neurons])
    # print('e', [sol[e_vars[u]] for u in set_Neurons])
    # print('R', [sol[R_vars[u]] for u in set_Neurons])
    # print('o1', [sol[o1[u]] for u in set_Neurons])
    # print('o2', [sol[o2[u]] for u in set_Neurons])

    return solution,maxVal





# ----------------------------------------------------------------------------
# Solve the model and display the result
# ----------------------------------------------------------------------------


if __name__ == '__main__':
    # Build the model
    model = solve_br()
    model.print_information()