#pragma once

#include "police/linear_condition.hpp"
#include "police/nn_policy.hpp"
#include "police/nnlp.hpp"
#include "police/nnlp_factory.hpp"
#include "police/verifiers/ic3/cube.hpp"
#include "police/verifiers/ic3/syntactic/applicability_conditioner.hpp"
#include "police/verifiers/ic3/syntactic/minimization_limits.hpp"
#include "police/verifiers/ic3/syntactic/policy_reasoner.hpp"
#include "police/verifiers/ic3/syntactic/sufficient_condition.hpp"
#include "police/verifiers/ic3/syntactic/variable_classification.hpp"
#include <type_traits>

namespace police::ic3::syntactic {

class PolicyReasonerNNLPBase : public PolicyReasoner {
public:
    [[nodiscard]]
    bool can_be_selected(const Cube& cube, size_t action);

protected:
    PolicyReasonerNNLPBase(
        NNLPFactory& nnlp_factory,
        const Model* model,
        const NeuralNetworkPolicy* policy,
        const vector<size_t>& var_order,
        double total_time_limit,
        double individual_time_limit,
        bool applicability_filter);

    template <bool MaskedApplicable>
    SuffCondAlternatives compute_reason(
        std::integral_constant<bool, MaskedApplicable>,
        const flat_state& state,
        const LinearConstraintConjunction& guard,
        size_t action,
        ApplicabilityConditioner* applicable = nullptr);

    vector<size_t> input_vars_;
    vector<size_t> relaxation_vars_;
    vector<std::shared_ptr<NNLP>> nnlps_;
    vector<VariableCategory> var_class_;

    MinimizationLimits limits_;

#ifndef POLICE_NO_STATISTICS
    std::shared_ptr<std::ostream> stats_file_ = nullptr;
#endif

    const Model* model_;
    const NeuralNetworkPolicy* policy_;
};

class PolicyReasonerNNLP final : public PolicyReasonerNNLPBase {
public:
    PolicyReasonerNNLP(
        NNLPFactory& nnlp_factory,
        const Model* model,
        const NeuralNetworkPolicy* policy,
        const vector<size_t>& var_order,
        double total_time_limit,
        double individual_time_limit)
        : PolicyReasonerNNLPBase(
              nnlp_factory,
              model,
              policy,
              var_order,
              total_time_limit,
              individual_time_limit,
              false)
    {
    }

    [[nodiscard]]
    SuffCondAlternatives get_reason(
        const flat_state& state,
        const LinearConstraintConjunction& guard,
        size_t action) override;
};

class PolicyReasonerNNLPMasked final : public PolicyReasonerNNLPBase {
public:
    PolicyReasonerNNLPMasked(
        NNLPFactory& nnlp_factory,
        const Model* model,
        const NeuralNetworkPolicy* policy,
        const vector<size_t>& var_order,
        double total_time_limit,
        double individual_time_limit)
        : PolicyReasonerNNLPBase(
              nnlp_factory,
              model,
              policy,
              var_order,
              total_time_limit,
              individual_time_limit,
              true)
    {
    }

    void prepare(const flat_state& state) override
    {
        info_ = ApplicabilityInformation(*model_, state);
    }

    [[nodiscard]]
    SuffCondAlternatives get_reason(
        const flat_state& state,
        const LinearConstraintConjunction& guard,
        size_t action) override;

private:
    ApplicabilityInformation info_;
};

} // namespace police::ic3::syntactic
