#include "../numbers.hpp"
#include "police/expressions/constants.hpp"
#include "police/expressions/variable.hpp"
#include "police/linear_condition.hpp"
#include "police/model.hpp"
#include "police/nnlp_smt.hpp"
#include "police/sat_model.hpp"
#include "police/smt_z3.hpp"
#include "police/storage/ffnn.hpp"
#include "police/storage/flat_state.hpp"
#include "police/storage/variable_space.hpp"
#include "police/variable_substitution.hpp"
#include "police/verifiers/epic3/engines/sat_based/encoding_information.hpp"
#include "police/verifiers/epic3/engines/sat_based/ic3_sat.hpp"
#include "police/verifiers/epic3/engines/sat_based/pic3sat_label.hpp"
#include "police/verifiers/epic3/engines/sat_based/smt_transition_encoding.hpp"

#include <catch2/catch.hpp>

namespace {
using namespace police;

template <typename Value, typename... Values>
void match_model(
    police::size_t var,
    const SATModel& model,
    Value expected_value,
    Values... values)
{
    CHECK_THAT(
        static_cast<real_t>(model.get_value(var)),
        Catch::Matchers::WithinAbs(expected_value, PRECISION));
    if constexpr (sizeof...(values) > 0u) {
        match_model(var + 1, model, values...);
    }
}

template <typename... Values>
void match_model(const SATModel& model, Values... values)
{
    if constexpr (sizeof...(values) > 0u) {
        match_model(0, model, values...);
    }
}

} // namespace

#if POLICE_Z3
TEST_CASE("Test single action", "[ic3][smt][encoding]")
{
    VariableSpace vspace;
    const auto x = vspace.add_variable("x", BoundedIntType{0, 1});
    const auto y = vspace.add_variable("x", BoundedIntType{0, 1});

    LinearConstraintConjunction guard;
    {
        LinearConstraint c(LinearConstraint::EQUAL);
        c.insert(x, 1.0);
        c.rhs = 0.;
        guard &= c;
    }
    {
        LinearConstraint c(LinearConstraint::EQUAL);
        c.insert(y, 1.0);
        c.rhs = 0.;
        guard &= c;
    }

    LinearExpression one;
    one.bias = 1.;

    vector<Action> single_outcome(1);
    single_outcome[0].label = 0;
    single_outcome[0].guard = guard;
    single_outcome[0].outcomes.resize(1);
    single_outcome[0].outcomes.back().assignments.emplace_back(x, one);
    single_outcome[0].outcomes.back().assignments.emplace_back(y, one);

    vector<Action> two_outcomes(1);
    two_outcomes[0].label = 0;
    two_outcomes[0].guard = guard;
    two_outcomes[0].outcomes.resize(2);
    two_outcomes[0].outcomes.front().assignments.emplace_back(x, one);
    two_outcomes[0].outcomes.back().assignments.emplace_back(y, one);

    Z3SMT smt;

    SECTION("Test single outcome without prevail")
    {
        epic3::EncodingInformation info;
        epic3::add_state_variables(smt, info, vspace);
        epic3::encode_transition(
            smt,
            info,
            single_outcome.begin(),
            single_outcome.end());

        auto status = smt.solve();
        REQUIRE(status == SMT::Status::SAT);

        auto model = smt.get_model();
        match_model(model, 0, 1, 0, 1);
    }

    SECTION("Test single outcome with prevail")
    {
        single_outcome[0].outcomes.back().assignments.pop_back();

        epic3::EncodingInformation info;
        epic3::add_state_variables(smt, info, vspace);
        epic3::encode_transition(
            smt,
            info,
            single_outcome.begin(),
            single_outcome.end());

        auto status = smt.solve();
        REQUIRE(status == SMT::Status::SAT);

        auto model = smt.get_model();
        match_model(model, 0, 1, 0, 0);
    }

    SECTION("Test two outcomes")
    {
        epic3::EncodingInformation info;
        epic3::add_state_variables(smt, info, vspace);
        epic3::encode_transition(
            smt,
            info,
            two_outcomes.begin(),
            two_outcomes.end());

        smt.add_constraint(
            substitute_vars(guard[1], info.out_vars).as_expression());

        auto status = smt.solve();
        REQUIRE(status == SMT::Status::SAT);

        auto model = smt.get_model();
        match_model(model, 0, 1, 0, 0);
    }
}
#endif

