#include "../linex.hpp"
#include "../numbers.hpp"
#include "../sat/model_utils.hpp"
#include "./nnlp_factories.hpp"

#include "police/nnlp_marabou.hpp"
#include "police/sat_model.hpp"
#include "police/storage/variable_space.hpp"

#include <catch2/catch.hpp>
#include <cmath>

namespace {
using namespace police;

void check_bounds(const NNLP& lp, police::size_t var, const BoundedIntType& t)
{
    CHECK_THAT(
        lp.get_variable_lower_bound(var),
        Catch::Matchers::WithinAbs(t.lower_bound, PRECISION));

    CHECK_THAT(
        lp.get_variable_upper_bound(var),
        Catch::Matchers::WithinAbs(t.upper_bound, PRECISION));
}

void check_bounds(const NNLP& lp, police::size_t var, const BoundedRealType& t)
{
    CHECK_THAT(
        lp.get_variable_lower_bound(var),
        Catch::Matchers::WithinAbs(t.lower_bound, PRECISION));

    CHECK_THAT(
        lp.get_variable_upper_bound(var),
        Catch::Matchers::WithinAbs(t.upper_bound, PRECISION));
}

class TestCaseBase {
public:
    explicit TestCaseBase()
        : c0(LinearConstraint::LESS_EQUAL)
        , c1(LinearConstraint::GREATER_EQUAL)
        , eq(LinearConstraint::EQUAL)
        , relu(0, 0)
        , c2(LinearConstraint::GREATER_EQUAL)
        , c3(LinearConstraint::GREATER_EQUAL)
        , c4(LinearConstraint::LESS_EQUAL)
    {
        setup_variables();
        setup_constraints();
    }

    void operator()(NNLP& lp)
    {
        lp.add_variables(vspace);
        prepare(lp);
        check_result(lp);
    }

    VariableSpace vspace;

    police::size_t x;
    police::size_t y;
    police::size_t z;
    police::size_t h;
    police::size_t o;

    LinearConstraint c0;
    LinearConstraint c1;
    LinearConstraint eq;
    ReluConstraint relu;
    LinearConstraint c2;
    LinearConstraint c3;
    LinearConstraintDisjunction c2_or_c3;
    LinearConstraint c4;

protected:
    void check_variables(const SATModel& model)
    {
        check_variable_domains(model, vspace);
        check_variable_types(model, vspace);
    }

private:
    void setup_variables()
    {
        x = vspace.add_variable("x", BoundedRealType(0, 10));
        y = vspace.add_variable("y", BoundedRealType(0, 10));
        z = vspace.add_variable("z", BoundedRealType(-10, 10));
        h = vspace.add_variable("h", BoundedRealType(-100, 100));
        o = vspace.add_variable("o", BoundedRealType(-100, 100));
    }

    void setup_constraints()
    {
        c0.insert(x, 1);
        c0.insert(y, 1);
        c0.insert(z, -1);
        c0.rhs = -1.;

        c1.insert(x, 1);
        c1.insert(y, 1);
        c1.rhs = 5.;

        eq.insert(x, 2.);
        eq.insert(y, 1.);
        eq.insert(z, -2.);
        eq.insert(h, -1.);

        relu = ReluConstraint(h, o);

        c2.insert(x, 2.);
        c2.insert(z, -1);

        c3.insert(y, 2.);
        c3.insert(z, -1);

        c2_or_c3.push_back(c2);
        c2_or_c3.push_back(c3);

        c4.insert(z, 1.0);
        c4.insert(o, 1.0);
    }

    virtual void prepare(NNLP& lp) = 0;

    virtual void check_result(NNLP& lp) = 0;
};

class TestCase1 final : public TestCaseBase {
public:
    using TestCaseBase::TestCaseBase;

    constexpr static std::string name(const char* what)
    {
        return std::string(
                   "Program with two linear constraints is solvable: ") +
               what;
    }

private:
    void prepare(NNLP& lp) override
    {
        lp.add_constraint(c0);
        lp.add_constraint(c1);
    }

    void check_result(NNLP& lp) override
    {
        REQUIRE(lp.solve() == NNLP::Status::SAT);
        check_model(lp);
    }

