#include "police/verifiers/ic3/syntactic/synic3_unit.hpp"

#include "police/action.hpp"
#include "police/addtree_policy.hpp"
#include "police/cg_policy.hpp"
#include "police/defaults.hpp"
#include "police/execution_unit.hpp"
#include "police/global_arguments.hpp"
#include "police/linear_condition.hpp"
#include "police/macros.hpp"
#include "police/masked_policy.hpp"
#include "police/model.hpp"
#include "police/nn_policy.hpp"
#include "police/option.hpp"
#include "police/smt_factory.hpp"
#include "police/storage/variable_space.hpp"
#include "police/successor_generator/applicable_actions_generator.hpp"
#include "police/successor_generator/determinize_outcomes_adapator.hpp"
#include "police/successor_generator/prune_successors.hpp"
#include "police/successor_generator/successor_generators.hpp"
#include "police/variable_order_chooser.hpp"
#include "police/verifiers/ic3/generic_ic3.hpp"
#include "police/verifiers/ic3/start_generator.hpp"
#include "police/verifiers/ic3/syntactic/abstraction.hpp"
#include "police/verifiers/ic3/syntactic/frame_refiner.hpp"
#include "police/verifiers/ic3/syntactic/frames_storage.hpp"
#include "police/verifiers/ic3/syntactic/policy_path_checker.hpp"
#include "police/verifiers/ic3/syntactic/policy_reasoner.hpp"
#include "police/verifiers/ic3/syntactic/policy_reasons.hpp"
#include "police/verifiers/ic3/syntactic/pruned_successor_generator.hpp"
#include "police/verifiers/ic3/syntactic/start_avoid_checker.hpp"

#include <memory>
#include <utility>

