#include "police/verifiers/search/search_unit.hpp"

#include "police/action.hpp"
#include "police/addtree_policy.hpp"
#include "police/arguments.hpp"
#include "police/cg_policy.hpp"
#include "police/expressions/expression_evaluator.hpp"
#include "police/model.hpp"
#include "police/nn_policy.hpp"
#include "police/option.hpp"
#include "police/option_parser.hpp"
#include "police/smt_factory.hpp"
#include "police/smt_model_enumerator.hpp"
#include "police/storage/state_registry.hpp"
#include "police/successor_generator/compress_state_adaptor.hpp"
#include "police/successor_generator/determinize_outcomes_adapator.hpp"
#include "police/successor_generator/prune_successors.hpp"
#include "police/successor_generator/successor_generators.hpp"
#include "police/utils/graphviz.hpp"
#include "police/utils/ranges.hpp"
#include "police/verification_property.hpp"
#include "police/verifiers/search/best_first_search.hpp"
#include "police/verifiers/search/novelty_queue.hpp"
#include "police/verifiers/search/state_novelty.hpp"

#include <fstream>

namespace police::search {

namespace {

struct Noop {
    void operator()(const CompressedState&) const {}
};

} // namespace

template <typename _SuccessorGenerator>
SearchUnit<_SuccessorGenerator>::SearchUnit(
    SearchConfig config,
    const VariableSpace* variables,
    const vector<identifier_name_t>* labels,
    const VerificationProperty* property,
    std::shared_ptr<StateRegistry> state_registry,
    _SuccessorGenerator succ_gen)
    : variables_(variables)
    , labels_(labels)
    , property_(property)
    , state_registry_(std::move(state_registry))
    , succ_gen_(std::move(succ_gen))
    , conf_(std::move(config))
{
}

template <typename _SuccessorGenerator>
std::string_view SearchUnit<_SuccessorGenerator>::name() const
{
    return "search";
}

namespace {

template <typename Fun>
void foreach_state(
    SMT& smt,
    StateRegistry& state_registry,
    const VariableSpace& variables,
    Fun fun)
{
    SMTModelEnumerator get_models;
    get_models(&smt, [&](const SMT::model_type& model) {
        size_t state_id;
        bool is_new = false;
        {
            FlatState state(model.size());
            for (int var = model.size() - 1; var >= 0; --var) {
                state[var] = Value(
                    model.get_value(var),
                    variables[var].type.value_type());
            }
            auto inserted = state_registry.insert(state);
            state_id = inserted.first;
            is_new = inserted.second;
        }
        if (is_new) {
            CompressedState state{state_id, state_registry[state_id]};
            return fun(state);
        }
        return false;
    });
}

template <typename Search>
struct SearchWrapper {

    bool operator()(const CompressedState& state)
    {
        auto res = search->operator()(state);
        if (res.has_value()) {
            *id = state.get_state_id();
            *ce_found = true;
            *ce = *res;
            return true;
        }
        return false;
    }

    SearchWrapper(Search* search, search::Path* ce, bool* ce_found, size_t* id)
        : search(std::move(search))
        , ce(ce)
        , ce_found(ce_found)
        , id(id)
    {
    }

