import z3
import itertools

import SM, verifier

num_nodes = 5

# Create a verifier instance, using the SM.State Enum
# as the states in the state machine, with 5 participating
# nodes
ver = verifier.SMVerifier(SM.State, 5)


"""
Purely for example
Defines a set of transitions that is trivially correct for
the atomic commmit protocol. This would be replaced with an
exaustive list of evaluations of the model.
Returns an exhaustive list of possible transitions of the form
[
    [["LocalCommit", ..., "LocalCommit"], "Commit"],
    ...
]
"""
def dummy_correct_model(n):
    transitions = []
    node_states = [list(map(lambda x: x.name, SM.State))] * n
    for t in itertools.product(*node_states):
        comb = list(t)
        if all(map(lambda x: x == SM.State.Commit.name or x == SM.State.LocalCommit.name, comb)):
            transitions.append([comb, SM.State.Commit.name])
        else:
            transitions.append([comb, SM.State.Abort.name])
    return transitions

# Inverse of the dummy correct model, always wrong
def dummy_incorrect_model(n):
    transitions = []
    node_states = [list(map(lambda x: x.name, SM.State))] * n
    for t in itertools.product(*node_states):
        comb = list(t)
        if all(map(lambda x: x == SM.State.Commit.name or x == SM.State.LocalCommit.name, comb)):
            transitions.append([comb, SM.State.Abort.name])
        else:
            transitions.append([comb, SM.State.Commit.name])
    return transitions


"""
Properties of the atomic commit protocol we want to enforce.
start and end are lists of z3 variables of equal length.

start is the list of z3 variables representing the initial state of the nodes

end is the list of z3 variables representing the final state of the nodes
"""
def atomic_commit(start_states, end_states):
    return z3.And(
        [
            z3.And(
                z3.Implies(
                    z3.And([u == ver.get_type("LocalCommit") for u in start]),
                    z3.And([u == ver.get_type("Commit") for u in end_states])
                ),
                z3.Implies(
                    z3.Or([u == ver.get_type("Abort") for u in start]),
                    z3.And([u == ver.get_type("Abort") for u in end_states])
                )
            )
            for start in start_states
        ]
    )

# Add the transitions to the verifer
print("Adding transitions")
ver.add_transitions(dummy_correct_model(num_nodes))

# Add the properties of the protocol we want to ensure are satisfied, and how many phases will take place
print("Asserting protocol")
ver.assert_protocol(atomic_commit, 2, ver.get_type("Lost"))

# Verify the protocol, this can take awhile depending on the complexity of the protocol and number of nodes
print("Verifying")
ver.verify()
