#include "properties.h"

#include <iostream>

#include "verifier.h"

z3::expr_vector node_crash_per_round(std::vector<z3::expr_vector> messages, SMVerifier& ver, std::string lost_str);

// Constrantis for no recovery
z3::expr no_recovery_constraints(std::vector<z3::expr_vector> all_decisions,
                                 std::vector<std::vector<z3::expr_vector>> all_messages, SMVerifier& ver) {
    z3::context& context = ver.get_context();
    int num_nodes = ver.get_num_nodes();
    z3::expr_vector final_ret(context);

    // TODO: Consider moving this part out of this function
    std::vector<std::string> lost_types;
    if (ver.get_protocol_name() == "atomic_commit") {
        lost_types = {"Lost_R1", "Lost_R2"};
    } else {
        lost_types = std::vector<std::string>(ver.get_num_rounds(), "Lost");
    }

    for (int r = 0; r < ver.get_num_rounds() - 1; ++r) {
        z3::expr_vector node_crash(context);
        for (int i = 0; i < num_nodes; ++i) {
            z3::expr_vector or_expr(context);
            for (int j = 0; j < num_nodes; ++j) {
                or_expr.push_back(all_messages[r][j][i] == ver.get_type(lost_types[r]));
            }
            node_crash.push_back(z3::mk_or(or_expr));
        }
        z3::expr_vector all_lost(context);
        for (int i = 0; i < num_nodes; ++i) {
            z3::expr_vector and_expr(context);
            for (int j = 0; j < num_nodes; ++j) {
                if (i != j) {
                    and_expr.push_back(all_messages[r + 1][j][i] == ver.get_type(lost_types[r + 1]));
                }
            }
            all_lost.push_back(z3::mk_and(and_expr));
        }

        z3::expr_vector and_expr(context);
        for (int i = 0; i < num_nodes; ++i) {
            and_expr.push_back(z3::implies(node_crash[i], all_lost[i]));
        }
        final_ret.push_back(z3::mk_and(and_expr));
    }

    return z3::mk_and(final_ret);
}

z3::expr minimum_alive_constraints(std::vector<z3::expr_vector> all_decisions,
                                   std::vector<std::vector<z3::expr_vector>> all_messages, SMVerifier& ver) {
    z3::context& context = ver.get_context();
    int num_nodes = ver.get_num_nodes();
    z3::expr_vector final_ret(context);

    // TODO: Consider moving this part out of this function
    std::vector<std::string> lost_types;
    if (ver.get_protocol_name() == "atomic_commit") {
        lost_types = {"Lost_R1", "Lost_R2"};
    } else {
        lost_types = std::vector<std::string>(ver.get_num_rounds(), "Lost");
    }

    for (int i = 0; i < ver.get_num_rounds(); ++i) {
        z3::expr_vector crash_info = node_crash_per_round(all_messages[i], ver, lost_types[i]);
        z3::expr at_least_one_alive = context.bool_val(false);
        for (int j = 0; j < num_nodes; ++j) {
            at_least_one_alive = at_least_one_alive || !crash_info[j];
        }
        final_ret.push_back(at_least_one_alive);
    }
    return z3::mk_and(final_ret);
}

z3::expr distributed_locking(z3::expr_vector start_states, z3::expr_vector end_states,
                             std::vector<z3::expr_vector> all_decisions,
                             std::vector<std::vector<z3::expr_vector>> all_messages, SMVerifier& ver) {
    z3::context& context = ver.get_context();
    int num_nodes = ver.get_num_nodes();
    z3::expr_vector final_ret(context);

    // There are not 2 nodes enter CS at the same time
    z3::expr_vector prop1(context);
    for (int i = 0; i < num_nodes; ++i) {
        for (int j = 0; j < num_nodes; ++j) {
            if (i != j) {
                prop1.push_back(
                    z3::implies(end_states[i] == ver.get_type("Enter"), end_states[j] != ver.get_type("Enter")));
            }
        }
    }
    final_ret.push_back(z3::mk_and(prop1));

    // The nodes that don't need lock should not enter CS.
    z3::expr_vector prop2(context);
    for (int i = 0; i < num_nodes; ++i) {
        prop2.push_back(z3::implies(start_states[i] == ver.get_type("NoNeed"), end_states[i] != ver.get_type("Enter")));
    }
    final_ret.push_back(z3::mk_and(prop2));

    // If lock is needed, lock will be acquired by any one
    z3::expr_vector any_need(context);
    z3::expr_vector any_enter(context);
    for (int i = 0; i < num_nodes; ++i) {
        any_need.push_back(start_states[i] == ver.get_type("Need"));
        any_enter.push_back(end_states[i] == ver.get_type("Enter"));
    }
    z3::expr prop3 = z3::implies(z3::mk_or(any_need), z3::mk_or(any_enter));
    final_ret.push_back(prop3);

    return z3::mk_and(final_ret);
}

