#include "police/verifiers/sampling/sampler_unit.hpp"

#include "police/addtree_policy.hpp"
#include "police/cg_policy.hpp"
#include "police/expressions/expression_evaluator.hpp"
#include "police/nn_policy.hpp"
#include "police/option.hpp"
#include "police/option_parser.hpp"
#include "police/smt_factory.hpp"
#include "police/successor_generator/determinize_outcomes_adapator.hpp"
#include "police/successor_generator/prune_successors.hpp"
#include "police/successor_generator/successor_generators.hpp"
#include "police/utils/io.hpp"
#include "police/utils/rng.hpp"
#include "police/verifiers/sampling/sampler.hpp"

#include <type_traits>

namespace police::sampler {

template <typename SuccessorGenerator>
SamplerUnit<SuccessorGenerator>::SamplerUnit(
    const Model& model,
    const VerificationProperty& property,
    const SMTFactory& smt_factory,
    SuccessorGenerator successor_generator,
    size_t num_simulations,
    size_t max_length,
    std::shared_ptr<RNG> rng)
    : successor_generator_(std::move(successor_generator))
    , initial_states_(smt_factory.make_shared())
    , avoid_(property.avoid)
    , goal_(property.reach)
    , rng_(std::move(rng))
    , num_simulations_(num_simulations)
    , max_length_(max_length)
{
    initial_states_->add_variables(model.variables);
    initial_states_->add_constraint(property.start);
}

template <typename SuccessorGenerator>
std::string_view SamplerUnit<SuccessorGenerator>::name() const
{
    return "sample";
}

template <typename SuccessorGenerator>
void SamplerUnit<SuccessorGenerator>::run()
{
    auto lookup = [](const flat_state& state) {
        auto res = [&state](size_t idx) { return state[idx]; };
        return res;
    };
    auto check_avoid = [&](const flat_state& state) {
        std::cout << print_sequence(state) << std::endl;
        return expressions::evaluate(avoid_, lookup(state));
    };
    auto check_goal = [&](const flat_state& state) {
        return expressions::evaluate(goal_, lookup(state));
    };
    auto simulator = Sampler(
        initial_states_,
        successor_generator_,
        check_avoid,
        check_goal,
        *rng_,
        num_simulations_,
        max_length_);
    result_ = simulator();
}

template <typename SuccessorGenerator>
void SamplerUnit<SuccessorGenerator>::report_result()
{
    std::cout << result_ << std::flush;
}

template <typename SuccessorGenerator>
void SamplerUnit<SuccessorGenerator>::report_statistics()
{
}

namespace {
template <typename _SuccessorGenerator>
auto construct_successor_generator(
    _SuccessorGenerator base_generator,
    const Arguments& context)
{
    auto value_getter = [](const flat_state& state) {
        return [&state](size_t var) { return (state[var]); };
    };
    auto condition_evaluator =
        [&value_getter](const expressions::Expression& cond) {
            return [&value_getter, &cond](const flat_state& state) {
                return static_cast<bool>(
                    expressions::evaluate(cond, value_getter(state)));
            };
        };
    auto term_evaluator = condition_evaluator(context.get_property().reach);
    return successor_generator::determinize_outcomes_adapator(
        successor_generator::PruneSuccessors(
            std::move(base_generator),
            term_evaluator));
}

PointerOption<ExecutionUnit> _option(
    "sample",
    [](const Arguments& args) -> std::shared_ptr<ExecutionUnit> {
        auto create = [&](auto successor_generator) {
            auto engine = SamplerUnit(
                args.get_model(),
                args.get_property(),
                *args.get_ptr<SMTFactory>("smt_solver"),
                construct_successor_generator(
                    std::move(successor_generator),
                    args),
                args.get<int>("samples"),
                args.get<int>("max_length"),
                args.get_rng());
            return std::make_shared<std::remove_cv_t<decltype(engine)>>(
                std::move(engine));
        };
        const auto& model = args.get_model();
        if (args.has_nn_policy()) {
            const NeuralNetworkPolicy& policy = args.get_nn_policy();
            if (args.applicability_masking) {
                return create(
                    successor_generator::MaskedPolicySuccessorGenerator(
                        &model.actions,
                        &policy));
            } else {
                return create(
                    successor_generator::PolicySuccessorGenerator(
                        &model.actions,
                        &policy));
            }
        } else if (args.has_addtree_policy()) {
            const AddTreePolicy& policy = args.get_addtree_policy();
            if (args.applicability_masking) {
                return create(
                    successor_generator::MaskedPolicySuccessorGenerator(
                        &model.actions,
                        &policy));
            } else {
                return create(
                    successor_generator::PolicySuccessorGenerator(
                        &model.actions,
                        &policy));
            }
        } else if (args.has_cg_policy()) {
            const CGPolicy& policy = args.get_cg_policy();
            if (args.applicability_masking) {
                return create(
                    successor_generator::MaskedPolicySuccessorGenerator(
                        &model.actions,
                        &policy));
            } else {
                return create(
                    successor_generator::PolicySuccessorGenerator(
                        &model.actions,
                        &policy));
            }
        } else {
            assert(!args.has_policy());
            return create(
                successor_generator::BaseSuccessorGenerator(&model.actions));
        }
    },
    [](ArgumentsDefinition& defs) {
        defs.add_ptr_argument<SMTFactory>("smt_solver", "", "z3");
        defs.add_argument<int>("samples", "", "1000");
        defs.add_argument<int>("max_length", "", "1000");
    });
} // namespace

} // namespace police::sampler