    void check_model(NNLP& lp)
    {
        auto model = lp.get_model();
        check_variables(model);
        check_constraints(model, c0, c1);
    }
};

class TestCase2 final : public TestCaseBase {
public:
    using TestCaseBase::TestCaseBase;

    constexpr static std::string name(const char* what)
    {
        return std::string("Program with two linear and one equality "
                           "constraints is solvable: ") +
               what;
    }

private:
    void prepare(NNLP& lp) override
    {
        lp.add_constraint(c0);
        lp.add_constraint(c1);
        lp.add_constraint(eq);
    }

    void check_result(NNLP& lp) override
    {
        REQUIRE(lp.solve() == NNLP::Status::SAT);
        check_model(lp);
    }

    void check_model(NNLP& lp)
    {
        auto model = lp.get_model();
        check_variables(model);
        check_constraints(model, c0, c1, eq);
    }
};

class TestCase3 final : public TestCaseBase {
public:
    using TestCaseBase::TestCaseBase;

    constexpr static std::string name(const char* what)
    {
        return std::string(
                   "Program with two linear, one equality, and one Relu "
                   "constraints is solvable: ") +
               what;
    }

private:
    void prepare(NNLP& lp) override
    {
        lp.add_constraint(c0);
        lp.add_constraint(c1);
        lp.add_constraint(eq);
        lp.add_constraint(relu);
    }

    void check_result(NNLP& lp) override
    {
        REQUIRE(lp.solve() == NNLP::Status::SAT);
        check_model(lp);
    }

    void check_model(NNLP& lp)
    {
        auto model = lp.get_model();
        check_variables(model);
        check_constraints(model, c0, c1, eq, relu);
    }
};

class TestCase4 final : public TestCaseBase {
public:
    using TestCaseBase::TestCaseBase;

    constexpr static std::string name(const char* what)
    {
        return std::string(
                   "Program is solvable after setting variable bounds: ") +
               what;
    }

private:
    void prepare(NNLP& lp) override
    {
        lp.add_constraint(c0);
        lp.add_constraint(c1);
        lp.add_constraint(eq);
        lp.add_constraint(relu);

        vspace[x].type = BoundedRealType(-10, -1);
        lp.set_variable_lower_bound(x, -10);
        lp.set_variable_upper_bound(x, -1);
    }

    void check_result(NNLP& lp) override
    {
        REQUIRE(lp.solve() == NNLP::Status::SAT);
        check_model(lp);
    }

    void check_model(NNLP& lp)
    {
        auto model = lp.get_model();
        check_variables(model);
        check_constraints(model, c0, c1, eq, relu);
    }
};

class TestCase5 final : public TestCaseBase {
public:
    using TestCaseBase::TestCaseBase;

    constexpr static std::string name(const char* what)
    {
        return std::string("Program remains solvable after changing variable "
                           "bounds: ") +
               what;
    }

private:
    void prepare(NNLP& lp) override
    {
        lp.add_constraint(c0);
        lp.add_constraint(c1);
        lp.add_constraint(eq);
        lp.add_constraint(relu);
    }

    void check_result(NNLP& lp) override
    {
        REQUIRE(lp.solve() == NNLP::Status::SAT);
        check_model(lp);

        vspace[x].type = BoundedRealType(-10, -1);
        lp.set_variable_lower_bound(x, -10);
        lp.set_variable_upper_bound(x, -1);

        REQUIRE(lp.solve() == NNLP::Status::SAT);
        check_model(lp);
    }

    void check_model(NNLP& lp)
    {
        auto model = lp.get_model();
        check_variables(model);
        check_constraints(model, c0, c1, eq, relu);
    }
};

class TestCase6 final : public TestCaseBase {
public:
    using TestCaseBase::TestCaseBase;

    constexpr static std::string name(const char* what)
    {
        return std::string(
                   "Program with disjunctive constraints is solvable: ") +
               what;
    }

private:
    void prepare(NNLP& lp) override
    {
        lp.add_constraint(c0);
        lp.add_constraint(c1);
        lp.add_constraint(eq);
        lp.add_constraint(relu);
        lp.add_constraint(c2_or_c3);
    }

