#include "police/base_types.hpp"
#include "police/expressions/expression.hpp"
#include "police/expressions/variable.hpp"
#include "police/storage/unordered_set.hpp"

#include <catch2/catch.hpp>

#ifdef POLICE_Z3

#include "police/smt_model_enumerator.hpp"
#include "police/smt_z3.hpp"
#include "police/storage/variable_space.hpp"

namespace {
using namespace police;
}

TEST_CASE("Test Z3 construction and destruction", "[sat][smt][z3]")
{
    SECTION("Single real variable")
    {
        police::Z3SMT smt;
        smt.add_variable("x", RealType());
        REQUIRE(smt.num_variables() == 1u);
    }

    SECTION("Single int variable")
    {
        police::Z3SMT smt;
        smt.add_variable("x", IntegerType());
        REQUIRE(smt.num_variables() == 1u);
    }

    SECTION("Single bounded variable")
    {
        police::Z3SMT smt;
        smt.add_variable("x", BoundedIntType(0, 1));
        REQUIRE(smt.num_variables() == 1u);
        const auto status = smt.solve();
        REQUIRE(status == SMT::Status::SAT);
    }
}

TEST_CASE("Test Z3: pure linear real arithmetic", "[sat][smt][z3]")
{
    police::Z3SMT smt;

    const auto x =
        police::expressions::Variable(smt.add_variable("x", police::RealType{}));
    const auto y =
        police::expressions::Variable(smt.add_variable("y", police::RealType{}));
    const auto z =
        police::expressions::Variable(smt.add_variable("z", police::RealType{}));

    smt.add_constraint(expressions::greater_equal(x, Value(0.)));
    smt.add_constraint(expressions::greater_equal(y, Value(0.)));
    smt.add_constraint(expressions::greater_equal(z, Value(0.)));
    smt.add_constraint(expressions::greater_equal(x + y, Value(10.)));
    smt.add_constraint(expressions::greater_equal(y + z, Value(10.)));
    smt.add_constraint(expressions::less_equal(x + z, Value(5.)));

    auto check_model = [&](const SMT::model_type& model) {
        const auto x_val = model.get_value(x.var_id);
        const auto y_val = model.get_value(y.var_id);
        const auto z_val = model.get_value(z.var_id);

        REQUIRE(x_val >= Value(0.));
        REQUIRE(y_val >= Value(0.));
        REQUIRE(z_val >= Value(0.));

        REQUIRE(x_val + y_val >= Value(10.));
        REQUIRE(y_val + z_val >= Value(10.));
        REQUIRE(x_val + z_val <= Value(5.));
    };

    const auto status0 = smt.solve();
    REQUIRE(status0 == SMT::Status::SAT);

    SECTION("Extract model")
    {
        const auto model = smt.get_model();
        check_model(model);
    }

    SECTION("Make unsolvable")
    {
        smt.add_constraint(expressions::less_equal(y, Value(5.)));
        const auto status1 = smt.solve();
        REQUIRE(status1 == SMT::Status::UNSAT);
    }

    SECTION("Push & pop snapshot")
    {
        smt.push_snapshot();
        smt.add_constraint(expressions::less_equal(y, Value(5.)));
        const auto status1 = smt.solve();
        REQUIRE(status1 == SMT::Status::UNSAT);
        smt.pop_snapshot();
        const auto status2 = smt.solve();
        REQUIRE(status2 == SMT::Status::SAT);
    }

    SECTION("Enumerate models")
    {
        SMTModelEnumerator model_enum;
        int max_num = 3;
        model_enum(
            &smt,
            [&max_num, &check_model](const SMT::model_type& model) {
                check_model(model);
                return --max_num == 0;
            });
        REQUIRE(max_num == 0);
    }
}

TEST_CASE("Test Z3: linear integer arithmetic", "[sat][smt][z3]")
{
    police::Z3SMT smt;

    const auto x = police::expressions::Variable(
        smt.add_variable("x", police::BoundedIntType{0, 10}));
    const auto y = police::expressions::Variable(
        smt.add_variable("y", police::BoundedIntType{0, 10}));
    const auto z = police::expressions::Variable(
        smt.add_variable("z", police::BoundedIntType{-5, 5}));

    smt.add_constraint(expressions::less_equal(x + y + z, Value(15)));
    smt.add_constraint(expressions::greater_equal(x + y, Value(19)));

    auto check_model =
        [&](const SMT::model_type& model) -> std::tuple<int_t, int_t, int_t> {
        const auto x_val = model.get_value(x.var_id);
        const auto y_val = model.get_value(y.var_id);
        const auto z_val = model.get_value(z.var_id);

        REQUIRE((x_val >= Value(0) && x_val <= Value(10)));
        REQUIRE((y_val >= Value(0) && y_val <= Value(10)));
        REQUIRE((z_val >= Value(-5) && z_val <= Value(5)));

        REQUIRE(x_val + y_val + z_val <= Value(15));
        REQUIRE(x_val + y_val >= Value(19));

        return {
            static_cast<int_t>(x_val),
            static_cast<int_t>(y_val),
            static_cast<int_t>(z_val)};
    };

    const auto status0 = smt.solve();
    REQUIRE(status0 == SMT::Status::SAT);

    SECTION("Extract model")
    {
        const auto model = smt.get_model();
        check_model(model);
    }

    SECTION("Make unsolvable")
    {
        smt.add_constraint(
            expressions::greater_equal(x + Value(2) * z, Value(3)));
        const auto status1 = smt.solve();
        REQUIRE(status1 == SMT::Status::UNSAT);
    }

    SECTION("Enumerate models")
    {
        SMTModelEnumerator model_enum;
        unordered_set<std::tuple<int_t, int_t, int_t>> models;
        model_enum(&smt, [&models, &check_model](const SMT::model_type& model) {
            auto inserted = models.insert(check_model(model));
            REQUIRE(inserted.second);
            return models.size() > 5;
        });
        REQUIRE(models.size() == 5);
    }
}

#endif
