#include "police/verifiers/ic3/sat_based/ic3_engine.hpp"

#include "police/arguments.hpp"
#include "police/defaults.hpp"
#include "police/global_arguments.hpp"
#include "police/option.hpp"
#include "police/option_parser.hpp"
#include "police/action.hpp"
#include "police/addtree_policy.hpp"
#include "police/linear_condition.hpp"
#include "police/model.hpp"
#include "police/nn_policy.hpp"
#include "police/expressions/expression.hpp"
#include "police/storage/state_registry.hpp"
#include "police/successor_generator/determinize_outcomes_adapator.hpp"
#include "police/successor_generator/prune_successors.hpp"
#include "police/successor_generator/successor_generators.hpp"
#include "police/utils/plan_validator.hpp"
#include "police/verifiers/ic3/sat_based/generalizer/greedy_generalizer.hpp"
#include "police/verifiers/ic3/sat_based/generalizer/unsat_core_generalizer.hpp"
#include "police/verifiers/ic3/sat_based/goal_checker.hpp"
#include "police/verifiers/ic3/sat_based/ic3_sat.hpp"
#include "police/verifiers/ic3/sat_based/pic3sat_label.hpp"
#include "police/verifiers/ic3/sat_based/pic3sat_singleton.hpp"

#include <algorithm>
#include <memory>
#include <ranges>
#include <type_traits>