z3::expr atomic_commit(z3::expr_vector start_states, z3::expr_vector end_states,
                       std::vector<z3::expr_vector> all_decisions,
                       std::vector<std::vector<z3::expr_vector>> all_messages, SMVerifier& ver) {
    z3::context& context = ver.get_context();
    int num_nodes = ver.get_num_nodes();
    z3::expr_vector final_ret(context);

    z3::expr_vector or_expr(context);
    for (int i = 0; i < all_messages.size(); ++i) {
        for (int j = 0; j < all_messages[i].size(); ++j) {
            for (auto state : all_messages[i][j]) {
                for (auto lost_type : ver.get_lost_types()) {
                    or_expr.push_back(state == lost_type);
                }
            }
        }
    }
    z3::expr any_lost = z3::mk_or(or_expr);

    std::vector<z3::expr_vector> crash_info;
    std::vector<std::string> lost_types = {"Lost_R1", "Lost_R2"};
    for (int i = 0; i < ver.get_num_rounds(); ++i) {
        crash_info.push_back(node_crash_per_round(all_messages[i], ver, lost_types[i]));
    }
    z3::expr_vector reach_commit =
        reach_state_with_failure(all_decisions, crash_info, ver.get_type("Commit"), num_nodes, context);
    z3::expr_vector reach_abort =
        reach_state_with_failure(all_decisions, crash_info, ver.get_type("Abort"), num_nodes, context);

    // Rule1: If there is NOT a lost message, then all local commits shoudl result in all commits
    // Any COMMIT at both rounds
    z3::expr all_local_commit = all_state(start_states, ver.get_type("LocalCommit"), context);
    z3::expr prop_1 = z3::implies(!any_lost, z3::implies(all_local_commit, z3::mk_and(reach_commit)));
    final_ret.push_back(prop_1);

    // Rule2:
    // If there is NOT a lost message, then a single local abort should result in abort at all nodes
    z3::expr any_local_abort = any_state(start_states, ver.get_type("LocalAbort"), context);
    z3::expr prop_2 = z3::implies(!any_lost, z3::implies(any_local_abort, z3::mk_and(reach_abort)));
    final_ret.push_back(prop_2);

    // Rule3:
    // Different nodes should not reach conflict action (COMMIT/ABORT)
    z3::expr prop3 = !(z3::mk_or(reach_commit) && z3::mk_or(reach_abort));
    final_ret.push_back(prop3);

    // Rule4:
    // If the start state is LocalAbort, none of the decision should be Commit
    z3::expr_vector and_expr(context);
    for (int i = 0; i < num_nodes; ++i) {
        and_expr.push_back(z3::implies(start_states[i] == ver.get_type("LocalAbort"), !reach_commit[i]));
    }
    z3::expr prop4 = z3::mk_and(and_expr);
    final_ret.push_back(prop4);

    // Rule5:
    // Cannot change decisions
    z3::expr_vector prop5(context);
    for (int i = 0; i < num_nodes; ++i) {
        prop5.push_back(!(reach_commit[i] && reach_abort[i]));
    }
    final_ret.push_back(z3::mk_and(prop5));

    // Rule6:
    // Decision conflict
    z3::expr prop6 = !(z3::mk_or(reach_commit) && z3::mk_or(reach_abort));
    final_ret.push_back(prop6);

    return z3::mk_and(final_ret);
}

z3::expr atomic_commit_protocol_constraints(std::vector<z3::expr_vector> all_decisions,
                                            std::vector<std::vector<z3::expr_vector>> all_messages, SMVerifier& ver) {
    z3::context& context = ver.get_context();
    int num_nodes = ver.get_num_nodes();
    z3::expr_vector final_ret(context);

    // No Recovery
    final_ret.push_back(no_recovery_constraints(all_decisions, all_messages, ver));

    // At least one node should be alive in each round
    final_ret.push_back(minimum_alive_constraints(all_decisions, all_messages, ver));

    return z3::mk_and(final_ret);
}

