#include "verifier.h"

#include <fstream>
#include <iostream>

#include "properties.h"
#include "util.h"
#include "z3_util.h"

void SMVerifier::build_state_mapping(const std::vector<std::string>& states) {
    for (int i = 0; i < states.size(); ++i) {
        name_value_mapping[states[i]] = i;
        value_name_mapping[std::to_string(i)] = states[i];
    }
}

void SMVerifier::build_protocol_mapping() {
    property_fn_mapping["distributed_locking"] = distributed_locking;
    property_fn_mapping["atomic_commit"] = atomic_commit;
    protocol_constraint_fn_mapping["atomic_commit"] = atomic_commit_protocol_constraints;
    property_fn_mapping["primary_backup"] = primary_backup;
    protocol_constraint_fn_mapping["primary_backup"] = primary_backup_protocol_constraints;
}

SMVerifier::SMVerifier(const std::vector<std::string>& states, std::string& protocol, int num_node, int num_rounds,
                       bool encode_id, bool encode_round)
    : protocol_name(protocol),
      num_nodes(num_node),
      num_rounds(num_rounds),
      encode_id(encode_id),
      encode_round(encode_round),
      context(),
      solver(z3::solver(context)),
      state_sort(z3::sort(context)),
      all_types(z3::expr_vector(context)),
      constrains(z3::expr_vector(context)),
      start_types(z3::expr_vector(context)),
      lost_types(z3::expr_vector(context)),
      decision_types(z3::expr_vector(context)) {
    state_sort = make_enum_sort(states, context, all_types, "State");

    build_state_mapping(states);

    build_protocol_mapping();
}

std::vector<TransitionPtr> SMVerifier::read_from_transitions_file(const std::string& path) {
    std::vector<TransitionPtr> transitions;
    std::ifstream file(path);
    std::string str;
    std::cout << encode_id << std::endl;
    while (std::getline(file, str)) {
        std::vector<std::string> edge = split_string(str, ':');
        std::vector<std::string> inputs = split_string(edge[0], ',');
        std::vector<std::string> comb;
        int id = -1;
        int round = -1;
        if (encode_id && encode_round) {
            round = std::stoi(inputs[0]);
            id = std::stoi(inputs[1]);
            for (int i = 2; i < inputs.size(); ++i) {
                comb.push_back(value_name_mapping[inputs[i]]);
            }
        } else if (encode_id) {
            id = std::stoi(inputs[0]);
            for (int i = 1; i < inputs.size(); ++i) {
                comb.push_back(value_name_mapping[inputs[i]]);
            }
        } else if (encode_round) {
            round = std::stoi(inputs[0]);
            for (int i = 1; i < inputs.size(); ++i) {
                comb.push_back(value_name_mapping[inputs[i]]);
            }
        } else {
            for (const auto& input : inputs) {
                comb.push_back(value_name_mapping[input]);
            }
        }
        std::string action = value_name_mapping[edge[1]];

        TransitionPtr transition = std::make_shared<Transition>(comb, action, id, round);
        transitions.push_back(transition);
    }
    return transitions;
}

void SMVerifier::add_transitions(const std::vector<TransitionPtr>& transitions) {
    for (auto tran : transitions) {
        std::shared_ptr<TransitionZ3> tran_z3 = std::make_shared<TransitionZ3>(context);
        for (auto s : tran->combs) {
            tran_z3->combs.push_back(get_type(s));
        }
        tran_z3->action = get_type(tran->edge);
        if (encode_id) {
            tran_z3->id = tran->id;
        }
        if (encode_round) {
            tran_z3->round = tran->round;
        }
        my_transitions.push_back(tran_z3);
    }
}

bool SMVerifier::verify() {
    auto result = solver.check();
    if (result == z3::sat) {
        std::cout << "system is sat, so there ARE inputs that violate correctness:" << std::endl;
        std::cout << "protocol is INCORRECT" << std::endl;
        z3::model model = solver.get_model();
        print_model(model, num_nodes, num_rounds, context);
        std::cout << "Solver statistics:" << std::endl;
        std::cout << solver.statistics() << std::endl;
        return false;
    } else if (result == z3::unsat) {
        std::cout << "system is unsat, so there are no inputs that violate correctness" << std::endl;
        std::cout << "protocol is CORRECT" << std::endl;
        return true;
    } else {
        std::cout << "unknown solution: couldn't identify whether there was a solution" << std::endl;
        return false;
    }
}

bool SMVerifier::verify_with_multiple_cs(int cnt) {
    int cs_cnt = 0;
    z3::check_result result;
    while (cs_cnt < cnt && (result = solver.check()) == z3::sat) {
        z3::model model = solver.get_model();
        print_model(model, num_nodes, num_rounds, context);
        add_cs_constrains(model);
        cs_cnt++;
    }
    if (result == z3::unsat) {
        std::cout << "system is unsat, so there are no inputs that violate correctness" << std::endl;
        std::cout << "protocol is CORRECT" << std::endl;
        return true;
    } else {
        return false;
    }
}

void SMVerifier::add_cs_constrains(const z3::model& m) {
    z3::expr_vector new_constrains(context);
    for (int i = 0; i < m.size(); i++) {
        z3::func_decl v = m[i];
        new_constrains.push_back(v() != m.get_const_interp(v));
    }
    solver.add(z3::mk_or(new_constrains));
}