namespace police::ic3 {

namespace {

void update_categorial_variables(
    vector<bool>& result,
    const vector<Action>& edges)
{
    std::for_each(edges.begin(), edges.end(), [&](const auto& edge) {
        std::for_each(
            edge.outcomes.begin(),
            edge.outcomes.end(),
            [&](const Outcome& outcome) {
                const auto& assignments = outcome.assignments;
                std::for_each(
                    assignments.begin(),
                    assignments.end(),
                    [&](const Assignment& a) {
                        assert(a.var_id < result.size());
                        result[a.var_id] = result[a.var_id] && a.value.empty();
                    });
            });
    });
}

vector<bool>
get_categorial_variables(size_t num_vars, const vector<Action>& edges)
{
    vector<bool> result(num_vars, true);
    update_categorial_variables(result, edges);
    return result;
}

ic3::StartGenerator construct_start_generator(
    const SMTFactory& factory,
    const expressions::Expression& start,
    const VariableSpace& variables)
{
    auto start_smt = factory.make_shared();
    start_smt->add_variables(variables);
    start_smt->add_constraint(start);
    return ic3::StartGenerator(std::move(start_smt));
}

ic3::GreedyGeneralizer<GoalChecker> construct_greedy_generalizer(
    const Model* model,
    const VerificationProperty& property,
    const IC3Parameters& params)
{
    auto opts =
        std::dynamic_pointer_cast<GenMinimizeReason>(params.generalizer);
    assert(opts != nullptr);
    auto order = opts->order->get_variable_order();
    auto is_categorial =
        get_categorial_variables(model->variables.size(), model->actions);
    auto avoid = params.smt->make_shared();
    avoid->add_variables(model->variables);
    avoid->add_constraint(property.avoid);
    return {
        &model->variables,
        GoalChecker(std::move(avoid)),
        std::move(order),
        std::move(is_categorial)};
}

ic3::UnsatCoreGeneralizer<GoalChecker> construct_base_generalizer(
    const Model* model,
    const VerificationProperty& property,
    const IC3Parameters& params)
{
    auto avoid = params.smt->make_shared();
    avoid->add_variables(model->variables);
    avoid->add_constraint(property.avoid);
    return {&model->variables, GoalChecker(std::move(avoid))};
}

template <
    typename SuccessorGenerator,
    typename SatInterface,
    typename Generalizer>
typename IC3Engine<
    std::remove_cvref_t<SuccessorGenerator>,
    SatInterface,
    Generalizer>::ic3_type
construct_ic3(
    SuccessorGenerator&& successor_gen,
    SatInterface&& sat,
    Generalizer&& generalizer,
    const Model* model,
    const VerificationProperty& property,
    const IC3Parameters& params)
{
    auto start = construct_start_generator(
        *params.smt,
        property.start,
        model->variables);
    return typename IC3Engine<
        std::remove_cvref_t<SuccessorGenerator>,
        SatInterface,
        Generalizer>::
        ic3_type(
            &model->variables,
            property.avoid,
            std::move(start),
            std::forward<SuccessorGenerator>(successor_gen),
            std::forward<SatInterface>(sat),
            std::forward<Generalizer>(generalizer),
            params.obligation_rescheduling);
}

template <typename SuccessorGenerator, typename Generalizer>
auto construct_ic3(
    SuccessorGenerator&& successor_gen,
    Generalizer&& generalizer,
    const Model* model,
    const VerificationProperty& property,
    const IC3Parameters& params)
{
    auto sat = IC3SatInterface::create(
        *params.smt,
        model->variables,
        model->actions.begin(),
        model->actions.end(),
        !LinearCondition::from_expression(property.reach),
        LinearCondition::from_expression(property.avoid));
    return construct_ic3(
        std::forward<SuccessorGenerator>(successor_gen),
        std::move(sat),
        std::forward<Generalizer>(generalizer),
        model,
        property,
        params);
}

template <typename SuccessorGenerator, typename Generalizer>
auto construct_ic3(
    std::integral_constant<int, SatInterfaceOption::PER_LABEL>,
    SuccessorGenerator&& successor_gen,
    Generalizer&& generalizer,
    const Model* model,
    const NeuralNetworkPolicy& net,
    const VerificationProperty& property,
    const IC3Parameters& params)
{
    auto opt = std::dynamic_pointer_cast<EdgeIndividualSatOption>(
        params.sat_interface);
    assert(opt != nullptr);
    PIC3SatEdgeIndividual sat(
        opt->params,
        *model,
        net,
        !LinearCondition::from_expression(property.reach),
        LinearCondition::from_expression(property.avoid));
    return construct_ic3(
        std::forward<SuccessorGenerator>(successor_gen),
        std::move(sat),
        std::forward<Generalizer>(generalizer),
        model,
        property,
        params);
}

template <typename SuccessorGenerator, typename Generalizer>
auto construct_ic3(
    std::integral_constant<int, SatInterfaceOption::SINGLETON>,
    SuccessorGenerator&& successor_gen,
    Generalizer&& generalizer,
    const Model* model,
    const NeuralNetworkPolicy& net,
    const VerificationProperty& property,
    const IC3Parameters& params)
{
    auto opt =
        std::dynamic_pointer_cast<SingletonSatOption>(params.sat_interface);
    assert(opt != nullptr);
    PIC3SatSingleton sat(
        opt->params,
        *model,
        net,
        !LinearCondition::from_expression(property.reach),
        LinearCondition::from_expression(property.avoid));
    return construct_ic3(
        std::forward<SuccessorGenerator>(successor_gen),
        std::move(sat),
        std::forward<Generalizer>(generalizer),
        model,
        property,
        params);
}

} // namespace

template <
    typename SuccessorGenerator,
    typename SatInterface,
    typename Generalizer>
IC3Engine<SuccessorGenerator, SatInterface, Generalizer>::IC3Engine(
    std::shared_ptr<SMTFactory> smt_factory,
    const Model* model,
    const VerificationProperty* property,
    SuccessorGenerator successor_gen,
    ic3_type ic3)
    : model_(std::move(model))
    , property_(std::move(property))
    , sgen_(std::move(successor_gen))
    , smt_(std::move(smt_factory))
    , ic3_(std::move(ic3))
{
}

template <
    typename SuccessorGenerator,
    typename SatInterface,
    typename Generalizer>
std::string_view
IC3Engine<SuccessorGenerator, SatInterface, Generalizer>::name() const
{
    return "ic3";
}

template <
    typename SuccessorGenerator,
    typename SatInterface,
    typename Generalizer>
bool IC3Engine<SuccessorGenerator, SatInterface, Generalizer>::
    some_start_is_goal() const
{
    // check whether there is an unsafe initial state
    auto smt = smt_->make_unique();
    smt->add_variables(model_->variables);
    smt->add_constraint(property_->start);
    smt->add_constraint(property_->avoid);
    auto result = smt->solve();
    if (result == SMT::Status::SAT) {
        std::cout << "found an unsafe start state" << std::endl;
        auto model = smt->get_model();
        std::cout << "[";
        for (size_t var = 0; var < model.size(); ++var) {
            std::cout << (var > 0 ? ", " : "") << model.get_value(var);
        }
        std::cout << "]" << std::endl;
        return true;
    }
    return false;
}

template <
    typename SuccessorGenerator,
    typename SatInterface,
    typename Generalizer>
void IC3Engine<SuccessorGenerator, SatInterface, Generalizer>::run()
{
    std::cout << "Normalized model:" << "\n";
    model_->report_infos();
    std::cout << "Running IC3..." << std::endl;
    if (some_start_is_goal()) {
        return;
    }
    const auto result = ic3_();
    if (result.has_value()) {
        std::cout << "Plan found.\n";
        std::cout << "Plan length: " << (result->size() - 1) << "\n";
        for (auto i = 0u; i + 1 < result->size(); ++i) {
            if (result->at(i).label == SILENT_ACTION) {
                std::cout << "tau" << "\n";
            } else {
                std::cout << model_->labels.at(result->at(i).label) << "\n";
            }
        }
        std::cout << "Full path:\n";
        for (auto i = 0u; i + 1 < result->size(); ++i) {
            const auto& state = result->at(i).state;
            std::cout << "[";
            for (auto i = 0u; i < state.size(); ++i) {
                std::cout << (i > 0 ? ", " : "") << state[i];
            }
            std::cout << "]\n";
        }
        std::cout << std::flush;
        PlanValidator<SuccessorGenerator>(sgen_, property_)
            .validate_state_sequence(
                result.value() | std::ranges::views::transform(
                                     [](const auto& pr) { return pr.state; }));
    } else {
        std::cout << "Property is unsolvable" << std::endl;
    }
}

template <
    typename SuccessorGenerator,
    typename SatInterface,
    typename Generalizer>
void IC3Engine<SuccessorGenerator, SatInterface, Generalizer>::
    report_statistics()
{
    std::cout << ic3_.get_statistics();
}

template <typename SuccessorGenerator>
std::shared_ptr<ExecutionUnit> create_ic3_engine(
    const IC3Parameters& params,
    const Model* model,
    const VerificationProperty* property,
    SuccessorGenerator&& successor_gen)
{
    switch (params.generalizer->kind()) {
    case ReasonGeneralizerOption::GREEDY_MINIMIZE: {
        auto gen = construct_greedy_generalizer(model, *property, params);
        auto ic3 = construct_ic3(
            successor_gen,
            std::move(gen),
            model,
            *property,
            params);
        return std::make_shared<IC3Engine<
            SuccessorGenerator,
            typename decltype(ic3)::sat_interface_type,
            typename decltype(ic3)::generalizer_type>>(
            params.smt,
            std::move(model),
            property,
            std::forward<SuccessorGenerator>(successor_gen),
            std::move(ic3));
    }
    case ReasonGeneralizerOption::UNSAT_CORE: {
        auto gen = construct_base_generalizer(model, *property, params);
        auto ic3 = construct_ic3(
            successor_gen,
            std::move(gen),
            model,
            *property,
            params);
        return std::make_shared<IC3Engine<
            SuccessorGenerator,
            typename decltype(ic3)::sat_interface_type,
            typename decltype(ic3)::generalizer_type>>(
            params.smt,
            std::move(model),
            property,
            std::forward<SuccessorGenerator>(successor_gen),
            std::move(ic3));
    }
    }
    POLICE_UNREACHABLE();
}

namespace {
template <typename SuccessorGenerator, typename Generalizer>
std::shared_ptr<ExecutionUnit> create_ic3_engine(
    const IC3Parameters& params,
    const VerificationProperty* property,
    const NeuralNetworkPolicy& net,
    SuccessorGenerator&& successor_gen,
    const Model* model,
    Generalizer&& gen)
{
    assert(params.sat_interface != nullptr);
    const auto create_wrapper = [&](auto&& kind) {
        auto ic3 = construct_ic3(
            kind,
            successor_gen,
            std::forward<Generalizer>(gen),
            model,
            net,
            *property,
            params);
        return std::make_shared<IC3Engine<
            SuccessorGenerator,
            typename decltype(ic3)::sat_interface_type,
            typename decltype(ic3)::generalizer_type>>(
            params.smt,
            std::move(model),
            property,
            std::forward<SuccessorGenerator>(successor_gen),
            std::move(ic3));
    };
    switch (params.sat_interface->kind()) {
    case SatInterfaceOption::PER_LABEL:
        return create_wrapper(
            std::integral_constant<int, SatInterfaceOption::PER_LABEL>());
    case SatInterfaceOption::SINGLETON:
        return create_wrapper(
            std::integral_constant<int, SatInterfaceOption::SINGLETON>());
    }
    POLICE_UNREACHABLE();
}
} // namespace

template <typename SuccessorGenerator>
std::shared_ptr<ExecutionUnit> create_ic3_engine(
    const IC3Parameters& params,
    const Model* model,
    const VerificationProperty* property,
    const NeuralNetworkPolicy& net,
    SuccessorGenerator&& successor_gen)
{
    assert(params.sat_interface != nullptr);
    const auto create_wrapper = [&](auto&& gen) {
        return create_ic3_engine(
            params,
            property,
            net,
            std::forward<SuccessorGenerator>(successor_gen),
            std::move(model),
            std::forward<decltype(gen)>(gen));
    };
    switch (params.generalizer->kind()) {
    case ReasonGeneralizerOption::GREEDY_MINIMIZE: {
        return create_wrapper(
            construct_greedy_generalizer(model, *property, params));
    }
    case ReasonGeneralizerOption::UNSAT_CORE: {
        return create_wrapper(
            construct_base_generalizer(model, *property, params));
    }
    }
    POLICE_UNREACHABLE();
}

namespace {

const Model* get_model(const Arguments& context)
{
    return &context.get_model();
}

const VerificationProperty* get_property(const Arguments& context)
{
    return &context.get_property();
}

void add_parameters(const GlobalArguments& globals, ArgumentsDefinition& defs)
{
    if (globals.has_nn_policy()) {
        defs.add_ptr_argument<SatInterfaceOption>(
            "sat_interface",
            "SAT interface used by IC3");
    }
    defs.add_ptr_argument<ReasonGeneralizerOption>("reasons", "", "unsat_core");
    defs.add_ptr_argument<SMTFactory>("smt_solver", "", DEFAULT_SMT_SOLVER);
    defs.add_argument<bool>("reschedule", "", "false");
}

IC3Parameters get_parameters(const Arguments& args)
{
    auto gen = args.get<std::shared_ptr<ReasonGeneralizerOption>>("reasons");
    auto sat =
        args.has("sat_interface")
            ? args.get<std::shared_ptr<SatInterfaceOption>>("sat_interface")
            : nullptr;
    auto smt = args.get_ptr<SMTFactory>("smt_solver");
    bool reschedule = args.get<bool>("reschedule");
    return {std::move(gen), std::move(sat), std::move(smt), reschedule};
}

template <typename _SuccessorGenerator>
auto construct_successor_generator(
    _SuccessorGenerator base_generator,
    const Arguments& context)
{
    auto value_getter = [](const FlatState& state) {
        return [&state](size_t var) { return (state[var]); };
    };
    auto condition_evaluator =
        [&value_getter](const expressions::Expression& cond) {
            return [&value_getter, &cond](const FlatState& state) {
                return static_cast<bool>(
                    expressions::evaluate(cond, value_getter(state)));
            };
        };
    auto term_evaluator = condition_evaluator(context.get_property().reach);
    return successor_generator::determinize_outcomes_adapator(
        successor_generator::PruneSuccessors(
            std::move(base_generator),
            term_evaluator));
}

auto get_successor_generator_policy(const Arguments& context)
{
    assert(context.has_nn_policy());
    const Model* model = &context.get_model();
    const NeuralNetworkPolicy* policy = &context.get_nn_policy();
    return construct_successor_generator(
        successor_generator::PolicySuccessorGenerator(&model->actions, policy),
        context);
}

auto get_successor_generator(const Arguments& context)
{
    const Model* model = &context.get_model();
    return construct_successor_generator(
        successor_generator::BaseSuccessorGenerator(&model->actions),
        context);
}

PointerOption<ReasonGeneralizerOption> _minimize(
    "minimize",
    [](const Arguments& args) -> std::shared_ptr<ReasonGeneralizerOption> {
        auto var_order = args.get_ptr<VariableOrderChooser>("order");
        return std::make_shared<GenMinimizeReason>(std::move(var_order));
    },
    [](ArgumentsDefinition& defs) {
        defs.add_ptr_argument<VariableOrderChooser>("order", "", "default");
    });

PointerOption<ReasonGeneralizerOption> _core(
    "unsat_core",
    [](const Arguments&) -> std::shared_ptr<ReasonGeneralizerOption> {
        return std::make_shared<GenUnsatCore>();
    });

PointerOption<ExecutionUnit> _ic3(
    "ic3",
    [](const Arguments& args) -> std::shared_ptr<ExecutionUnit> {
        if (args.has_nn_policy()) {
            const auto& policy = args.get_nn_policy();
            return create_ic3_engine(
                get_parameters(args),
                get_model(args),
                get_property(args),
                policy,
                get_successor_generator_policy(args));
        } else {
            return create_ic3_engine(
                get_parameters(args),
                get_model(args),
                get_property(args),
                get_successor_generator(args));
        }
    },
    add_parameters);
} // namespace

} // namespace police::ic3