z3::expr primary_backup(z3::expr_vector start_states, z3::expr_vector end_states,
                        std::vector<z3::expr_vector> all_decisions,
                        std::vector<std::vector<z3::expr_vector>> all_messages, SMVerifier& ver) {
    z3::context& context = ver.get_context();
    int num_nodes = ver.get_num_nodes();
    z3::expr_vector final_ret(context);

    std::vector<z3::expr_vector> crash_info = node_crash(all_messages, ver.get_num_rounds(), num_nodes, context, ver);
    z3::expr_vector reach_one =
        reach_state_with_failure(all_decisions, crash_info, ver.get_type("One"), num_nodes, context);
    z3::expr_vector reach_zero =
        reach_state_with_failure(all_decisions, crash_info, ver.get_type("Zero"), num_nodes, context);

    // Rule1:
    // Cannot change decisions
    z3::expr_vector prop1(context);
    for (int i = 0; i < num_nodes; ++i) {
        prop1.push_back(!(reach_one[i] && reach_zero[i]));
    }
    final_ret.push_back(z3::mk_and(prop1));

    // Rule2:
    // The decision should exist in the initial states
    z3::expr all_local_one = all_state(start_states, ver.get_type("LocalOne"), context);
    z3::expr all_local_zero = all_state(start_states, ver.get_type("LocalZero"), context);
    z3::expr_vector prop2(context);
    for (int i = 0; i < num_nodes; ++i) {
        prop2.push_back(z3::implies(all_local_one, !reach_zero[i]));
        prop2.push_back(z3::implies(all_local_zero, !reach_one[i]));
    }
    final_ret.push_back(z3::mk_and(prop2));

    // Rule3:
    // There has to be a decision at the end
    z3::expr_vector prop3(context);
    for (int i = 0; i < num_nodes; ++i) {
        // If it's crashed, we don't need to check whether it's reach a decision or not.
        prop3.push_back(z3::implies(!crash_info[ver.get_num_rounds() - 1][i], (reach_one[i] || reach_zero[i])));
    }
    final_ret.push_back(z3::mk_and(prop3));

    // Rule4:
    // Decision conflict
    z3::expr prop4 = !(z3::mk_or(reach_one) && z3::mk_or(reach_zero));
    final_ret.push_back(prop4);

    return z3::mk_and(final_ret);
}

z3::expr primary_backup_protocol_constraints(std::vector<z3::expr_vector> all_decisions,
                                             std::vector<std::vector<z3::expr_vector>> all_messages, SMVerifier& ver) {
    z3::context& context = ver.get_context();
    int num_nodes = ver.get_num_nodes();
    z3::expr_vector final_ret(context);

    // No Recovery
    final_ret.push_back(no_recovery_constraints(all_decisions, all_messages, ver));

    // TODO: at least one node should be alive in each round (Need to use imply)
    final_ret.push_back(minimum_alive_constraints(all_decisions, all_messages, ver));

    return z3::mk_and(final_ret);
}

z3::expr any_state(z3::expr_vector states, z3::expr target, z3::context& context) {
    z3::expr_vector or_expr(context);
    for (auto state : states) {
        or_expr.push_back(state == target);
    }
    return z3::mk_or(or_expr);
}

z3::expr all_state(z3::expr_vector states, z3::expr target, z3::context& context) {
    z3::expr_vector and_expr(context);
    for (auto state : states) {
        and_expr.push_back(state == target);
    }
    return z3::mk_and(and_expr);
}

z3::expr_vector reach_state(std::vector<z3::expr_vector> all_decisions, z3::expr target, int num_nodes,
                            z3::context& context) {
    z3::expr_vector reach(context);
    for (int i = 0; i < num_nodes; ++i) {
        z3::expr_vector or_expr(context);
        for (int j = 0; j < all_decisions.size(); ++j) {
            or_expr.push_back(all_decisions[j][i] == target);
        }
        reach.push_back(z3::mk_or(or_expr));
    }

    return reach;
}

std::vector<z3::expr_vector> node_crash(std::vector<std::vector<z3::expr_vector>> all_messages, int num_rounds,
                                        int num_nodes, z3::context& context, SMVerifier& ver) {
    std::vector<z3::expr_vector> crash;
    for (int i = 0; i < num_rounds; ++i) {
        crash.push_back(z3::expr_vector(context));
        for (int j = 0; j < num_nodes; ++j) {
            z3::expr_vector or_expr(context);
            for (int k = 0; k < num_nodes; ++k) {
                // Any node receive a lost message from node j means node j is crashed
                or_expr.push_back(all_messages[i][k][j] == ver.get_type("Lost"));
            }
            crash[i].push_back(z3::mk_or(or_expr));
        }
    }
    return crash;
}

z3::expr_vector node_crash_per_round(std::vector<z3::expr_vector> messages, SMVerifier& ver, std::string lost_str) {
    z3::context& context = ver.get_context();
    int num_nodes = ver.get_num_nodes();
    z3::expr_vector crash_info(context);

    for (int i = 0; i < num_nodes; ++i) {
        z3::expr_vector or_expr(context);
        for (int j = 0; j < num_nodes; ++j) {
            or_expr.push_back(messages[j][i] == ver.get_type(lost_str));
        }
        crash_info.push_back(z3::mk_or(or_expr));
    }
    return crash_info;
}

// Return whether each node reaches the target state considering failure
// If node crashed, considered as NOT reaching the target state
z3::expr_vector reach_state_with_failure(std::vector<z3::expr_vector> all_decisions,
                                         std::vector<z3::expr_vector> crash_info, z3::expr target, int num_nodes,
                                         z3::context& context) {
    z3::expr_vector reach(context);
    for (int i = 0; i < num_nodes; ++i) {
        z3::expr_vector or_expr(context);
        for (int j = 0; j < all_decisions.size(); ++j) {
            z3::expr if_crash = z3::ite(crash_info[j][i], context.bool_val(true), context.bool_val(false));
            or_expr.push_back(!if_crash && all_decisions[j][i] == target);
        }
        reach.push_back(z3::mk_or(or_expr));
    }

    return reach;
}