void SMVerifier::run_protocol(State state) {
    // Push type expression to corresponding expr_vector
    std::cout << "Initialize states" << std::endl;
    initialize_states(state);

    // Create start states
    std::cout << "Create start states" << std::endl;
    z3::expr_vector start_states = start();

    std::cout << "Simulate the whole protocol" << std::endl;
    z3::expr_vector prev_actions = start_states;
    std::vector<z3::expr_vector> prev_states;
    for (int i = 0; i < num_rounds; ++i) {
        std::vector<z3::expr_vector> out_states = message_exchange(i, prev_actions);
        all_messages.push_back(out_states);
        z3::expr_vector action = take_action(out_states, prev_states, i);
        all_decisions.push_back(action);
        prev_actions = action;
        prev_states = out_states;
    }
    z3::expr_vector end_states = prev_actions;

    solver.add(z3::mk_and(constrains));

    std::cout << "Add protocol constraints" << std::endl;
    PropertyFn properties = get_property_fn(protocol_name);
    if (properties) {
        solver.add(!properties(start_states, end_states, all_decisions, all_messages, *this));
    }
    ProtocolConstraintFn protocol_constraints = get_protocol_contraint_fn(protocol_name);
    if (protocol_constraints) {
        solver.add(protocol_constraints(all_decisions, all_messages, *this));
    }
}

void SMVerifier::initialize_states(State state) {
    // start types
    for (auto s : state.initial) start_types.push_back(get_type(s));

    // lost types
    for (auto s : state.lost) lost_types.push_back(get_type(s));

    // decision types
    for (auto s : state.decision) decision_types.push_back(get_type(s));
}

z3::expr_vector SMVerifier::start() {
    z3::expr_vector start = create_states("s_n");

    // Add constrains for start states
    z3::expr_vector and_expr(context);
    for (int i = 0; i < start.size(); ++i) {
        z3::expr s = start[i];
        z3::expr_vector or_expr(context);
        for (int j = 0; j < start_types.size(); ++j) {
            z3::expr conj = s == start_types[j];
            or_expr.push_back(conj);
        }
        and_expr.push_back(z3::mk_or(or_expr));
    }
    constrains.push_back(z3::mk_and(and_expr));
    return start;
}

std::vector<z3::expr_vector> SMVerifier::message_exchange(int round, const z3::expr_vector& prev_actions) {
    std::vector<z3::expr_vector> out_states;
    for (int i = 0; i < num_nodes; ++i) {
        std::stringstream name;
        name << "m" << round << "n" << i << "s";
        out_states.push_back(create_states(name.str()));
    }

    for (int i = 0; i < num_nodes; ++i) {
        for (int j = 0; j < num_nodes; ++j) {
            z3::expr_vector or_expr(context);
            if (i != j && lost_types.size() > 0) {
                or_expr.push_back(out_states[j][i] == prev_actions[i]);
                if (lost_types.size() > 1) {
                    or_expr.push_back(out_states[j][i] == lost_types[round]);
                } else {
                    or_expr.push_back(out_states[j][i] == lost_types[0]);
                }
            } else {
                or_expr.push_back(out_states[j][i] == prev_actions[i]);
            }
            constrains.push_back(z3::mk_or(or_expr));
        }
    }

    return out_states;
}

z3::expr_vector SMVerifier::take_action(const std::vector<z3::expr_vector>& curr_states,
                                        const std::vector<z3::expr_vector>& prev_states, int round) {
    z3::expr_vector decision(context);
    for (int i = 0; i < num_nodes; ++i) {
        std::stringstream name;
        // decision at round n of node i
        name << "d" << round << "n" << i;
        decision.push_back(context.constant(name.str().c_str(), state_sort));
    }

    for (int i = 0; i < num_nodes; ++i) {
        for (int j = 0; j < my_transitions.size(); ++j) {
            TransitionZ3Ptr tran = my_transitions[j];

            // match input states to transition lists
            z3::expr_vector match_states(context);
            if (encode_id) match_states.push_back(tran->id == context.int_val(i));
            if (encode_round) match_states.push_back(tran->round == context.int_val(round));
            for (int k = 0; k < curr_states[i].size(); ++k) {
                if (is_history_encoded) {
                    match_states.push_back(tran->combs[k * 2] == curr_states[i][k]);
                    if (prev_states.empty()) {
                        match_states.push_back(tran->combs[k * 2 + 1] == get_type("Dummy"));
                    } else {
                        match_states.push_back(tran->combs[k * 2 + 1] == prev_states[i][k]);
                    }
                } else {
                    match_states.push_back(tran->combs[k] == curr_states[i][k]);
                }
            }
            constrains.push_back(z3::implies(z3::mk_and(match_states), decision[i] == tran->action));
        }
    }

    return decision;
}

z3::expr_vector SMVerifier::create_states(std::string name_prefix) {
    z3::expr_vector states(context);
    for (int i = 0; i < num_nodes; ++i) {
        std::stringstream name;
        name << name_prefix << i;
        states.push_back(context.constant(name.str().c_str(), state_sort));
    }
    return states;
}

z3::expr SMVerifier::get_type(const std::string& name) {
    if (name_value_mapping.find(name) != name_value_mapping.end()) {
        return all_types[name_value_mapping[name]];
    } else {
        std::cerr << "Unknown name: " << name << std::endl;
        return z3::expr(context);
    }
}

PropertyFn SMVerifier::get_property_fn(const std::string& funcName) {
    if (property_fn_mapping.find(funcName) != property_fn_mapping.end()) {
        return property_fn_mapping[funcName];
    }
    return nullptr;
}

ProtocolConstraintFn SMVerifier::get_protocol_contraint_fn(const std::string& funcName) {
    if (protocol_constraint_fn_mapping.find(funcName) != protocol_constraint_fn_mapping.end()) {
        return protocol_constraint_fn_mapping[funcName];
    }
    return nullptr;
}
