import z3

num_nodes = 5

State = z3.Datatype("State")
State.declare("LocalCommit")
State.declare("Abort")
State.declare("Commit")
State.declare("DoNothing")
State.declare("Lost")
State = State.create()

# solver will accumulate all of the constraints of the protocol
# then at the end the solver will be checked
solver = z3.Solver()

# broadcast represents the messaging being done.
# We introduce extra contraints to represent lost
# messages or other changes in state due to communication
def broadcast(init_states, phase):
    states = []
    for i in range(num_nodes):
        local_knowledge = []
        for j in range(num_nodes):
            v = z3.Const("b{}x{}y{}".format(phase, i, j), State)
            solver.add(z3.Or(v == State.LocalCommit, v == State.Abort))
            solver.add(init_states[j] == v)
            local_knowledge.append(v)
        states.append(local_knowledge)
    return states

# The DoNothing/Abort decision
def phase1(p1_states):
    p2_states = []
    for i in range(num_nodes):
        local_knowledge = p1_states[i]
        v = z3.Const("p1x{}".format(i), State)
        solver.add(z3.Or(v == State.DoNothing, v == State.Abort))
        solver.add(
            z3.Implies(z3.And([u == State.LocalCommit for u in local_knowledge]), v == State.DoNothing)
        )
        solver.add(
            z3.Implies(z3.Or([u == State.Abort for u in local_knowledge]), v == State.Abort)
        )
        p2_states.append(v)
    return p2_states

# The Commit/Abort decision
def phase2(p2_states):
    final_states = []
    for i in range(num_nodes):
        local_knowledge = p2_states[i]
        v = z3.Const("p2x{}".format(i), State)
        solver.add(z3.Or(v == State.Commit, v == State.Abort))
        solver.add(
            z3.Implies(z3.And([u == State.DoNothing for u in local_knowledge]), v == State.Commit)
        )
        solver.add(
            z3.Implies(z3.Or([u == State.Abort for u in local_knowledge]), v == State.Abort)
        )
        final_states.append(v)
    return final_states


def sol1_2pc():
    start = [z3.Const("p0x{}".format(i), State) for i in range(num_nodes)]
    for s in start:
        solver.add(z3.Or(s == State.LocalCommit, s == State.Abort))

    return start, phase2(broadcast(phase1(broadcast(start, 1)), 2))



start, final_states = sol1_2pc()

"""
The properties of a correct atomic commit protocol

The follow must be true:

* At least one node locally aborting implies every node aborts
    AND
* Every node locally committing implies every node commits

To prove our solution correct, we want to show that it is impossible
to find a solution that violates these properties, ie a proof by contradiction
"""
solver.add(
    z3.Not(
        z3.And(
            z3.Implies(z3.And([u == State.LocalCommit for u in start]), z3.And([u == State.Commit for u in final_states])),
            z3.Implies(z3.Or([u == State.Abort for u in start]), z3.And([u == State.Abort for u in final_states]))
        )
    )
)

if solver.check() == z3.sat:
    print("system is sat, so there ARE inputs that violate correctness:")
    print(solver.model())
    print("protocol is INCORRECT")
else:
    print("system is unsat, so there are no inputs that violate correctness")
    print("protocol is CORRECT")
