#include "linex.hpp"
#include "numbers.hpp"

#include <catch2/catch.hpp>

namespace {
using namespace police;
}

bool is_satisfied(
    const police::LinearConstraint& constraint,
    const police::SATModel& model)
{
    real_t lhs = 0.;
    for (const auto [var, coef] : constraint) {
        lhs += static_cast<real_t>(model.get_value(var)) * coef;
    }
    switch (constraint.type) {
    case (police::LinearConstraint::LESS_EQUAL):
        return IsAtMost(constraint.rhs).match(lhs);
    case (police::LinearConstraint::GREATER_EQUAL):
        return IsAtLeast(constraint.rhs).match(lhs);
    case (police::LinearConstraint::EQUAL):
        return Catch::Matchers::WithinAbs(constraint.rhs, PRECISION).match(lhs);
    }
    POLICE_UNREACHABLE();
}

bool is_satisfied(
    const police::ReluConstraint& constraint,
    const police::SATModel& model)
{
    real_t x = static_cast<real_t>(model.get_value(constraint.x));
    real_t y = static_cast<real_t>(model.get_value(constraint.y));
    return Catch::Matchers::WithinAbs(std::max(x, 0.), PRECISION).match(y);
}

bool is_satisfied(
    const police::LinearConstraintDisjunction& constraint,
    const police::SATModel& model)
{
    return (std::any_of(
        constraint.begin(),
        constraint.end(),
        [&model](const auto& constraint) {
            return is_satisfied(constraint, model);
        }));
}

void check_constraint(
    const police::LinearConstraint& constraint,
    const police::SATModel& model)
{
    real_t lhs = 0.;
    for (const auto [var, coef] : constraint) {
        lhs += static_cast<real_t>(model.get_value(var)) * coef;
    }
    switch (constraint.type) {
    case (police::LinearConstraint::LESS_EQUAL):
        CHECK_THAT(lhs, IsAtMost(constraint.rhs));
        break;
    case (police::LinearConstraint::GREATER_EQUAL):
        CHECK_THAT(lhs, IsAtLeast(constraint.rhs));
        break;
    case (police::LinearConstraint::EQUAL):
        CHECK_THAT(lhs, Catch::Matchers::WithinAbs(constraint.rhs, PRECISION));
        break;
    }
}

void check_constraint(
    const police::ReluConstraint& constraint,
    const police::SATModel& model)
{
    real_t x = static_cast<real_t>(model.get_value(constraint.x));
    real_t y = static_cast<real_t>(model.get_value(constraint.y));
    CHECK_THAT(y, Catch::Matchers::WithinAbs(std::max(x, 0.), PRECISION));
}

void check_constraint(
    const police::LinearConstraintDisjunction& constraint,
    const police::SATModel& model)
{
    CHECK(std::any_of(
        constraint.begin(),
        constraint.end(),
        [&model](const auto& constraint) {
            return is_satisfied(constraint, model);
        }));
}
