import copy
import z3
import SM

enum_states = list(map(lambda s: s.name, SM.State))
State, verifier_states = z3.EnumSort("State", enum_states)
class SMVerifier():
# Define the possible transitions through the state machine
    transitions = []
    transitions_file = None

    # Algebraic datatype for z3, mirrors the Enum passed in
    state_type = None
    type_mapping = {}

    # Number of nodes participating in the protocol
    num_nodes = 0

    # Lost type, set to z3 type if it's possible for a message to be lost
    lost_types = None

    # Start types, set to z3 types that it is possible for the initial nodes state to be

    solver = z3.Solver()
    constraints = []
    pc = None

    mode = 0

    def __init__(self, states, num_nodes, mode="node_independent", threads=1): 
        """
        Creates a verifier instance.

        states is a python Enum. The internal z3 types are created from this Enum.
        It enumerates the possible states a node could be in.

        num_nodes is the number of nodes participating in the protocol. This parameter
        will have the greatest effect on runtime.

        threads is not yet implemented!
        """

        self.constraints = []
        self.num_nodes = num_nodes

        self.state_type = State
        self.type_mapping = dict(zip(enum_states, verifier_states))

        self.lost_types = None
        self.solver = z3.Solver()
        self.pc = None

        self.mode = 0
        if mode != "node_independent":
            self.mode = 1
            print("using node dependent mode")

    def __repr__(self):
        return f"num_nodes: {self.num_nodes}\nstate_type: {self.state_type}\ntransitions: {self.transitions}"

    def dump_variables(self, variables):
        starting = [variables[f"s_n{i}"] for i in range(self.num_nodes)]
        ending = [variables[f"f_n{i}"] for i in range(self.num_nodes)]
        decisions = [[variables[f"d{p}n{n}"] for n in range(self.num_nodes)] for p in range(self.num_phases-1)]
        messages = [[[variables[f"m{p}n{n}s{s}"] for s in range(self.num_nodes)] for n in range(self.num_nodes)] for p in range(self.num_phases)]

        print("")
        print("Starting states:")
        print(starting)
        print("")

        for i in range(self.num_phases):
            print("===============================")
            print(f"        phase {i+1}")
            print("===============================")
            print("message exchange:")
            print(messages[i])
            if i < len(decisions):
                print("decision:")
                print(decisions[i])
            else:
                print("Final decision:")
                print(ending)
            print("")

        
    
    def add_transitions(self, transitions):
        """
        Adds the transtions list into the verifier.

        transitions is of the form:
        
        [
            [["LocalCommit", ..., "LocalCommit"], "Commit"],
            [["LocalCommit", ..., "LocalAbort"], "Abort"],
            ...
        ]

        Where the state names are the names of the enum values originally provided to the verifier
        """

        t = copy.deepcopy(transitions)
        self.transitions = list(
                map(
                    lambda x: (
                        list(map(lambda y: self.type_mapping[y], x[0])),
                        self.type_mapping[x[1]]
                    ),
                    t
                )
        )

    def get_type(self, name):
        """
        Get the internal z3 type representing the state with name

        name is the string representation of the Enum value you created the verifier with.
        """
        return self.type_mapping[name]

    def get_lost_types(self):
        """
        Get the interal types that correspond to lost messages
        """
        return self.lost_types

    def get_start_types(self):
        """
        Get the interal types that correspond to the possible starting states
        """
        return self.start_types

    def protocol_constraints(self, pc):
        """
        Pass in any extra constraints to assert behavior of the protocol
        """
        self.pc = pc

    def phase(self, in_states, n):
        """
        A phase of transitions. For internal use.

        Adds contraints for transitioning the state machine.

        n identifies which phase this is.

        returns:
            the states of the nodes after the transitions
        """

        # decision portion
        # we map [[state], [state], ..., [state]] -> [state]
        # ie, each node makes its decision based on the transitions
        decision_states = [z3.Const(f"d{n}n{i}", self.state_type) for i in range(self.num_nodes)]
        for i, (in_state, decision_state) in enumerate(zip(in_states, decision_states)):
            for t in self.transitions:
                if self.mode == 0:
                    self.constraints.append(
                        z3.Implies(
                            z3.And([u == v for (u, v) in zip(t[0], in_state)]),
                            z3.And([decision_state == t[1]])
                        )
                    )
                else:
                    self.constraints.append(
                        z3.Implies(
                            z3.And([u == v for (u, v) in zip(t[0], [in_state[i]] + in_state)]),
                            z3.And([decision_state == t[1]])
                        )
                    )

        # messaging portion
        # disperse the decisions back to all of the nodes, adding the possibility of a lost message
        # for peers
        out_states = [[z3.Const(f"m{n+1}n{j}s{i}", self.state_type) for i in range(self.num_nodes)] for j in range(self.num_nodes)]

        for n_i in range(self.num_nodes):
            for n_j in range(self.num_nodes):
                if n_i != n_j:
                    self.constraints.append(
                        z3.Or(
                            out_states[n_j][n_i] == decision_states[n_i],
                            out_states[n_j][n_i] == self.get_lost_types()[n+1]
                        ) 
                    )
                else:
                    self.constraints.append(
                        out_states[n_j][n_i] == decision_states[n_i],
                    )

        return (out_states, decision_states)

    def assert_protocol(self, properties, phases, start_types, lost_types, early_decision_types, protocol_constraints=None):
        """
        Pass in the conditions for the protocol to be correct.

        Asserts the condition returned by the properties function. Pass in a function
        properties. This function needs to take 2 lists. The starting states of the protocol
        will be passed into the first parameter, and the final states will be passed into the
        second parameter.
        
        properties: [z3 variables for states] x [z3 variables for states] -> A z3 condition

        phases: how many phases to use, 2pc would be 2

        returns:
            A z3 condition that needs to be true for the protocol to be correct
        """
        self.lost_types = lost_types
        self.start_types = start_types
        self.early_decision_types = early_decision_types
        self.num_phases = phases

        start = [z3.Const(f"s_n{i}", self.state_type) for i in range(self.num_nodes)]
        self.constraints.append(
            z3.And([z3.Or([s == t for t in self.get_start_types()]) for s in start])
        )

        curr_messages = None
        curr_decisions = start
        all_messages = []
        all_decisions = []

        out_states = [[z3.Const(f"m0n{j}s{i}", self.state_type) for i in range(self.num_nodes)] for j in range(self.num_nodes)]

        for n_i in range(self.num_nodes):
            for n_j in range(self.num_nodes):
                if n_i != n_j:
                    self.constraints.append(
                        z3.Or(
                            out_states[n_j][n_i] == start[n_i],
                            out_states[n_j][n_i] == self.get_lost_types()[0]
                        ) 
                    )
                else:
                    self.constraints.append(
                        out_states[n_j][n_i] == start[n_i],
                    )


        curr_messages = out_states
        all_messages.append(curr_messages)
        
        for pn in range(phases-1):
            all_decisions.append(curr_decisions)
            (curr_messages, curr_decisions) = self.phase(curr_messages, pn)
            all_messages.append(curr_messages)

        all_decisions.append(curr_decisions)
        # The final output state of each node based on the states in the final phase
        end_states = [z3.Const(f"f_n{i}", self.state_type) for i in range(self.num_nodes)]

        for i, (curr_message, end_state) in enumerate(zip(curr_messages, end_states)):
            for t in self.transitions:
                if self.mode == 0:
                    self.constraints.append(
                        z3.Implies(
                            z3.And([u == v for (u, v) in zip(t[0], curr_message)]),
                            z3.And([end_state == t[1]])
                        )
                    )
                else:
                    self.constraints.append(
                        z3.Implies(
                            z3.And([u == v for (u, v) in zip(t[0], [curr_message[i]] + curr_message)]),
                            z3.And([end_state == t[1]])
                        )
                    )
        
        all_decisions.append(end_states)
        self.solver.add(z3.And(self.constraints))
        if protocol_constraints is not None:
            self.solver.add(protocol_constraints(all_decisions, all_messages, self))
        self.solver.add(z3.Not(properties(start, end_states, all_decisions, all_messages, self)))

    def verify(self):
        """
        Check the model created using z3.

        This should be the last function that is called. It depends on the constraints
        added by the other methods of this class.

        returns:
            true if the model is correct,
            false if the model is incorrect
        """

        print("Checking model! This might take awhile...")
        result = self.solver.check()
        if result == z3.sat:
            print("system is sat, so there ARE inputs that violate correctness:")
            m = self.solver.model()
            ks = {d.name(): m[d] for d in m.decls()}
            self.dump_variables(ks)
            print("protocol is INCORRECT")
            print("Solver statistics:")
            print(self.solver.statistics())
            return False
        elif result == z3.unsat:
            print("system is unsat, so there are no inputs that violate correctness")
            print("protocol is CORRECT")
            print("Solver statistics:")
            print(self.solver.statistics())
            return True
        else:
            print("unknown solution: couldn't identify whether there was a solution")
            print("Solver statistics:")
            print(self.solver.statistics())
            return False

    def verify_nolog(self):
        result = self.solver.check()
        if result == z3.sat:
            return False
        elif result == z3.unsat:
            return True
        else:
            print("unknown solution: couldn't identify whether there was a solution")
            print("Solver statistics:")
            print(self.solver.statistics())
            return False
