import z3
import itertools
import argparse
import State, verifier


def read_from_transitions_file(path, mode=0):
    transitions = []
    state_mapping = {f"{s.value}": s.name for s in State.State}
    if mode == 0:
        with open(path, "r") as f:
            lines = f.readlines()
            for line in lines:
                edge = line.strip().split(":")
                inputs = edge[0].split(",")
                comb = list(map(lambda x: state_mapping[x], inputs))
                transitions.append([comb, state_mapping[edge[1]]])
    else:
        with open(path, "r") as f:
            lines = f.readlines()
            for line in lines:
                edge = line.strip().split(":")
                inputs = edge[0].split(",")
                comb = [int(inputs[0])] + list(map(lambda x: state_mapping[x], inputs[1:]))
                transitions.append([comb, state_mapping[edge[1]]])
    return transitions

"""
Properties of the atomic commit protocol we want to enforce.
start and end are lists of z3 variables of equal length.

start is the list of z3 variables representing the initial state of the nodes

end is the list of z3 variables representing the final state of the nodes
"""
"""
1. if more than one node Enters the CS it is incorrect
2. if a node that does not need the CS, Enters, it is incorrect
(Optional) 3. if at least one node needs the CS, but none of them Enter, it is incorrect (no progress)
"""
def distributed_locking(start_states, end_states, all_decisions, all_messages, ver):

    prop1_constr = []
    for i in range(ver.num_nodes):
        for j in range(ver.num_nodes):
            if i != j:
                prop1_constr.append(z3.Implies(end_states[i] == ver.get_type("Enter"), end_states[j] != ver.get_type("Enter")))
    prop1 = z3.And(prop1_constr)

    prop2_constr = []
    for i in range(ver.num_nodes):
        prop2_constr.append(z3.Implies(start_states[i] == ver.get_type("NoNeed"), end_states[i] != ver.get_type("Enter")))
    prop2 = z3.And(prop2_constr)

    lock_is_needed = z3.Or([start_states[i] == ver.get_type("Need") for i in range(ver.num_nodes)])
    lock_is_acquired = z3.Or([end_states[i] == ver.get_type("Enter") for i in range(ver.num_nodes)])
    prop3 = z3.Implies(lock_is_needed, lock_is_acquired)

    return z3.And(prop1, prop2, prop3)

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser(description="Verify the learned model")
    parser.add_argument("-t", default="transitions.txt", type=str, help="path to the transitions list")
    parser.add_argument("-n", default=3, type=int, help="number of nodes the model was trained for")
    parser.add_argument("-r", default=1, type=int, help="how many rounds does the model use")
    args = parser.parse_args()

    # Create a verifier instance, using the SM.State Enum
    # as the states in the state machine, with 5 participating
    # nodes
    ver = verifier.SMVerifier(State.State, args.n, mode="node_dependent")
    # Add the transitions to the verifer
    print("Adding transitions")
    ver.add_transitions(read_from_transitions_file(args.t, mode=1))
    # Add the properties of the protocol we want to ensure are satisfied, and how many phases will take place
    print("Asserting protocol")
    ver.assert_protocol(
            distributed_locking, # function that asserts properties of the protocol
            args.r, # number of rounds
            [ver.get_type("Need"), ver.get_type("NoNeed")], # possible states for nodes to start in
            None, # types that represent lost messages
            [ver.get_type("Enter"), ver.get_type("NoEnter")]
    )

    # Verify the protocol, this can take awhile depending on the complexity of the protocol and number of nodes
    print("Verifying")
    ver.verify()