#if POLICE_Z3
TEST_CASE("Test multiple actions", "[ic3][smt][encoding]")
{
    VariableSpace vspace;
    const auto x = vspace.add_variable("x", BoundedIntType{0, 1});
    const auto y = vspace.add_variable("x", BoundedIntType{0, 1});

    LinearConstraintConjunction guard;
    {
        LinearConstraint c(LinearConstraint::EQUAL);
        c.insert(x, 1.0);
        c.rhs = 0.;
        guard &= c;
    }
    {
        LinearConstraint c(LinearConstraint::EQUAL);
        c.insert(y, 1.0);
        c.rhs = 0.;
        guard &= c;
    }

    LinearExpression one;
    one.bias = 1.;

    vector<Action> single_outcome(1);
    single_outcome[0].label = 0;
    single_outcome[0].guard = guard;
    single_outcome[0].outcomes.resize(1);
    single_outcome[0].outcomes.back().assignments.emplace_back(x, one);

    vector<Action> two_outcomes(1);
    two_outcomes[0].label = 0;
    two_outcomes[0].guard = guard;
    two_outcomes[0].outcomes.resize(2);
    two_outcomes[0].outcomes.front().assignments.emplace_back(x, one);
    two_outcomes[0].outcomes.back().assignments.emplace_back(y, one);

    Z3SMT smt;

    SECTION("Test sat x=1")
    {
        vector<Action> a{single_outcome[0], two_outcomes[0]};
        epic3::EncodingInformation info;
        epic3::add_state_variables(smt, info, vspace);
        epic3::encode_transition(smt, info, a.begin(), a.end());

        smt.add_constraint(
            substitute_vars(guard[1], info.out_vars).as_expression());

        auto status = smt.solve();
        REQUIRE(status == SMT::Status::SAT);

        auto model = smt.get_model();
        match_model(model, 0, 1, 0, 0);
    }

    SECTION("Test sat y=1")
    {
        vector<Action> a{single_outcome[0], two_outcomes[0]};
        epic3::EncodingInformation info;
        epic3::add_state_variables(smt, info, vspace);
        epic3::encode_transition(smt, info, a.begin(), a.end());

        smt.add_constraint(
            substitute_vars(guard[0], info.out_vars).as_expression());

        auto status = smt.solve();
        REQUIRE(status == SMT::Status::SAT);

        auto model = smt.get_model();
        match_model(model, 0, 0, 0, 1);
    }

    SECTION("Test unsat")
    {
        vector<Action> a{single_outcome[0], two_outcomes[0]};
        epic3::EncodingInformation info;
        epic3::add_state_variables(smt, info, vspace);
        epic3::encode_transition(smt, info, a.begin(), a.end());

        guard[0].rhs = 1.;
        guard[1].rhs = 1.;
        smt.add_constraint(
            substitute_vars(guard, info.out_vars).as_expression());

        auto status = smt.solve();
        REQUIRE(status == SMT::Status::UNSAT);
    }
}
#endif

#if POLICE_Z3
TEST_CASE("Test ic3sat", "[ic3][encoding]")
{
    VariableSpace vspace;
    const auto x = vspace.add_variable("x", BoundedIntType{0, 10});
    const auto y = vspace.add_variable("y", BoundedIntType{0, 1});

    LinearConstraintConjunction guard;
    {
        LinearConstraint c(LinearConstraint::LESS_EQUAL);
        c.insert(x, 1.0);
        c.rhs = 9.;
        guard &= c;
    }
    {
        LinearConstraint c(LinearConstraint::EQUAL);
        c.insert(y, 1.0);
        c.rhs = 0.;
        guard &= c;
    }

    LinearExpression plus_one;
    plus_one.insert(x, 1.);
    plus_one.bias = 1.;

    vector<Action> action(1);
    action[0].label = 0;
    action[0].guard = guard;
    action[0].outcomes.resize(1);
    action[0].outcomes.back().assignments.emplace_back(x, plus_one);

    LinearCondition terminal;
    {
        LinearConstraint term(LinearConstraint::EQUAL);
        term.insert(y, 1.);
        term.rhs = 1.;
        terminal |= term;
    }

    LinearCondition avoid;
    {
        LinearConstraint av(LinearConstraint::GREATER_EQUAL);
        av.insert(x, 1.);
        av.rhs = 1.;
        avoid |= av;
    }

    epic3::Cube initial_state;
    initial_state.set(0, epic3::Interval(Value(0)));
    initial_state.set(1, epic3::Interval(Value(0)));

    Z3SMTFactory smt;

    SECTION("Single step sat")
    {
        epic3::IC3SatInterface sat = epic3::IC3SatInterface::create(
            smt,
            vspace,
            action.begin(),
            action.end(),
            !terminal,
            avoid);
        auto blocked = sat.is_blocked(initial_state, 1);
        REQUIRE(!blocked.first);
    }

    SECTION("Single step unsat")
    {
        avoid[0].back().rhs = 2;
        epic3::IC3SatInterface sat = epic3::IC3SatInterface::create(
            smt,
            vspace,
            action.begin(),
            action.end(),
            !terminal,
            avoid);
        auto blocked = sat.is_blocked(initial_state, 1);
        REQUIRE(blocked.first);
    }
}
#endif

