#ifndef VERIFIER_H
#define VERIFIER_H

#include <z3++.h>

#include <unordered_map>

#include "properties.h"
#include "states.h"
#include "transition.h"

class SMVerifier {
   public:
    z3::context context;

   private:
    int num_nodes;
    int num_rounds;
    // flag to indicate whether the node id and round number are encoded in the input
    bool encode_id = false;
    bool encode_round = false;
    // flag to indicate whether the history is encoded in the input, no command line argument so far
    // By default, it is false
    bool is_history_encoded = false;
    std::string protocol_name;
    std::vector<TransitionZ3Ptr> my_transitions;

    z3::solver solver;
    z3::sort state_sort;
    z3::expr_vector all_types;
    z3::expr_vector constrains;
    z3::expr_vector start_types;
    z3::expr_vector lost_types;
    z3::expr_vector decision_types;
    // key: name of state; value: idx of the state in *all_types*
    // e.g. "Commit": 1
    std::unordered_map<std::string, int> name_value_mapping;
    // key: idx in string; value: name of state
    // e.g: "1": "Commit"
    std::unordered_map<std::string, std::string> value_name_mapping;

    /*
        all messages that have been exchanged
        (std::vector<z3::expr_vector>) all_messages[i] : all messages of round i
        (z3::expr_vector) all_messages[i][j]: round i of node j, j = [0, n-1]
        (z3::expr) all_messages[i][j][k]: round i, message node j received from node k
    */
    std::vector<std::vector<z3::expr_vector>> all_messages;

    /*
        all desicions that have been made
        (z3::expr_vector) all_decisions[i] : all decisions of round i
        (z3::expr) all_decisions[i][j]: decision at round i of node j, j = [0, n-1]
    */
    std::vector<z3::expr_vector> all_decisions;

    /*
        Protocol mapping to support different protocols
    */
    std::unordered_map<std::string, PropertyFn> property_fn_mapping;
    std::unordered_map<std::string, ProtocolConstraintFn> protocol_constraint_fn_mapping;

   public:
    SMVerifier(const std::vector<std::string>& states, std::string& protocol, int num_node, int num_rounds,
               bool encode_id, bool encode_round);
    ~SMVerifier() = default;

    /*
     mode = 1: Encode node id at the beginning of input
    */
    std::vector<TransitionPtr> read_from_transitions_file(const std::string& path);

    void add_transitions(const std::vector<TransitionPtr>& transitions);

    void initialize_states(State state);

    void remove_transitions();

    // Simulate the whole protocol
    void run_protocol(State state);

    // Run solver
    bool verify();

    // Output multiple counter examples
    // cnt: number of counter examples
    bool verify_with_multiple_cs(int cnt);

    // Add counter example constrains into solver
    void add_cs_constrains(const z3::model& m);

    z3::context& get_context() { return context; }

    int get_num_nodes() { return num_nodes; }

    z3::expr get_type(const std::string& name);

    z3::expr_vector get_lost_types() { return lost_types; }

    int get_num_rounds() { return num_rounds; }

    std::string get_protocol_name() { return protocol_name; }

   private:
    void build_state_mapping(const std::vector<std::string>& states);

    void build_protocol_mapping();

    PropertyFn get_property_fn(const std::string& funcName);

    ProtocolConstraintFn get_protocol_contraint_fn(const std::string& funcName);

    z3::expr_vector start();

    // Simulate message exchange for one round
    std::vector<z3::expr_vector> message_exchange(int round, const z3::expr_vector& prev_states);

    // Take actions based on current state and transition lists
    z3::expr_vector take_action(const std::vector<z3::expr_vector>& curr_states,
                                const std::vector<z3::expr_vector>& prev_states, int round);

    z3::expr_vector create_states(std::string prefix_name);
};

#endif