    void check_result(NNLP& lp) override
    {
        REQUIRE(lp.solve() == NNLP::Status::SAT);
        check_model(lp);
    }

    void check_model(NNLP& lp)
    {
        auto model = lp.get_model();
        check_variables(model);
        check_constraints(model, c0, c1, eq, relu);
    }
};

class TestCase7 final : public TestCaseBase {
public:
    using TestCaseBase::TestCaseBase;

    constexpr static std::string name(const char* what)
    {
        return std::string("Program is unsolvable: ") + what;
    }

private:
    void prepare(NNLP& lp) override
    {
        lp.add_constraint(c0);
        lp.add_constraint(c1);
        lp.add_constraint(eq);
        lp.add_constraint(relu);
        lp.add_constraint(c4);
    }

    void check_result(NNLP& lp) override
    {
        REQUIRE(lp.solve() == NNLP::Status::UNSAT);
    }
};

class TestCase8 final : public TestCaseBase {
public:
    using TestCaseBase::TestCaseBase;

    constexpr static std::string name(const char* what)
    {
        return std::string("Reverting to snapshot makes program solvable: ") +
               what;
    }

private:
    void prepare(NNLP& lp) override
    {
        lp.add_constraint(c0);
        lp.add_constraint(c1);
        lp.add_constraint(eq);
        lp.add_constraint(relu);

        lp.push_snapshot();

        lp.add_constraint(c4);
    }

    void check_result(NNLP& lp) override
    {
        REQUIRE(lp.solve() == NNLP::Status::UNSAT);
        lp.pop_snapshot();
        REQUIRE(lp.solve() == NNLP::Status::SAT);
        check_model(lp);
    }

    void check_model(NNLP& lp)
    {
        auto model = lp.get_model();
        check_variables(model);
        check_constraints(model, c0, c1, eq, relu);
    }
};

class TestCase9 final : public TestCaseBase {
public:
    using TestCaseBase::TestCaseBase;

    constexpr static std::string name(const char* what)
    {
        return std::string("Bound reset through popping snapshot leaves "
                           "program solvable: ") +
               what;
    }

private:
    void prepare(NNLP& lp) override
    {
        lp.add_constraint(c0);
        lp.add_constraint(c1);
        lp.add_constraint(eq);
        lp.add_constraint(relu);
        lp.push_snapshot();
    }

    void check_result(NNLP& lp) override
    {
        REQUIRE(lp.solve() == NNLP::Status::SAT);
        check_model(lp);

        const auto t = std::get<BoundedRealType>(vspace[x].type);
        vspace[x].type = BoundedRealType(-10, -1);
        lp.set_variable_lower_bound(x, -10);
        lp.set_variable_upper_bound(x, -1);
        REQUIRE(lp.solve() == NNLP::Status::SAT);
        check_model(lp);

        lp.pop_snapshot();
        vspace[x].type = t;

        CHECK_THAT(
            lp.get_variable_lower_bound(x),
            Catch::Matchers::WithinAbs(t.lower_bound, PRECISION));
        CHECK_THAT(
            lp.get_variable_upper_bound(x),
            Catch::Matchers::WithinAbs(t.upper_bound, PRECISION));

        REQUIRE(lp.solve() == NNLP::Status::SAT);
        check_model(lp);
    }

    void check_model(NNLP& lp)
    {
        auto model = lp.get_model();
        check_variables(model);
        check_constraints(model, c0, c1, eq, relu);
    }
};

class TestCase10 final : public TestCaseBase {
public:
    using TestCaseBase::TestCaseBase;

    constexpr static std::string name(const char* what)
    {
        return std::string("Pushing and popping series of snapshot while "
                           "chaning variable bounds: ") +
               what;
    }

private:
    void prepare(NNLP& lp) override
    {
        lp.add_constraint(c0);
        lp.add_constraint(c1);
        lp.add_constraint(eq);
        lp.add_constraint(relu);
    }

