#include "police/command_line_options.hpp"
#include "police/command_line_parser.hpp"
#include "police/global_arguments.hpp"
#include "police/jani/model.hpp"
#include "police/macros.hpp"
#include "police/verification_property.hpp"

#include <fstream>
#include <ostream>
#include <string>
#include <string_view>

using namespace police;

namespace {
unsigned get_num_locations(const jani::Model& model)
{
    unsigned res = 0;
    for (unsigned var_id = 0; var_id < model.variables.size(); ++var_id) {
        const auto& var_name = model.variables.get_name(var_id);
        res += var_name.starts_with("location_load");
    }
    return res;
}

unsigned get_num_packages(const jani::Model& model)
{
    for (unsigned var_id = 0; var_id < model.variables.size(); ++var_id) {
        const auto& var_name = model.variables.get_name(var_id);
        if (var_name.starts_with("location_load")) {
            const auto& var_type = model.variables.get_type(var_id);
            return static_cast<int_t>(var_type.get_upper_bound());
        }
    }
    POLICE_UNREACHABLE();
}

void dump_actions(const jani::Model& model, std::ostream& out)
{
    out << model.action_names.size() << "\n";
    out << "begin-operators" << "\n";
    for (unsigned action_id = 0; action_id < model.action_names.size();
         ++action_id) {
        const auto& name = model.action_names[action_id];
        if (name.starts_with("pickup")) {
            const std::string loc = name.substr(7);
            out << action_id << " pickup l" << loc << "\n";
        } else if (name.starts_with("drop")) {
            const std::string loc = name.substr(5);
            out << action_id << " drop l" << loc << "\n";
        } else if (name.starts_with("drive")) {
            auto i = std::find(name.begin(), name.end(), '_');
            auto j = std::find(i + 1, name.end(), '_');
            assert(i != name.end() && j != name.end());
            std::string_view src(i + 1, j);
            std::string_view dest(j + 1, name.end());
            std::cout << name 
            << " (" << action_id << ") => drive-"
                << (std::stoi(std::string(src)) < std::stoi(std::string(dest))
                        ? "forward"
                        : "backward") 
                << " l" << src << " l" << dest << "\n";
            out << action_id << " drive-"
                << (std::stoi(std::string(src)) < std::stoi(std::string(dest))
                        ? "forward"
                        : "backward")
                << " l" << src << " l" << dest << "\n";
        } else {
            POLICE_UNREACHABLE();
        }
    }
    out << "end-operators" << "\n" << std::flush;
}

void dump_state_interface(const jani::Model& model, std::ostream& out)
{
    const auto num_locs = get_num_locations(model);
    const auto num_packs = get_num_packages(model);
    out << (num_locs + 2) << "\n";
    out << "begin-variables" << "\n";
    for (unsigned var_id = 0, interface_var = 0;
         var_id < model.variables.size();
         ++var_id) {
        const auto& var_name = model.variables.get_name(var_id);
        if (var_name == "truck_0") {
            for (unsigned loc = 0; loc < num_locs; ++loc) {
                out << interface_var << " " << loc << " at l" << loc << "\n";
            }
        } else if (var_name == "truck_load_0") {
            for (unsigned pack = 0; pack <= num_packs; ++pack) {
                out << interface_var << " " << pack << " loaded n" << pack
                    << "\n";
            }
        } else if (var_name.starts_with("location_load")) {
            const auto loc = var_name.substr(14);
            for (unsigned pack = 0; pack <= num_packs; ++pack) {
                out << interface_var << " " << pack << " packages l" << loc
                    << " n" << pack << "\n";
            }
        } else {
            continue;
        }
        ++interface_var;
    }
    out << "end-variables" << "\n" << std::flush;
}

void generate_pddl_problem(
    const jani::Model& model,
    const VerificationProperty&,
    std::ostream& out)
{
    const unsigned num_locs = get_num_locations(model);
    const unsigned num_packs = get_num_packages(model);
    out << "(define (problem linetrack-l" << num_locs << "-p" << num_packs
        << ")\n"
        << "  (:domain linetrack)\n"
        << "  (:objects\n    ";
    for (unsigned x = 0; x < num_locs; ++x) {
        out << " l" << x;
    }
    out << " - location\n    ";
    for (unsigned x = 1; x <= num_packs; ++x) {
        out << " n" << x;
    }
    out << " - integer\n"
        << "  )\n"
        << "  (:init\n";
    out << "    (agent-choice)\n"
        << "    (at l0)\n"
        << "    (loaded n0)\n"
        << "    (packages l0 n" << num_packs << ")";
    for (unsigned loc = 1; loc < num_locs; ++loc) {
        out << "\n    (packages l" << loc << " n0)";
    }
    for (unsigned loc = 0; loc + 1 < num_locs; ++loc) {
        out << "\n    (CONNECTED l" << loc << " l" << (loc + 1) << ")";
    }
    for (unsigned n = 0; n < num_packs; ++n) {
        out << "\n    (NEXT n" << n << " n" << (n + 1) << ")";
    }
    for (unsigned n = 1; n < num_packs; ++n) {
        for (unsigned nn = 0; nn <= n; ++nn) {
            out << "\n    (LEQ n" << nn << " n" << n << ")";
        }
    }
    for (unsigned loc = 0; loc < num_locs; ++loc) {
        out << "\n    (CAPACITY l" << loc << " n"
            << (loc + 1 == num_locs ? 1 : num_packs) << ")";
    }
    out << "\n  )\n"
        << "  (:goal (and (agent-choice) (packages l" << (num_locs - 1) << " n" << num_packs
        << ")))\n"
        << ")\n";
}

void generate_policy_interface(const jani::Model& model, std::ostream& out)
{
    out << "{\n"
        << "  \"elements\": [],\n"
        << "  \"file\": \"\",\n"
        << "  \"filter\": false,\n"
        << "  \"input\": [\n";
    for (unsigned var_id = 0, interface_var = 0;
         var_id < model.variables.size();
         ++var_id) {
        const auto& var_name = model.variables.get_name(var_id);
        if (var_name == "last_capacity_diff") {
            continue;
        }
        out << (interface_var > 0 ? ",\n" : "") << "    {\n"
            << "      \"automaton\": null,\n"
            << "      \"name\": \"" << var_name << "\"\n    }";
        ++interface_var;
    }
    out << "\n  ],\n" << "  \"output\": [\n";
    for (unsigned action_id = 0; action_id < model.action_names.size();
         ++action_id) {
        out << "    \"" << model.action_names[action_id] << "\""
            << (action_id + 1 < model.action_names.size() ? "," : "") << "\n";
    }
    out << "  ]\n}\n";
}

} // namespace

int main(int argc, const char** argv)
{
    CommandLineParser cli;
    add_jani_options(cli);

    std::string model_interface, pddl, jani2nnet;
    cli.add_raw_argument(
        "out-jani2pddl",
        [&model_interface](const auto&, std::string_view path) {
            model_interface = path;
        },
        "",
        {},
        true);
    cli.add_raw_argument(
        "out-pddl",
        [&pddl](const auto&, std::string_view path) { pddl = path; },
        "",
        {},
        true);
    cli.add_raw_argument(
        "out-jani2nnet",
        [&jani2nnet](const auto&, std::string_view path) { jani2nnet = path; },
        "",
        {},
        true);

    auto args = cli.parse(argc, argv);

    const auto& model = *args.jani_model;
    const auto& prop = *args.property;

    std::ofstream interface;
    interface.open(model_interface.c_str());
    dump_actions(model, interface);
    dump_state_interface(model, interface);

    std::ofstream problem;
    problem.open(pddl.c_str());
    generate_pddl_problem(model, prop, problem);

    std::ofstream pol;
    pol.open(jani2nnet.c_str());
    generate_policy_interface(model, pol);

    return 0;
}