    Search* search;
    search::Path* ce;
    bool* ce_found;
    size_t* id;
};

template <typename Search>
SearchWrapper(Search, std::vector<size_t>*, bool*) -> SearchWrapper<Search>;

} // namespace

template <typename _SuccessorGenerator>
void SearchUnit<_SuccessorGenerator>::run()
{
    auto smt = conf_.smt_factory->make_unique();
    smt->add_variables(*variables_);
    smt->add_constraint(property_->start);
    auto value_getter = [](const CompressedState& state) {
        return [&state](size_t var) { return (state[var]); };
    };
    auto condition_evaluator =
        [&value_getter](const expressions::Expression& cond) {
            return [&value_getter, &cond](const CompressedState& state) {
                return static_cast<bool>(
                    expressions::evaluate(cond, value_getter(state)));
            };
        };
    auto goal_evaluator = condition_evaluator(property_->avoid);
    auto term_evaluator = condition_evaluator(property_->reach);
    auto successor_generator =
        successor_generator::determinize_outcomes_adapator(
            successor_generator::PruneSuccessors(succ_gen_, term_evaluator));
    switch (conf_.variant) {
    case SearchConfig::BFWS: {
        auto search = create_best_first_search<CompressedState>(
            StateIdGetter(),
            StateLookup(state_registry_.get()),
            successor_generator,
            goal_evaluator,
            Noop(),
            search::NoveltyQueue<CompressedState, StateIdGetter, StateNovelty>(
                conf_.width,
                StateIdGetter(),
                StateNovelty(*variables_, conf_.width)));
        foreach_state(
            *smt,
            *state_registry_,
            *variables_,
            SearchWrapper(
                &search,
                &counter_example_,
                &counter_example_found_,
                &state_id_));
        break;
    }
    case SearchConfig::DOT: {
        auto dot = create_graphviz<CompressedState>(
            successor_generator,
            StateLookup(state_registry_.get()),
            StateIdGetter(),
            goal_evaluator,
            [&](size_t label) {
                return label != SILENT_ACTION ? labels_->at(label)
                                              : identifier_name_t{"tau"};
            },
            conf_.max_states);
        foreach_state(*smt, *state_registry_, *variables_, [&](auto&& s) {
            return !dot(s);
        });
        std::ofstream f("graph.dot");
        f << dot;
        break;
    }
    default: {
        auto search = create_breadth_first_search<CompressedState>(
            StateIdGetter(),
            StateLookup(state_registry_.get()),
            successor_generator,
            goal_evaluator,
            Noop());
        foreach_state(
            *smt,
            *state_registry_,
            *variables_,
            SearchWrapper(
                &search,
                &counter_example_,
                &counter_example_found_,
                &state_id_));
        break;
    }
    }
}

template <typename _SuccessorGenerator>
void SearchUnit<_SuccessorGenerator>::report_statistics()
{
    std::cout << "Registered states: " << state_registry_->size() << std::endl;
}

template <typename _SuccessorGenerator>
void SearchUnit<_SuccessorGenerator>::report_result()
{
    if (counter_example_found_) {
        std::cout << "Found property satisfying path." << std::endl;
        std::cout << "Path length: " << (counter_example_.size() - 1)
                  << std::endl;
        std::cout << "Start state: ";
        ranges::printer()((*state_registry_)[state_id_]) << std::endl;
        for (auto i = 0u; i + 1 < counter_example_.size(); ++i) {
            std::cout << (counter_example_[i].label == SILENT_ACTION
                              ? "<tau>"
                              : labels_->at(counter_example_[i].label))
                      << std::endl;
        }
    } else {
        std::cout << "Property is unsolvable." << std::endl;
    }
}

namespace {

std::shared_ptr<ExecutionUnit>
create(const Arguments& args, SearchConfig config)
{
    const Model& model = args.get_model();
    const VerificationProperty& property = args.get_property();
    std::shared_ptr<StateRegistry> state_registry =
        std::make_shared<StateRegistry>(model.variables);
    if (args.has_nn_policy()) {
        const NeuralNetworkPolicy& policy = args.get_nn_policy();
        if (args.applicability_masking) {
            successor_generator::compress_state_adaptor gen(
                successor_generator::MaskedPolicySuccessorGenerator(
                    &model.actions,
                    &policy),
                state_registry.get());
            return std::make_shared<SearchUnit<decltype(gen)>>(
                std::move(config),
                &model.variables,
                &model.labels,
                &property,
                std::move(state_registry),
                std::move(gen));
        } else {
            successor_generator::compress_state_adaptor gen(
                successor_generator::PolicySuccessorGenerator(
                    &model.actions,
                    &policy),
                state_registry.get());
            return std::make_shared<SearchUnit<decltype(gen)>>(
                std::move(config),
                &model.variables,
                &model.labels,
                &property,
                std::move(state_registry),
                std::move(gen));
        }
    } else if (args.has_addtree_policy()) {
        const AddTreePolicy& policy = args.get_addtree_policy();
        if (args.applicability_masking) {
            successor_generator::compress_state_adaptor gen(
                successor_generator::MaskedPolicySuccessorGenerator(
                    &model.actions,
                    &policy),
                state_registry.get());
            return std::make_shared<SearchUnit<decltype(gen)>>(
                std::move(config),
                &model.variables,
                &model.labels,
                &property,
                std::move(state_registry),
                std::move(gen));
        } else {
            successor_generator::compress_state_adaptor gen(
                successor_generator::PolicySuccessorGenerator(
                    &model.actions,
                    &policy),
                state_registry.get());
            return std::make_shared<SearchUnit<decltype(gen)>>(
                std::move(config),
                &model.variables,
                &model.labels,
                &property,
                std::move(state_registry),
                std::move(gen));
        }
    } else if (args.has_cg_policy()) {
        const CGPolicy& policy = args.get_cg_policy();
        if (args.applicability_masking) {
            successor_generator::compress_state_adaptor gen(
                successor_generator::MaskedPolicySuccessorGenerator(
                    &model.actions,
                    &policy),
                state_registry.get());
            return std::make_shared<SearchUnit<decltype(gen)>>(
                std::move(config),
                &model.variables,
                &model.labels,
                &property,
                std::move(state_registry),
                std::move(gen));
        } else {
            successor_generator::compress_state_adaptor gen(
                successor_generator::PolicySuccessorGenerator(
                    &model.actions,
                    &policy),
                state_registry.get());
            return std::make_shared<SearchUnit<decltype(gen)>>(
                std::move(config),
                &model.variables,
                &model.labels,
                &property,
                std::move(state_registry),
                std::move(gen));
        }
    } else {
        assert(!args.has_policy());
        successor_generator::compress_state_adaptor gen(
            successor_generator::BaseSuccessorGenerator(&model.actions),
            state_registry.get());
        return std::make_shared<SearchUnit<decltype(gen)>>(
            std::move(config),
            &model.variables,
            &model.labels,
            &property,
            std::move(state_registry),
            std::move(gen));
    }
}

void add_search_arguments(ArgumentsDefinition& defs)
{
    defs.add_ptr_argument<SMTFactory>("smt", "", "z3");
}

PointerOption<ExecutionUnit> _brfs(
    "brfs",
    [](const Arguments& args) {
        return create(
            args,
            {SearchConfig::BrFS, 0, args.get_ptr<SMTFactory>("smt")});
    },
    add_search_arguments);

PointerOption<ExecutionUnit> _bfws(
    "bwfs",
    [](const Arguments& args) {
        return create(
            args,
            {SearchConfig::BFWS,
             args.get<int>("width"),
             args.get_ptr<SMTFactory>("smt")});
    },
    [](ArgumentsDefinition& defs) {
        add_search_arguments(defs);
        defs.add_argument<int>("width", "", "2");
    });

PointerOption<ExecutionUnit> _dot(
    "dot",
    [](const Arguments& args) {
        return create(
            args,
            {SearchConfig::DOT,
             0,
             args.get_ptr<SMTFactory>("smt"),
             (size_t)args.get<int>("max_states")});
    },
    [](ArgumentsDefinition& defs) {
        add_search_arguments(defs);
        defs.add_argument<int>(
            "max_states",
            "",
            std::to_string(std::numeric_limits<int>::max()));
    });

} // namespace

} // namespace police::search
