#pragma once

#include "../numbers.hpp"
#include "police/base_types.hpp"
#include "police/sat_model.hpp"
#include "police/storage/variable_space.hpp"

#include <catch2/catch.hpp>
#include <cmath>
#include <type_traits>

template <typename GetVarId>
void check_variable_domains(
    const police::SATModel& model,
    const police::VariableSpace& vspace,
    GetVarId&& get_var_id)
{
    for (auto var = 0u; var < vspace.size(); ++var) {
        const police::size_t sol_idx = get_var_id(var);
        const police::real_t value =
            static_cast<police::real_t>(model.get_value(sol_idx));
        std::visit(
            [&](auto&& t) {
                using T = std::decay_t<decltype(t)>;
                if constexpr (std::is_same_v<T, police::BoolType>) {
                    CHECK(value + PRECISION >= 0.);
                    CHECK(value - PRECISION <= 1.);
                } else if constexpr (
                    std::is_same_v<T, police::BoundedIntType> ||
                    std::is_same_v<T, police::BoundedRealType>) {
                    if (t.is_lower_bounded()) {
                        CHECK_THAT(value, IsAtLeast(t.lower_bound));
                    }
                    if (t.is_upper_bounded()) {
                        CHECK_THAT(value, IsAtMost(t.upper_bound));
                    }
                }
            },
            vspace[var].type);
    }
}

template <typename GetVarId>
void check_variable_types(
    const police::SATModel& model,
    const police::VariableSpace& vspace,
    GetVarId&& get_var_id)
{
    for (auto var = 0u; var < vspace.size(); ++var) {
        const police::size_t sol_idx = get_var_id(var);
        const police::real_t value =
            static_cast<police::real_t>(model.get_value(sol_idx));
        std::visit(
            [&](auto&& t) {
                using T = std::decay_t<decltype(t)>;
                if constexpr (
                    std::is_same_v<T, police::BoolType> ||
                    std::is_same_v<T, police::BoundedIntType> ||
                    std::is_same_v<T, police::IntegerType>) {
                    police::real_t rounded = std::round(value);
                    CHECK_THAT(
                        value,
                        Catch::Matchers::WithinAbs(rounded, PRECISION));
                }
            },
            vspace[var].type);
    }
}

void check_variable_domains(
    const police::SATModel& model,
    const police::VariableSpace& vspace);

void check_variable_types(
    const police::SATModel& model,
    const police::VariableSpace& space);
