from pulp import (
    LpVariable,
    LpMaximize,
    LpProblem,
    lpSum,
    GUROBI_CMD,
    value,
    LpStatus,
)
from itertools import product

H = 5
M = 1000  # big constant for MIP


def construct_max(inputs, prob, variables, output_variable_name):
    assert output_variable_name not in variables
    variables[output_variable_name] = LpVariable(output_variable_name)
    out = variables[output_variable_name]

    if len(inputs) == 1:
        prob += out == inputs[0]
    else:
        auxillary_binary_variables = []
        for idx, v in enumerate(inputs):
            auxillary_binary_variable_name = f"{output_variable_name}-{idx}"
            assert auxillary_binary_variable_name not in variables
            variables[auxillary_binary_variable_name] = LpVariable(
                auxillary_binary_variable_name, cat="Binary"
            )
            prob += out >= v
            prob += out <= v + M * (1 - variables[auxillary_binary_variable_name])
            auxillary_binary_variables.append(variables[auxillary_binary_variable_name])
        prob += lpSum(auxillary_binary_variables) == 1

    return out


def marginal_rev_bound(alpha, density_upper_bound):
    prob = LpProblem("marginal_rev_bound", LpMaximize)
    variables = {}
    variables["zero"] = LpVariable("zero", lowBound=0, upBound=0)
    zero = variables["zero"]
    prob += 0

    for x, y in product(range(H), range(H)):
        variables[f"density-{x}-{y}"] = LpVariable(
            f"density-{x}-{y}", lowBound=1, upBound=density_upper_bound
        )

    for x, y in product(reversed(range(H)), reversed(range(H))):
        if x == 0:
            construct_max([zero], prob, variables, f"mv-{x}-{y}-0")
        if y == 0:
            construct_max([zero], prob, variables, f"mv-{x}-{y}-1")

        if x == H - 1:
            construct_max(
                [(H - 1) / H * variables[f"density-{x}-{y}"]],
                prob,
                variables,
                f"mv-{x}-{y}-0",
            )
            construct_max(
                [zero],
                prob,
                variables,
                f"mv-{x}-{y}-0-above",
            )
        if y == H - 1:
            construct_max(
                [(H - 1) / H * variables[f"density-{x}-{y}"]],
                prob,
                variables,
                f"mv-{x}-{y}-1",
            )
            construct_max(
                [zero],
                prob,
                variables,
                f"mv-{x}-{y}-1-above",
            )

        if 0 < x < H - 1:
            tmp = construct_max(
                [
                    variables[f"mv-{x+1}-{y}-0-above"],
                    (x + 1)
                    / H
                    * lpSum(variables[f"density-{tt}-{y}"] for tt in range(x + 1, H)),
                ],
                prob,
                variables,
                f"mv-{x}-{y}-0-above",
            )
            mv_here = (
                x / H * lpSum(variables[f"density-{tt}-{y}"] for tt in range(x, H))
            )
            construct_max([mv_here - tmp, zero], prob, variables, f"mv-{x}-{y}-0")
        if 0 < y < H - 1:
            tmp = construct_max(
                [
                    variables[f"mv-{x}-{y+1}-1-above"],
                    (y + 1)
                    / H
                    * lpSum(variables[f"density-{x}-{tt}"] for tt in range(y + 1, H)),
                ],
                prob,
                variables,
                f"mv-{x}-{y}-1-above",
            )
            mv_here = (
                y / H * lpSum(variables[f"density-{x}-{tt}"] for tt in range(y, H))
            )
            construct_max([mv_here - tmp, zero], prob, variables, f"mv-{x}-{y}-1")

        construct_max(
            [variables[f"mv-{x}-{y}-0"], variables[f"mv-{x}-{y}-1"]],
            prob,
            variables,
            f"mv-{x}-{y}-greedy",
        )

    greedy_res = lpSum(
        variables[f"mv-{x}-{y}-greedy"] for (x, y) in product(range(H), range(H))
    )
    snd_price_res_tiebreak_for_0 = lpSum(
        variables[f"mv-{x}-{y}-0"] for (x, y) in product(range(H), range(H)) if x >= y
    ) + lpSum(
        variables[f"mv-{x}-{y}-1"] for (x, y) in product(range(H), range(H)) if x < y
    )
    snd_price_res_tiebreak_for_1 = lpSum(
        variables[f"mv-{x}-{y}-0"] for (x, y) in product(range(H), range(H)) if x > y
    ) + lpSum(
        variables[f"mv-{x}-{y}-1"] for (x, y) in product(range(H), range(H)) if x <= y
    )

    prob += greedy_res >= alpha * snd_price_res_tiebreak_for_0
    prob += greedy_res >= alpha * snd_price_res_tiebreak_for_1

    prob.solve(
        GUROBI_CMD(
            msg=True,
        )
    )
    print(LpStatus[prob.status])
    print(value(prob.objective))
    print()
    for y in reversed(range(H)):
        res = []
        for x in range(H):
            res.append(value(variables[f"density-{x}-{y}"]))
        print(res)
    print()
    for y in reversed(range(H)):
        res = []
        for x in range(H):
            res.append(value(variables[f"mv-{x}-{y}-0"]))
        print(res)
    print()
    for y in reversed(range(H)):
        res = []
        for x in range(H):
            res.append(value(variables[f"mv-{x}-{y}-1"]))
        print(res)
    print()
    for y in reversed(range(H)):
        res = []
        for x in range(H):
            res.append(value(variables[f"mv-{x}-{y}-greedy"]))
        print(res)


# marginal_rev_bound(1.25, 100)
marginal_rev_bound(1.11, 10)
