#include "police/command_line_options.hpp"
#include "police/cg_policy.hpp"
#include "police/command_line_parser.hpp"
#include "police/compute_graph_factory.hpp"
#include "police/execution_unit.hpp"
#include "police/expressions/expression.hpp"
#include "police/expressions/variable.hpp"
#include "police/global_arguments.hpp"
#include "police/jani/model.hpp"
#include "police/jani/parser/jani2nn.hpp"
#include "police/jani/parser/parser.hpp"
#include "police/macros.hpp"
#include "police/nn_policy.hpp"
#include "police/sas/parser.hpp"
#include "police/smt_z3.hpp"
#include "police/static_variable_remover.hpp"
#include "police/utils/stopwatch.hpp"
#include "police/verification_property.hpp"

#include <algorithm>
#include <iterator>
#include <memory>
#include <string>

namespace police {

namespace {

constexpr char BASE_GROUP[] = "<BASE>";
constexpr char JANI_GROUP[] = "<JANI>";
constexpr char SAS_GROUP[] = "<SAS>";
constexpr char POLICY_GROUP[] = "<POLICY>";

constexpr char POST_PROCESS_GROUP[] = "<POST_PROCESS>";

constexpr char CORE_GROUP[] = "<CORE>";

} // namespace

void add_base_options(CommandLineParser& parser)
{
    parser.add_argument<int>(
        "seed",
        [](GlobalArguments& args, int value) { args.set_rng_seed(value); },
        "RNG seed");
    parser.create_dependency_group(BASE_GROUP, {"seed"});
}

void add_jani_options(CommandLineParser& parser)
{
    parser.add_raw_argument(
        "jani",
        [](GlobalArguments& result, std::string_view arg) {
            std::cout << "reading jani model..." << std::endl;
            ScopedStopWatch timer("jani model read");
            auto model = police::jani::Model::from_file(arg);
            timer.destroy();
            std::cout << "normalizing jani model..." << std::endl;
            police::ScopedStopWatch w("model normalized");
            auto normalized_model = model.normalize();
            w.destroy();
            result.model = std::make_shared<Model>(std::move(normalized_model));
            result.jani_model = std::make_shared<jani::Model>(std::move(model));
            result.model_path = arg;
        },
        "Path to JANI model.",
        {BASE_GROUP});
    parser.add_raw_argument(
        "jani-additional-properties",
        [](GlobalArguments& result, std::string_view arg) {
            auto timer = ScopedStopWatch("additional properties read");
            std::cout << "reading additional properties..." << std::endl;
            auto props =
                jani::parse_verification_property(arg, *result.jani_model);
            result.property =
                std::make_shared<VerificationProperty>(std::move(props));
        },
        "JANI file containing additional property specifications.",
        {"jani"});
    parser.create_dependency_group(
        JANI_GROUP,
        {"jani", "additional-properties"});
    parser.add_raw_argument(
        "start",
        [](GlobalArguments& result, std::string_view arg) {
            expressions::Expression start;
            size_t begin = 0;
            size_t var = 0;
            for (auto end = arg.find(','); end != std::string_view::npos;
                 begin = end + 1, end = arg.find(',', end + 1), ++var) {
                int val =
                    std::stoi(std::string(arg.substr(begin, end - begin)));
                auto expr = expressions::equal(
                    expressions::Variable(var),
                    expressions::Constant(Value(val)));
                if (var == 0) {
                    start = std::move(expr);
                } else {
                    start = start && expr;
                }
            }
            int val = std::stoi(std::string(arg.substr(begin)));
            auto expr = expressions::equal(
                expressions::Variable(var),
                expressions::Constant(Value(val)));
            if (var == 0) {
                start = std::move(expr);
            } else {
                start = start && expr;
            }
            if (var + 1 != result.model->variables.size()) {
                POLICE_EXIT_INVALID_INPUT(
                    "expected " << result.model->variables.size()
                                << " values but got only " << (var + 1));
            }
            result.property->start = std::move(start);
        },
        "Specify start state (comma separated list of values)",
        {JANI_GROUP});
}

void add_sas_options(CommandLineParser& parser)
{
    parser.add_raw_argument(
        "sas",
        [](GlobalArguments& result, std::string_view arg) {
            std::cout << "reading sas model..." << std::endl;
            ScopedStopWatch timer("sas model read");
            auto [model, property] = police::sas::parse(arg);
            timer.destroy();
            result.model = std::make_shared<Model>(std::move(model));
            result.model_path = arg;
            result.property =
                std::make_shared<VerificationProperty>(std::move(property));
        },
        "Path to output.sas file.",
        {BASE_GROUP});
    parser.add_raw_argument(
        "pddl-init",
        [](GlobalArguments& result, std::string_view arg) {
            auto timer = ScopedStopWatch("initial state formula read");
            std::cout << "reading initial state formula..." << std::endl;
            auto cond = police::sas::parse_pddl_expression(*result.model, arg);
            result.property = std::make_shared<VerificationProperty>(
                std::move(cond),
                std::move(result.property->reach),
                std::move(result.property->avoid));
        },
        "PDDL file containing initial state formula.",
        {"sas"});
    parser.create_dependency_group(SAS_GROUP, {"sas", "pddl-init"});
}

void add_policy_options(CommandLineParser& parser)
{
    parser.add_raw_argument(
        "policy-adapter",
        [](GlobalArguments& result, std::string_view arg) {
            result.policy_adapter_path = arg;
        },
        "Path to adapter specification connecting policy and model.");
    parser.add_raw_argument(
        "policy",
        [](GlobalArguments& result, std::string_view arg) {
            ScopedStopWatch sw("policy loaded");
            if (arg.ends_with(".nnet")) {
                std::cout << "loading nnet policy..." << std::endl;
                auto policy = police::jani::parse_policy(
                    arg,
                    result.policy_adapter_path,
                    *result.jani_model);
                result.nn_policy =
                    std::make_shared<NeuralNetworkPolicy>(std::move(policy));
            } else if (arg.ends_with(".json")) {
                std::cout << "loading tree-ensemble policy..." << std::endl;
                auto pi = police::jani::parse_addtree_policy(
                    arg,
                    result.policy_adapter_path,
                    *result.jani_model);
                result.tree_policy =
                    std::make_shared<AddTreePolicy>(std::move(pi));
            } else if (arg.ends_with(".cg")) {
                std::cout << "loading cg policy..." << std::endl;
                if (result.jani_model != nullptr) {
                    result.cg_policy = std::make_shared<CGPolicy>(
                        police::jani::parse_cg_policy(
                            arg,
                            result.policy_adapter_path,
                            *result.jani_model));
                } else {
                    auto interface =
                        jani::parse_policy_adapter(result.policy_adapter_path);
                    vector<size_t> input;
                    vector<size_t> output;
                    std::transform(
                        interface.input.begin(),
                        interface.input.end(),
                        std::back_inserter(input),
                        [&](const jani::parser::AutomatonVariable& ref) {
                            return result.model->variables.get_variable_id(
                                ref.variable);
                        });
                    std::transform(
                        interface.output.begin(),
                        interface.output.end(),
                        std::back_inserter(output),
                        [&](const auto& label) {
                            return result.model->get_label_id(label);
                        });
                    auto cg_node = cg::parse_nnet_file(arg);
                    auto post_processed = cg::post_process(
                        cg_node,
                        input,
                        result.model->variables);
                    result.cg_policy = std::make_shared<CGPolicy>(
                        std::move(post_processed.cg),
                        std::move(post_processed.inputs),
                        std::move(output),
                        result.model->labels.size());
                }
            } else {
                POLICE_INVALID_ARGUMENT(
                    "--policy",
                    "could not determine type of policy file " << arg);
            }
            result.policy_path = arg;
        },
        "Path to .nnet or .json policy file.",
        {"policy-adapter", JANI_GROUP});
    parser.add_flag(
        "applicability-mask",
        &GlobalArguments::applicability_masking,
        "Mask inapplicable actions when rolling out policy.");
    parser.create_dependency_group(
        POLICY_GROUP,
        {"policy-adapter", "policy", "applicability-mask"});
}

void add_postprocessing_options(CommandLineParser& parser)
{
    parser.add_flag(
        "keep-static-variables",
        nullptr,
        "",
        {POLICY_GROUP},
        [](GlobalArguments& args) {
            ScopedStopWatch sw("Static variable analysis");
            std::cout << "Looking for static variables..." << std::endl;
            StaticVariableRemover r(
                *args.model,
                args.property->start,
                std::make_shared<Z3SMTFactory>());
            std::cout << "Found " << r.num_static_variables()
                      << " static variables" << std::endl;
            if (r.has_static_variables()) {
                r.apply(*args.model);
                r.apply(*args.property);
                if (args.tree_policy) {
                    r.apply(args.tree_policy);
                }
                if (args.cg_policy) {
                    r.apply(args.cg_policy);
                }
                if (args.nn_policy) {
                    r.apply(args.nn_policy);
                }
            }
        });

    parser.create_dependency_group(
        POST_PROCESS_GROUP,
        {"keep-static-variables"});
}

void add_engine_option(CommandLineParser& parser)
{
    parser.add_argument<std::shared_ptr<ExecutionUnit>>(
        "run",
        &GlobalArguments::unit,
        "Execution unit to run.",
        {CORE_GROUP},
        true);
}

CommandLineParser get_command_line_options()
{
    CommandLineParser parser;
    add_base_options(parser);
    add_jani_options(parser);
    add_sas_options(parser);
    add_policy_options(parser);
    add_postprocessing_options(parser);
    parser.create_dependency_group(
        CORE_GROUP,
        {BASE_GROUP, JANI_GROUP, SAS_GROUP, POLICY_GROUP, POST_PROCESS_GROUP});
    add_engine_option(parser);
    return parser;
}

} // namespace police