namespace police::ic3::syntactic {

namespace {
template <typename _SuccessorGenerator>
auto construct_successor_generator(
    _SuccessorGenerator base_generator,
    const Arguments& context)
{
    auto value_getter = [](const FlatState& state) {
        return [&state](size_t var) { return (state[var]); };
    };
    auto condition_evaluator =
        [&value_getter](const expressions::Expression& cond) {
            return [&value_getter, &cond](const FlatState& state) {
                return static_cast<bool>(
                    expressions::evaluate(cond, value_getter(state)));
            };
        };
    auto term_evaluator = condition_evaluator(context.get_property().reach);
    return successor_generator::determinize_outcomes_adapator(
        successor_generator::PruneSuccessors(
            std::move(base_generator),
            term_evaluator));
}

void add_common_parameters(
    const GlobalArguments& globals,
    ArgumentsDefinition& defs)
{
    defs.add_ptr_argument<SMTFactory>("smt_solver", "", DEFAULT_SMT_SOLVER);
    defs.add_argument<bool>("reschedule", "", "false");
    defs.add_ptr_argument<VariableOrderChooser>("tiebreaking", "", "default");
    defs.add_argument<std::string>("store_frames", "", "");
    if (globals.has_nn_policy() || globals.has_addtree_policy() ||
        globals.has_cg_policy()) {
        defs.add_ptr_argument<PolicyReasoner>(
            "policy_reasoner",
            "Reason extraction for the policy");
    }
}

void add_hybric3_parameters(
    const GlobalArguments& globals,
    ArgumentsDefinition& defs)
{
    add_common_parameters(globals, defs);
    defs.add_argument<int>(
        "perimeter",
        "Perimeter around policy to search in.",
        "inf");
}

StartGenerator create_start_generator(
    SMTFactory& smt_factory,
    const VariableSpace& vars,
    const expressions::Expression& start,
    int_t depth = 0)
{
    auto start_smt = smt_factory.make_shared();
    start_smt->add_variables(vars);
    start_smt->add_constraint(start);
    vector<std::pair<size_t, Value>> aux;
    if (depth > 0) {
        aux.emplace_back(vars.size(), Value(depth));
    }
    return StartGenerator(
        vars.size() + aux.size(),
        std::move(start_smt),
        std::move(aux));
}

template <typename SuccessorGenerator, typename PathChecker>
std::shared_ptr<ExecutionUnit> create_engine(
    const Arguments& args,
    SuccessorGenerator successor_generator,
    PathChecker path_val,
    std::unique_ptr<PolicyReasons> policy_reasons)
{
    auto smt_factory = args.get_ptr<SMTFactory>("smt_solver");
    const auto& model = args.get_model();
    const auto& prop = args.get_property();
    auto get_order = args.get_ptr<VariableOrderChooser>("tiebreaking");

    std::unique_ptr<FramesStorage> frames =
        std::make_unique<FramesStorage>(model.variables.size());
    std::unique_ptr<SyntacticAbstraction> abstraction =
        std::make_unique<SyntacticAbstraction>(model.variables.size());
    std::unique_ptr<size_t> cur_frame = std::make_unique<size_t>(1);

    StartAvoidChecker start_avoid_checker(*smt_factory, model.variables, prop);

    StartGenerator start_generator = create_start_generator(
        *smt_factory,
        model.variables,
        prop.start,
        args.has("perimeter") ? args.get<int>("perimeter") : 0);

    SyntacticFrameRefiner refiner(
        frames.get(),
        std::make_shared<LinearCondition>(
            LinearCondition::from_expression(prop.reach)),
        std::make_shared<LinearCondition>(
            LinearCondition::from_expression(prop.avoid)),
        &model,
        get_order->get_variable_order(),
        abstraction.get(),
        args.has("policy_reasoner")
            ? args.get_ptr<PolicyReasoner>("policy_reasoner")
            : nullptr,
        policy_reasons.get());

    return std::make_shared<IC3Engine<SuccessorGenerator, PathChecker>>(
        &model,
        &prop,
        std::move(start_generator),
        std::move(successor_generator),
        std::move(refiner),
        std::move(path_val),
        std::move(start_avoid_checker),
        std::move(cur_frame),
        std::move(frames),
        std::move(abstraction),
        std::move(policy_reasons),
        args.get<std::string>("store_frames"));
}

PointerOption<ExecutionUnit> _ic3(
    "synic3",
    [](const Arguments& args) -> std::shared_ptr<ExecutionUnit> {
        const Model& model = args.get_model();
        _detail::AllPathsAcceptor val;
        std::unique_ptr<PolicyReasons> policy_reasons =
            std::make_unique<PolicyReasons>(
                &model.variables,
                model.labels.size());
        auto create_for_gen = [&](auto successor_gen) {
            return create_engine(
                args,
                successor_generator::determinize_outcomes_adapator(
                    std::move(successor_gen)),
                std::move(val),
                std::move(policy_reasons));
        };
        if (args.has_nn_policy()) {
            const NeuralNetworkPolicy& policy = args.get_nn_policy();
            if (args.applicability_masking) {
                return create_for_gen(
                    successor_generator::MaskedPolicySuccessorGenerator(
                        &model.actions,
                        &policy));
            } else {
                return create_for_gen(
                    successor_generator::PolicySuccessorGenerator(
                        &model.actions,
                        &policy));
            }
        } else if (args.has_cg_policy()) {
            const CGPolicy& policy = args.get_cg_policy();
            if (args.applicability_masking) {
                return create_for_gen(
                    successor_generator::MaskedPolicySuccessorGenerator(
                        &model.actions,
                        &policy));
            } else {
                return create_for_gen(
                    successor_generator::PolicySuccessorGenerator(
                        &model.actions,
                        &policy));
            }
        } else if (args.has_addtree_policy()) {
            const AddTreePolicy& policy = args.get_addtree_policy();
            if (args.applicability_masking) {
                return create_for_gen(
                    successor_generator::MaskedPolicySuccessorGenerator(
                        &model.actions,
                        &policy));
            } else {
                return create_for_gen(
                    successor_generator::PolicySuccessorGenerator(
                        &model.actions,
                        &policy));
            }
        } else {
            return create_for_gen(
                successor_generator::BaseSuccessorGenerator(&model.actions));
        }
    },
    add_common_parameters);

PointerOption<ExecutionUnit> _hybric3(
    "hybric3",
    [](const Arguments& args) -> std::shared_ptr<ExecutionUnit> {
        if (!args.has_nn_policy() && !args.has_addtree_policy()) {
            POLICE_INVALID_ARGUMENT(
                "--engine",
                "the hybric3 engine supports only policy verification, but no "
                "policy was provided");
        }
        const Model& model = args.get_model();
        std::unique_ptr<PolicyReasons> policy_reasons =
            std::make_unique<PolicyReasons>(
                &model.variables,
                model.labels.size());
        std::shared_ptr<PolicyReasoner> policy_reasoner =
            args.get_ptr<PolicyReasoner>("policy_reasoner");
        successor_generator::ApplicableActionsGenerator aops_gen(
            &model.actions);
        auto construct_for_checker = [&](auto policy_wrapper, auto checker) {
            PrunedAopsGenerator aops_gen_wrapper(
                policy_reasons.get(),
                &model,
                std::move(aops_gen),
                std::move(policy_wrapper));
            PruningSuccessorGenerator succ_gen(
                &model.actions,
                std::move(aops_gen_wrapper));
            return create_engine(
                args,
                std::move(succ_gen),
                std::move(checker),
                std::move(policy_reasons));
        };
        auto construct_for_wrapper = [&](auto policy, auto policy_wrapper) {
            if (args.applicability_masking) {
                PolicyPathChecker checker(
                    PolicyPathValidationDirection::FORWARD,
                    MaskedPolicy(policy, &model.actions),
                    policy_reasons.get(),
                    &model,
                    policy_reasoner);
                return construct_for_checker(
                    std::move(policy_wrapper),
                    std::move(checker));
            } else {
                PolicyPathChecker checker(
                    PolicyPathValidationDirection::FORWARD,
                    [policy](const flat_state& state) {
                        return (*policy)(state);
                    },
                    policy_reasons.get(),
                    &model,
                    policy_reasoner);
                return construct_for_checker(
                    std::move(policy_wrapper),
                    std::move(checker));
            }
        };
        auto construct_for_policy = [&](auto policy) {
            if (args.applicability_masking) {
                return construct_for_wrapper(
                    policy,
                    [policy](
                        const flat_state& state,
                        const vector<size_t>& labels) {
                        auto res =
                            (*policy)(state, labels.begin(), labels.end());
                        return res == labels.end() ? SILENT_ACTION : *res;
                    });
            } else {
                return construct_for_wrapper(
                    policy,
                    [policy](const flat_state& state, const vector<size_t>&) {
                        return (*policy)(state);
                    });
            }
        };
        std::cout << "HybrIC3 perimeter: " << args.get<int_t>("perimeter")
                  << std::endl;
        if (args.has_nn_policy()) {
            const auto& policy = args.get_nn_policy();
            return construct_for_policy(&policy);
        }
        if (args.has_cg_policy()) {
            const auto& policy = args.get_cg_policy();
            return construct_for_policy(&policy);
        } else {
            const auto& policy = args.get_addtree_policy();
            return construct_for_policy(&policy);
        }
    },
    add_hybric3_parameters);

} // namespace

} // namespace police::ic3::syntactic
