from random import random
import torch
from pulp import (
    LpVariable,
    LpMaximize,
    LpProblem,
    lpSum,
    GUROBI_CMD,
    value,
    LpStatus,
)

n = 5
assert n == 5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def uniform_sampling(number_of_samples):
    res = []
    while len(res) < number_of_samples:
        x, y = random(), random()
        res.append((x, y))
    return res

def construct_relu_mlp_outputs(mlp, inputs, prob, variables, unique_label, extra_inputs=False, shapes=None, network_layers=None):
    M = 1000  # big constant for MIP -- its value seems to have implication on accuracy

    if extra_inputs == False:
        shapes = mlp.shapes
        network_layers = mlp.linear_relu_stack

    # print(mlp.shapes)
    assert len(inputs) == shapes[0]
    weights = []
    biases = []
    for i in range(0, len(network_layers), 2):
        weights.append(network_layers[i].weight.data.cpu().numpy())
        biases.append(network_layers[i].bias.data.cpu().numpy())
    # print(weights)
    # print(biases)
    for i, weight in enumerate(weights):
        assert weight.shape[0] == shapes[i + 1]
        assert weight.shape[1] == shapes[i]
    for i, bias in enumerate(biases):
        assert bias.shape[0] == shapes[i + 1]

    for i, size in enumerate(shapes):
        for j in range(size):
            # For all layers except for the last layer, nodes have nonnegtive values after relu
            if i < len(shapes) - 1:
                variables[(unique_label, i, j, "nonneg")] = LpVariable(
                    f"{unique_label},{i},{j},nonneg"
                )

            # For the first layer, the bids are nonnegative already
            if i == 0:
                prob += inputs[j] == variables[(unique_label, i, j, "nonneg")]

            # Except for the first and last layer, need ReLU binary variables
            if 0 < i < len(shapes) - 1:
                variables[(unique_label, i, j, "binary")] = LpVariable(
                    f"{unique_label},{i},{j},binary", cat="Binary"
                )

            # Except for the first layer, before ReLU, the node values may be negative
            if i > 0:
                variables[(unique_label, i, j)] = LpVariable(f"{unique_label},{i},{j}")

    for i in range(1, len(shapes) - 1):
        for j in range(shapes[i]):
            prob += (
                variables[(unique_label, i, j, "nonneg")]
                >= variables[(unique_label, i, j)]
            )
            prob += variables[(unique_label, i, j, "nonneg")] >= 0
            prob += (
                variables[(unique_label, i, j, "nonneg")]
                <= variables[(unique_label, i, j)]
                + M * variables[(unique_label, i, j, "binary")]
            )
            prob += (variables[(unique_label, i, j, "nonneg")] <= M * (
                1 - variables[(unique_label, i, j, "binary")]
            ))

    for i in range(1, len(shapes)):
        for j in range(shapes[i]):
            prob += (
                variables[(unique_label, i, j)]
                == lpSum(
                    [
                        variables[(unique_label, i - 1, k, "nonneg")]
                        * weights[i - 1][j][k]
                        for k in range(shapes[i - 1])
                        if abs(weights[i - 1][j][k]) > 0.000001
                    ]
                )
                + biases[i - 1][j]
            )

    return [
        variables[(unique_label, len(shapes) - 1, j)] for j in range(shapes[-1])
    ]