#if POLICE_Z3
TEST_CASE("Test nn", "[ic3][encoding][nn]")
{
    VariableSpace vspace;
    const auto xid = vspace.add_variable("x", BoundedIntType{0, 10});
    const auto yid = vspace.add_variable("y", BoundedIntType{0, 10});
    const expressions::Variable x(xid);
    const expressions::Variable y(yid);
    auto initial = (x <= y && y <= Value(5));
    auto zero_zero = x == Value(0) && y == Value(0);
    auto one_zero = x == Value(1) && y == Value(0);
    auto zero_one = x == Value(0) && y == Value(1);
    auto goal = (x == Value(10) && y == Value(10));
    auto avoid = (y + Value(1) <= x);
    vector<Action> edges;
    {
        auto guard = x <= Value(9);
        Assignment eff(xid, LinearExpression::from_expression(x + Value(1)));
        edges.push_back(Action(
            0,
            LinearCondition::from_expression(guard).front(),
            {Outcome({eff})}));
    }
    {
        auto guard = y <= Value(9);
        Assignment eff(yid, LinearExpression::from_expression(y + Value(1)));
        edges.push_back(Action(
            1,
            LinearCondition::from_expression(guard).front(),
            {Outcome({eff})}));
    }
    FeedForwardNeuralNetwork<> nn;
    nn.input.emplace_back(0, 10, 5, 10);
    nn.input.emplace_back(0, 10, 5, 10);
    nn.layers.push_back(
        FeedForwardNeuralNetwork<>::Layer({{-1., 1.}, {1., -1.}}, {1., 0.}));
    NeuralNetworkPolicy<> policy(std::move(nn), {0, 1}, {0, 1}, 2);

    nn.input.clear();
    nn.input.emplace_back(0, 10, 5, 10);
    nn.input.emplace_back(0, 10, 5, 10);
    nn.layers.clear();
    nn.layers.push_back(
        FeedForwardNeuralNetwork<>::Layer({{1., -1.}, {-1., 1.}}, {0., 0.}));
    nn.layers.push_back(
        FeedForwardNeuralNetwork<>::Layer({{1., 0.}, {0., 1.}}, {0., 0.}));
    NeuralNetworkPolicy<> unsafe_policy(nn, {0, 1}, {0, 1}, 2);

    std::shared_ptr<Z3SMTFactory> smt = std::make_shared<Z3SMTFactory>();
    std::shared_ptr<NNLPSMTFactory> nnlp =
        std::make_shared<NNLPSMTFactory>(smt.get(), false);

    epic3::Cube state;
    state.set(0, epic3::Interval(Value(0)));
    state.set(1, epic3::Interval(Value(0)));

    epic3::PIC3SatParameters sat_params{smt, nnlp, false, true, false, false};
    Model model{vspace, edges, {}};

    SECTION("Test policy selections")
    {
        flat_state state(2);
        state[0] = Value(0);
        state[1] = Value(0);
        REQUIRE(policy(state) == 0u);
        REQUIRE(unsafe_policy(state) == 0u);
        state[1] = Value(1);
        REQUIRE(policy(state) == 0u);
        REQUIRE(unsafe_policy(state) == 1u);
    }

    SECTION("Single step sat")
    {
        epic3::PIC3SatEdgeIndividual sat(
            sat_params,
            model,
            policy,
            {},
            LinearCondition::from_expression(one_zero));
        REQUIRE(!sat.is_blocked(state, 1).first);
    }

    SECTION("Single step unsat")
    {
        epic3::PIC3SatEdgeIndividual sat(
            sat_params,
            model,
            policy,
            !LinearCondition::from_expression(goal),
            LinearCondition::from_expression(zero_one));
        REQUIRE(sat.is_blocked(state, 1).first);
    }

    SECTION("Single step sat relu")
    {
        epic3::PIC3SatEdgeIndividual sat(
            sat_params,
            model,
            unsafe_policy,
            !LinearCondition::from_expression(goal),
            LinearCondition::from_expression(
                y >= expressions::MakeConstant()(2.)));
        state.set(1, epic3::Interval(Value(1)));
        REQUIRE(!sat.is_blocked(state, 1).first);
    }

    SECTION("Single step unsat relu")
    {
        epic3::PIC3SatEdgeIndividual sat(
            sat_params,
            model,
            unsafe_policy,
            {},
            LinearCondition::from_expression(one_zero));
        state.set(1, epic3::Interval(Value(1)));
        REQUIRE(sat.is_blocked(state, 1).first);
    }
}
#endif