    void check_result(NNLP& lp) override
    {
        const auto t0 = std::get<BoundedRealType>(vspace[x].type);
        lp.push_snapshot();

        const auto t1 = BoundedRealType(-10, -1);
        vspace[x].type = t1;
        lp.set_variable_lower_bound(x, -10);
        lp.set_variable_upper_bound(x, -1);
        REQUIRE(lp.solve() == NNLP::Status::SAT);
        check_model(lp);

        const auto ty = std::get<BoundedRealType>(vspace[y].type);
        lp.push_snapshot();

        vspace[y].type = BoundedRealType(0, 5);
        lp.set_variable_lower_bound(y, 0);
        lp.set_variable_upper_bound(y, 5);
        REQUIRE(lp.solve() == NNLP::Status::UNSAT);

        lp.pop_snapshot(); // t1 and ty
        check_bounds(lp, x, t1);
        check_bounds(lp, y, ty);
        vspace[y].type = ty;
        REQUIRE(lp.solve() == NNLP::Status::SAT);
        check_model(lp);

        lp.pop_snapshot(); // t0
        check_bounds(lp, x, t0);
        vspace[x].type = t0;
        REQUIRE(lp.solve() == NNLP::Status::SAT);
        check_model(lp);
    }

    void check_model(NNLP& lp)
    {
        auto model = lp.get_model();
        check_variables(model);
        check_constraints(model, c0, c1, eq, relu);
    }
};

class ITestCaseBase {
public:
    ITestCaseBase()
        : c0(LinearConstraint::GREATER_EQUAL)
        , c1(LinearConstraint::LESS_EQUAL)
        , c2(LinearConstraint::EQUAL)
        , c2_(LinearConstraint::EQUAL)
        , relu(0, 0)
        , c3(LinearConstraint::LESS_EQUAL)
        , c3_(LinearConstraint::LESS_EQUAL)
    {
        x = vspace.add_variable("x", BoundedIntType(0, 5));
        y = vspace.add_variable("y", BoundedIntType(0, 5));
        z = vspace.add_variable("z", BoundedIntType(-10, 10));
        h = vspace.add_variable("h", BoundedRealType(-100, 100));
        o = vspace.add_variable("o", BoundedRealType(-100, 100));

        c0.insert(x, 1.);
        c0.insert(y, 2.);
        c0.rhs = 5;

        c1.insert(x, 1.);
        c1.insert(y, 1.);
        c1.insert(z, -1.);

        c2.insert(x, 1.);
        c2.insert(y, 1.);
        c2.insert(z, 1.);
        c2.insert(h, -1.);

        c2_.insert(x, 1.);
        c2_.insert(y, 1.);
        c2_.insert(z, -3.);
        c2_.insert(h, -1.);

        relu = ReluConstraint(h, o);

        c3.insert(o, 1.);
        c3.rhs = 10;

        c3_.insert(o, 2.);
        c3_.rhs = 0;
    }

    void operator()(NNLP& lp)
    {
        lp.add_variables(vspace);
        prepare(lp);
        check_result(lp);
    }

    VariableSpace vspace;

    police::size_t x;
    police::size_t y;
    police::size_t z;
    police::size_t h;
    police::size_t o;

    LinearConstraint c0;
    LinearConstraint c1;
    LinearConstraint c2;
    LinearConstraint c2_;
    ReluConstraint relu;
    LinearConstraint c3;
    LinearConstraint c3_;

protected:
    void check_variables(const SATModel& model)
    {
        check_variable_domains(model, vspace);
        check_variable_types(model, vspace);
    }

private:
    virtual void prepare(NNLP& lp) = 0;
    virtual void check_result(NNLP& lp) = 0;
};

class ITestCase1 final : public ITestCaseBase {
public:
    constexpr static std::string name(const char* type)
    {
        return std::string("Has integer variables: ") + type;
    }

private:
    void prepare(NNLP&) override {}

    void check_result(NNLP& lp) override
    {
        REQUIRE(lp.num_variables() == 5u);
        REQUIRE(lp.has_integer_variable() == true);
        check_bounds(lp, x, std::get<BoundedIntType>(vspace[x].type));
        check_bounds(lp, y, std::get<BoundedIntType>(vspace[y].type));
        check_bounds(lp, z, std::get<BoundedIntType>(vspace[z].type));
    }
};

class ITestCase2 : public ITestCaseBase {
public:
    constexpr static std::string name(const char* type)
    {
        return std::string("Program with integer variables is solvable: ") +
               type;
    }

private:
    void prepare(NNLP& lp) override
    {
        lp.add_constraint(c0);
        lp.add_constraint(c1);
        lp.add_constraint(c2);
        lp.add_constraint(relu);
        lp.add_constraint(c3);
    }

