import datetime
from zoneinfo import ZoneInfo
from pulp import (
    LpVariable,
    lpSum,
)


def cap_up_and_down(val, lowBound=None, upBound=None):
    if lowBound is not None:
        val = max(val, lowBound)
    if upBound is not None:
        val = min(val, upBound)
    return val


def tuple_update_by_index(tup, index, val):
    res = list(tup)
    res[index] = val
    return tuple(res)


def get_timestamp():
    return datetime.datetime.now(ZoneInfo("US/Eastern")).strftime(
        "%b_%d_%Hhr_%Mmin_%Ssec"
    )


def add_profiles(profiles, new_profiles, tolerance=0.01):
    for np in new_profiles:
        profiles = add_single_profile(profiles, np, tolerance)
    return profiles


def add_single_profile(profiles, new_profile, tolerance=0.01):
    def profile_distance(p1, p2):
        return sum([abs(p1[i] - p2[i]) for i in range(len(p1))])

    res = []
    for p in profiles:
        if profile_distance(p, new_profile) > tolerance:
            res.append(p)
    return res + [new_profile]


def construct_relu_mlp_outputs(mlp, inputs, prob, variables, unique_label):
    M = 1000  # big constant for MIP -- its value seems to have implication on accuracy

    # print(mlp.shapes)
    assert len(inputs) == mlp.shapes[0]
    weights = []
    biases = []
    for i in range(0, len(mlp.linear_relu_stack), 2):
        weights.append(mlp.linear_relu_stack[i].weight.data.cpu().numpy())
        biases.append(mlp.linear_relu_stack[i].bias.data.cpu().numpy())
    # print(weights)
    # print(biases)
    for i, weight in enumerate(weights):
        assert weight.shape[0] == mlp.shapes[i + 1]
        assert weight.shape[1] == mlp.shapes[i]
    for i, bias in enumerate(biases):
        assert bias.shape[0] == mlp.shapes[i + 1]

    for i, size in enumerate(mlp.shapes):
        for j in range(size):
            # For all layers except for the last layer, nodes have nonnegtive values after relu
            if i < len(mlp.shapes) - 1:
                variables[(unique_label, i, j, "nonneg")] = LpVariable(
                    f"{unique_label},{i},{j},nonneg"
                )

            # For the first layer, the bids are nonnegative already
            if i == 0:
                prob += inputs[j] == variables[(unique_label, i, j, "nonneg")]

            # Except for the first and last layer, need ReLU binary variables
            if 0 < i < len(mlp.shapes) - 1:
                variables[(unique_label, i, j, "binary")] = LpVariable(
                    f"{unique_label},{i},{j},binary", cat="Binary"
                )

            # Except for the first layer, before ReLU, the node values may be negative
            if i > 0:
                variables[(unique_label, i, j)] = LpVariable(f"{unique_label},{i},{j}")

    for i in range(1, len(mlp.shapes) - 1):
        for j in range(mlp.shapes[i]):
            prob += (
                variables[(unique_label, i, j, "nonneg")]
                >= variables[(unique_label, i, j)]
            )
            prob += variables[(unique_label, i, j, "nonneg")] >= 0
            prob += (
                variables[(unique_label, i, j, "nonneg")]
                <= variables[(unique_label, i, j)]
                + M * variables[(unique_label, i, j, "binary")]
            )
            prob += variables[(unique_label, i, j, "nonneg")] <= M * (
                1 - variables[(unique_label, i, j, "binary")]
            )

    for i in range(1, len(mlp.shapes)):
        for j in range(mlp.shapes[i]):
            prob += (
                variables[(unique_label, i, j)]
                == lpSum(
                    [
                        variables[(unique_label, i - 1, k, "nonneg")]
                        * weights[i - 1][j][k]
                        for k in range(mlp.shapes[i - 1])
                        if abs(weights[i - 1][j][k]) > 0.000001
                    ]
                )
                + biases[i - 1][j]
            )

    return [
        variables[(unique_label, len(mlp.shapes) - 1, j)] for j in range(mlp.shapes[-1])
    ]
