import z3
import itertools
import argparse
import SM, verifier


def read_from_transitions_file(path):
    transitions = []
    state_mapping = {f"{s.value}": s.name for s in SM.State}
    with open(path, "r") as f:
        lines = f.readlines()
        for line in lines:
            edge = line.strip().split(":")
            inputs = edge[0].split(",")
            comb = list(map(lambda x: state_mapping[x], inputs))
            transitions.append([comb, state_mapping[edge[1]]])

    return transitions

"""
Properties of the atomic commit protocol we want to enforce.
start and end are lists of z3 variables of equal length.

start is the list of z3 variables representing the initial state of the nodes

end is the list of z3 variables representing the final state of the nodes
"""
"""
1. if all initial states are commit, then all the final states must be commit
2. if at least one of the initial states are Abort, then all the final states should be Abort
3. if there is a lost message in one of the phases, then both abort and commit do not show up together in the end states
"""
def atomic_commit(start_states, end_states, all_decisions, all_messages, ver):

    # Does a lost message appear somewhere in the message exchange portions
    lost_message = z3.Or(
        [
            z3.Or(
                        [
                            z3.Or(
                                [
                                    z3.Or([state == lt for lt in ver.get_lost_types()])
                                    for state in node
                                ]
                            )
                            for node in phase
                        ]
            )
            for phase in all_messages
        ]
    )

    prop1 = z3.Implies(
            # if there is NOT a lost message
            z3.Not(lost_message),
            # then all local commits should result in all commits
            z3.Implies(
                z3.And([u == ver.get_type("LocalCommit") for u in start_states]),
                z3.And([z3.Or([d[n] == ver.get_type("Commit") for d in all_decisions]) for n in range(ver.num_nodes)])
            )
    )

    prop2 = z3.Implies(
            # if there is NOT a lost message
            z3.Not(lost_message),
            # then a single local abort should result in abort
            z3.Implies(
                z3.Or([u == ver.get_type("LocalAbort") for u in start_states]),
                z3.And([z3.Or([d[n] == ver.get_type("Abort") for d in all_decisions]) for n in range(ver.num_nodes)])
            )
    )

    prop3 = z3.Implies(
            # if there is a lost message
            lost_message,
            # then both abort and commit do not show up in the end states
            z3.Not(
                z3.And(
                    z3.Or([z3.Or([d[n] == ver.get_type("Commit") for d in all_decisions]) for n in range(ver.num_nodes)]),
                    z3.Or([z3.Or([d[n] == ver.get_type("Abort") for d in all_decisions]) for n in range(ver.num_nodes)])
                )
            )
    )

    prop4 = z3.Not(
        z3.And(
            z3.Or([z3.Or([n == ver.get_type("Commit") for n in p]) for p in all_decisions]),
            z3.Or([z3.Or([n == ver.get_type("Abort") for n in p]) for p in all_decisions])
        )
    )

    # if the start state is LocalAbort, not of the decisions should be Commit
    prop5 = z3.And(
            [
                z3.Implies(
                    start == ver.get_type("LocalAbort"),
                    z3.Not(
                        z3.Or([d[n] == ver.get_type("Commit") for d in all_decisions])
                    )
                )
                for n, start in enumerate(start_states)
            ]
    )

    # each nodes comes to a maximum of one decision
    prop6 = z3.And(
        [
            z3.Sum([z3.If(z3.Or(d[n] == ver.get_type("Abort"), d[n] == ver.get_type("Commit")), 1, 0) for d in all_decisions]) < 2
            for n in range(ver.num_nodes)
        ]
    )

    return z3.And(prop1, prop2, prop3, prop4, prop5, prop6)

def extra_constraints(all_decisions, all_messages, ver):

    # No Recover
    # if a node losses a message, all messages after that must be lost
    no_recover = z3.And(
        [
            z3.Implies(
                z3.Or([all_messages[0][m][n] == ver.get_type("Lost_R1") for m in range(ver.num_nodes)]),
                z3.And([all_messages[1][m][n] == ver.get_type("Lost_R2") if m != n else True for m in range(ver.num_nodes)]),
            )
            for n in range(ver.num_nodes)
        ]
    )


    # Exceptions for impossible states
    # Decision states and DoNothing cannot be mixed (unless there is a lost message)
    impossible = z3.And(
        [
            z3.Implies(
                z3.And(
                    z3.Or([z3.Or(phase[n] == ver.get_type("Commit"), phase[n] == ver.get_type("Abort")) for n in range(ver.num_nodes)]),
                    z3.And([phase[n] != ver.get_type("Lost_R1") for n in range(ver.num_nodes)])
                ),
                z3.And([phase[n] != ver.get_type("DoNothingCommit") for n in range(ver.num_nodes)])
            )
            for phase in all_decisions[1:1] # just the first round
        ]
    )

    return z3.And(no_recover, impossible)

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser(description="Verify the learned model")
    parser.add_argument("-t", default="transitions.txt", type=str, help="path to the transitions list")
    parser.add_argument("-n", default=3, type=int, help="number of nodes the model was trained for")
    parser.add_argument("-r", default=2, type=int, help="how many rounds does the model use")
    args = parser.parse_args()

    # Create a verifier instance, using the SM.State Enum
    # as the states in the state machine, with 5 participating
    # nodes
    ver = verifier.SMVerifier(SM.State, args.n)
    # Add the transitions to the verifer
    print("Adding transitions")
    ver.add_transitions(read_from_transitions_file(args.t))
    # Add the properties of the protocol we want to ensure are satisfied, and how many phases will take place
    print("Asserting protocol")
    ver.assert_protocol(
            atomic_commit, # function that asserts properties of the protocol
            args.r, # number of rounds
            [ver.get_type("LocalCommit"), ver.get_type("LocalAbort")], # possible states for nodes to start in
            [ver.get_type("Lost_R1"), ver.get_type("Lost_R2")], # types that represent lost messages
            [ver.get_type("Commit"), ver.get_type("Abort")],
            protocol_constraints=extra_constraints
    )

    # Verify the protocol, this can take awhile depending on the complexity of the protocol and number of nodes
    print("Verifying")
    ver.verify()
