#include "helpers.hpp"
#include "police/expressions/binary_function_call.hpp"
#include "police/expressions/boolean_combination.hpp"
#include "police/expressions/comparison.hpp"
#include "police/expressions/expression.hpp"
#include "police/expressions/function_call.hpp"
#include "police/expressions/ifthenelse.hpp"
#include "police/expressions/negation.hpp"
#include "police/expressions/numeric_operation.hpp"
#include "police/jani/parser/expression.hpp"
#include "police/jani/parser/language.hpp"

#include <catch2/catch.hpp>
#include <nlohmann/json.hpp>
#include <string_view>

namespace {

using json = nlohmann::json;
namespace ex = police::expressions;
namespace lang = police::jani::parser::lang;
using namespace police::test;

} // namespace

TEST_CASE("Parse constant integer expression", "[jani][parser][expression]")
{
    test_constant(123);
}

TEST_CASE("Parse constant bool expression", "[jani][parser][expression]")
{
    test_constant(true);
}

TEST_CASE("Parse constant real expression", "[jani][parser][expression]")
{
    test_constant(2.0);
}

TEST_CASE("Parse identifier expression", "[jani][parser][expression]")
{
    const std::string identifier = "varname";
    json js = identifier;
    auto expr = police::jani::parser::expression_schema()(js, true);
    REQUIRE(expr.has_value());
    require_is_identifier(expr.value(), identifier);
}

TEST_CASE("Parse conjunction expression", "[jani][parser][expression]")
{
    json js = make_operation(police::jani::parser::lang::AND, "x", "y");
    auto conj = test_expr_parser<ex::Conjunction>(js);
    REQUIRE(conj->children.size() == 2u);
    require_is_identifier(conj->children[0], "x");
    require_is_identifier(conj->children[1], "y");
}

TEST_CASE("Parse negation expression", "[jani][parser][expression]")
{
    json js = make_unary_operation(police::jani::parser::lang::NEGATE, "x");
    auto neg = test_expr_parser<ex::Negation>(js);
    require_is_identifier(neg->expr, "x");
}

TEST_CASE("Parse if-then-else expression", "[jani][parser][expression]")
{
    json js;
    js[police::jani::parser::lang::OP] = police::jani::parser::lang::ITE;
    js[lang::IF] = "x";
    js[lang::THEN] = "y";
    js[lang::ELSE] = "z";
    auto expr = test_expr_parser<ex::IfThenElse>(js);
    require_is_identifier(expr->condition, "x");
    require_is_identifier(expr->consequence, "y");
    require_is_identifier(expr->alternative, "z");
}

TEST_CASE("Parse comparison", "[jani][parser][expression]")
{
    json js = make_operation(lang::LESS_EQUAL, "x", "y");
    auto expr = test_expr_parser<police::expressions::Comparison>(js);
    REQUIRE(expr->op == ex::Comparison::Operator::LESS_EQUAL);
    require_is_identifier(expr->left, "x");
    require_is_identifier(expr->right, "y");
}

TEST_CASE("Parse sum", "[jani][parser][expression]")
{
    json js = make_operation(lang::PLUS, "x", "y");
    auto expr = test_expr_parser<police::expressions::NumericOperation>(js);
    REQUIRE(expr->operand == ex::NumericOperation::Operand::ADD);
    require_is_identifier(expr->left, "x");
    require_is_identifier(expr->right, "y");
}

TEST_CASE("Parse nested linear constraint", "[jani][parser][expression]")
{
    json js = make_operation(
        lang::NOT_EQUAL,
        make_operation(lang::PLUS, "x", make_operation(lang::TIMES, "y", 2)),
        10);
    auto expr = test_expr_parser<police::expressions::Comparison>(js);
    auto linexp = cast<ex::NumericOperation>(expr->left.base());
    require_is_identifier(linexp->left, "x");
    auto linexp2 = cast<ex::NumericOperation>(linexp->right.base());
    require_is_identifier(linexp2->left, "y");
    require_is_constant(linexp2->right, 2);
    require_is_constant(expr->right, 10);
}

TEST_CASE("Parse ceil", "[jani][parser][expression]")
{
    json js = make_unary_operation(lang::CEIL, "x");
    auto expr = test_expr_parser<police::expressions::FunctionCall>(js);
    REQUIRE(expr->function == ex::FunctionCall::Function::CEIL);
    require_is_identifier(expr->expr, "x");
}

TEST_CASE("Parse pow", "[jani][parser][expression]")
{
    json js = make_operation(lang::POW, "x", 2);
    auto expr = test_expr_parser<police::expressions::BinaryFunctionCall>(js);
    REQUIRE(expr->function == ex::BinaryFunctionCall::Function::POW);
    require_is_identifier(expr->left, "x");
    require_is_constant(expr->right, 2);
}