#if POLICE_Z3 && POLICE_Marabou
TEST_CASE("Test nn b&b", "[ic3][encoding][nn]")
{
    VariableSpace vspace;
    const auto xid = vspace.add_variable("x", BoundedIntType{0, 10});
    const auto yid = vspace.add_variable("y", BoundedIntType{0, 10});
    const expressions::Variable x(xid);
    const expressions::Variable y(yid);
    auto initial = (x <= y && y <= Value(5));
    auto zero_zero = x == Value(0) && y == Value(0);
    auto one_zero = x == Value(1) && y == Value(0);
    auto zero_one = x == Value(0) && y == Value(1);
    auto goal = (x == Value(10) && y == Value(10));
    auto avoid = (y + Value(1) <= x);
    vector<epic3::NormalizedEdge> edges;
    {
        auto guard = x <= Value(9);
        epic3::NormalizedAssignment eff(
            xid,
            expressions::to_linear_expression(x + Value(1)));
        edges.push_back(epic3::NormalizedEdge(
            0,
            expressions::to_linear_condition_normal_form(guard).front(),
            {{eff}}));
    }
    {
        auto guard = y <= Value(9);
        epic3::NormalizedAssignment eff(
            yid,
            expressions::to_linear_expression(y + Value(1)));
        edges.push_back(epic3::NormalizedEdge(
            1,
            expressions::to_linear_condition_normal_form(guard).front(),
            {{eff}}));
    }
    FeedForwardNeuralNetwork<> nn;
    nn.input.emplace_back(0, 10, 5, 10);
    nn.input.emplace_back(0, 10, 5, 10);
    nn.layers.push_back(
        FeedForwardNeuralNetwork<>::Layer({{1., 10.}, {10., 1.}}, {1., 0.}));
    epic3::NeuralNetworkPolicy<> policy(std::move(nn), {0, 1}, {0, 1}, 2);

    nn.input.clear();
    nn.input.emplace_back(0, 10, 5, 10);
    nn.input.emplace_back(0, 10, 5, 10);
    nn.layers.clear();
    nn.layers.push_back(
        FeedForwardNeuralNetwork<>::Layer({{1., -1.}, {-1., 1.}}, {0., 0.}));
    nn.layers.push_back(
        FeedForwardNeuralNetwork<>::Layer({{1., 0.}, {0., 1.}}, {0., 0.}));
    epic3::NeuralNetworkPolicy<> unsafe_policy(nn, {0, 1}, {0, 1}, 2);

    Z3SMTFactory smt;
    MarabouLPFactory marabou;
    NNLPSMTFactory nnlp_smt(&smt, false);
    NNLPBranchNBoundFactory nnlp(&marabou);

    epic3::Cube state(2);
    state[0] = Value(0);
    state[1] = Value(0);

    epic3::PIC3SatParameters
        sat_params{&smt, &nnlp, false, false, false, false};
    epic3::NormalizedModel
        model{vspace, expressions::MakeConstant()(true), {}, edges, nullptr};

    SECTION("Test policy selections")
    {
        flat_state state(2);
        state[0] = Value(0);
        state[1] = Value(0);
        REQUIRE(policy(state) == 0u);
        REQUIRE(unsafe_policy(state) == 0u);
        state[1] = Value(1);
        REQUIRE(policy(state) == 0u);
        REQUIRE(unsafe_policy(state) == 1u);
    }

    SECTION("Single step sat")
    {
        epic3::PIC3SatEdgeIndividual sat(
            sat_params,
            model,
            policy,
            {},
            expressions::to_linear_condition_normal_form(one_zero));
        REQUIRE(!sat.is_blocked(state, 1).first);
    }

    SECTION("Single step unsat")
    {
        epic3::PIC3SatEdgeIndividual sat(
            sat_params,
            model,
            policy,
            !expressions::to_linear_condition_normal_form(goal),
            expressions::to_linear_condition_normal_form(zero_one));
        REQUIRE(sat.is_blocked(state, 1).first);
    }

    SECTION("Single step unsat relu")
    {
        epic3::PIC3SatEdgeIndividual sat(
            sat_params,
            model,
            unsafe_policy,
            !expressions::to_linear_condition_normal_form(goal),
            expressions::to_linear_condition_normal_form(zero_one));
        REQUIRE(sat.is_blocked(state, 1).first);
    }

    SECTION("Single step unsat relu")
    {
        epic3::PIC3SatEdgeIndividual sat(
            sat_params,
            model,
            unsafe_policy,
            {},
            expressions::to_linear_condition_normal_form(x + Value(2) <= y));
        state[1] = Value(1);
        REQUIRE(!sat.is_blocked(state, 1).first);
    }
}
#endif