    void check_result(NNLP& lp) override
    {
        REQUIRE(lp.solve() == NNLP::SAT);
        const auto model = lp.get_model();
        check_variables(model);
        check_constraints(model, c0, c1, c2, relu, c3);
    }
};

class ITestCase3 : public ITestCaseBase {
public:
    constexpr static std::string name(const char* type)
    {
        return std::string("Program with integer variables is unsolvable: ") +
               type;
    }

private:
    void prepare(NNLP& lp) override
    {
        lp.add_constraint(c0);
        lp.add_constraint(c1);
        lp.add_constraint(c2);
        lp.add_constraint(relu);
        lp.add_constraint(c3_);
    }

    void check_result(NNLP& lp) override { REQUIRE(lp.solve() == NNLP::UNSAT); }
};

} // namespace

#define CREATE_TEST_FOR(TestClass, Ntype, Type, Tags)                          \
    TEST_CASE(TestClass::name(#Type), "[nnlp]" Ntype Tags)                     \
    {                                                                          \
        TestClass test_case;                                                   \
        auto lp = Type();                                                      \
        test_case(lp);                                                         \
    }

#define CREATE_RTESTS_FOR(...)                                                 \
    CREATE_TEST_FOR(TestCase1, "[real]", __VA_ARGS__)                          \
    CREATE_TEST_FOR(TestCase2, "[real]", __VA_ARGS__)                          \
    CREATE_TEST_FOR(TestCase3, "[real]", __VA_ARGS__)                          \
    CREATE_TEST_FOR(TestCase4, "[real]", __VA_ARGS__)                          \
    CREATE_TEST_FOR(TestCase5, "[real]", __VA_ARGS__)                          \
    CREATE_TEST_FOR(TestCase6, "[real]", __VA_ARGS__)                          \
    CREATE_TEST_FOR(TestCase7, "[real]", __VA_ARGS__)                          \
    CREATE_TEST_FOR(TestCase8, "[real]", __VA_ARGS__)                          \
    CREATE_TEST_FOR(TestCase9, "[real]", __VA_ARGS__)                          \
    CREATE_TEST_FOR(TestCase10, "[real]", __VA_ARGS__)

#define CREATE_ITESTS_FOR(...)                                                 \
    CREATE_TEST_FOR(ITestCase1, "[int]", __VA_ARGS__)                          \
    CREATE_TEST_FOR(ITestCase2, "[int]", __VA_ARGS__)                          \
    CREATE_TEST_FOR(ITestCase3, "[int]", __VA_ARGS__)

#define CREATE_TESTS_FOR(...)                                                  \
    CREATE_RTESTS_FOR(__VA_ARGS__)                                             \
    CREATE_ITESTS_FOR(__VA_ARGS__)

#if POLICE_Z3
CREATE_TESTS_FOR(Z3_NNLP, "[smt][z3]")
CREATE_TESTS_FOR(BnB_NNLP_Z3, "[bnb][smt][z3]")
#endif

#if POLICE_GUROBI
CREATE_TESTS_FOR(Gurobi_NNLP, "[lp][gurobi]")
CREATE_TESTS_FOR(BnB_NNLP_Gurobi, "[bnb][gurobi]")
#endif

#if POLICE_MARABOU
CREATE_RTESTS_FOR(MarabouLP, "[marabou]")
CREATE_TESTS_FOR(BnB_NNLP_Marabou, "[bnb][marabou]")
#endif

#if POLICE_MARABOU && POLICE_Z3
CREATE_TESTS_FOR(Z3_NNLP_Preprocessing, "[smt][z3][marabou]")
#endif

#if POLICE_MARABOU && POLICE_GUROBI
CREATE_TESTS_FOR(Gurobi_NNLP_Preprocessing, "[lp][gurobi][marabou]")
#endif

#undef CREATE_TEST_FOR
#undef CREATE_TESTS_FOR
#undef CREATE_ITESTS_FOR
