import z3
import itertools

num_nodes = 5

State = z3.Datatype("State")
State.declare("Commit")
State.declare("Abort")
State.declare("DoNothingCommit")
State.declare("DoNothingAbort")
State.declare("Lost")
State.declare("LocalAbort")
State.declare("LocalCommit")
State.declare("Dummy")
State = State.create()

solver = z3.Solver()

states = [State.Commit, State.Abort, State.DoNothingCommit, State.DoNothingAbort, State.Lost, State.LocalAbort, State.LocalCommit, State.Dummy]

node_states = [states] * num_nodes

transitions = []

for t in itertools.product(*node_states):
    comb = list(t)
    if all(map(lambda x: x == State.Commit or x == State.LocalCommit, comb)):
        transitions.append([comb, State.Commit])
    else:
        transitions.append([comb, State.Abort])


start = [z3.Const(f"i{s}", State) for s in range(num_nodes)]


def phase(in_states, n):
    out_states = [z3.Const(f"p{n}s{i}", State) for i in range(num_nodes)]
    for t in transitions:
        solver.add(
            z3.Implies(
                z3.And([u == v for (u, v) in zip(t[0], in_states)]),
                z3.And([u == t[1] for u in out_states])
            )
        )
    return out_states
    

curr_states = start
for phase_num in range(2):
    curr_states = phase(curr_states, phase_num)

final_states = curr_states

"""
The properties of a correct atomic commit protocol

The follow must be true:

* At least one node locally aborting implies every node aborts
    AND
* Every node locally committing implies every node commits

To prove our solution correct, we want to show that it is impossible
to find a solution that violates these properties, ie a proof by contradiction
"""
solver.add(
    z3.Not(
        z3.And(
            z3.Implies(z3.And([u == State.LocalCommit for u in start]), z3.And([u == State.Commit for u in final_states])),
            z3.Implies(z3.Or([u == State.Abort for u in start]), z3.And([u == State.Abort for u in final_states]))
        )
    )
)


print("Checking model! This might take awhile...")

if solver.check() == z3.sat:
    print("system is sat, so there ARE inputs that violate correctness:")
    print(solver.model())
    print("protocol is INCORRECT")
    print("Solver statistics:")
    print(solver.statistics())
else:
    print("system is unsat, so there are no inputs that violate correctness")
    print("protocol is CORRECT")
    print("Solver statistics:")
    print(solver.statistics())
