#include "police/command_line_options.hpp"
#include "police/command_line_parser.hpp"
#include "police/global_arguments.hpp"
#include "police/jani/model.hpp"
#include "police/linear_condition.hpp"
#include "police/macros.hpp"
#include "police/verification_property.hpp"

#include <fstream>
#include <ostream>
#include <string>

using namespace police;

namespace {
unsigned get_num_blocks(const jani::Model& model)
{
    unsigned res = 0;
    for (unsigned var_id = 0; var_id < model.variables.size(); ++var_id) {
        const auto& var_name = model.variables.get_name(var_id);
        res += var_name.starts_with("block");
    }
    return res;
}

#if 0
void dump_action_interface(const jani::Model& model, std::ostream& out)
{
    const unsigned num_blocks = get_num_blocks(model);
    out << model.action_names.size() << "\n";
    out << "begin-operators" << "\n";
    for (unsigned action_id = 0; action_id < model.action_names.size();
         ++action_id) {
        if (model.action_names[action_id] == "choose_table") {
            for (unsigned block = 0; block < num_blocks; ++block) {
                for (unsigned on_table = 0; on_table < num_blocks - 1;
                     ++on_table) {
                    out << action_id << " put-on-table b" << block << " n"
                        << on_table << " n" << (on_table + 1) << "\n";
                }
                out << action_id << " dangerous-put-on-table b" << block << " n"
                    << (num_blocks - 1) << " n" << (num_blocks) << "\n";
            }
        } else if (model.action_names[action_id].starts_with("choose_block_")) {
            const unsigned block_id =
                std::stoi(model.action_names[action_id].substr(13));
            for (unsigned on_table = 1; on_table <= num_blocks; ++on_table) {
                out << action_id << " pickup-from-table b" << block_id << " n"
                    << on_table << " n" << (on_table - 1) << "\n";
            }
            for (unsigned block_id2 = 0; block_id2 < num_blocks; ++block_id2) {
                if (block_id == block_id2) continue;
                out << action_id << " pickup-from-block b" << block_id << " b"
                    << block_id2 << "\n";
            }
            for (unsigned block_id2 = 0; block_id2 < num_blocks; ++block_id2) {
                if (block_id == block_id2) continue;
                out << action_id << " put-on-block b" << block_id2 << " b"
                    << block_id << "\n";
            }
        } else {
            const auto& name = model.action_names[action_id];
            if (name.starts_with("pickup_from_table_b")) {
                const unsigned block_id =
                    std::stoi(model.action_names[action_id].substr(19));
                for (unsigned on_table = 1; on_table <= num_blocks;
                     ++on_table) {
                    out << action_id << " pickup-from-table b" << block_id
                        << " n" << on_table << " n" << (on_table - 1) << "\n";
                }
            } else if (name.starts_with("put_on_table_b")) {
                const unsigned block_id =
                    std::stoi(model.action_names[action_id].substr(14));
                for (unsigned on_table = 0; on_table < num_blocks - 1;
                     ++on_table) {
                    out << action_id << " put-on-table b" << block_id << " n"
                        << on_table << " n" << (on_table + 1) << "\n";
                }
                out << action_id << " dangerous-put-on-table b" << block_id
                    << " n" << (num_blocks - 1) << " n" << (num_blocks) << "\n";
            } else if (name.starts_with("pickup_from_block")) {
                auto pos = name.find_last_of('_');
                const unsigned b0 = std::stoi(name.substr(19, pos - 19));
                const unsigned b1 = std::stoi(name.substr(pos + 2));
                out << action_id << " pickup-from-block b" << b0 << " b" << b1
                    << "\n";
            } else {
                assert(name.starts_with("put_on_block"));
                auto pos = name.find_last_of('_');
                const unsigned b0 = std::stoi(name.substr(14, pos - 14));
                const unsigned b1 = std::stoi(name.substr(pos + 2));
                out << action_id << " put-on-block b" << b0 << " b" << b1
                    << "\n";
            }
        }
    }
    out << "end-operators" << "\n" << std::flush;
}

void dump_state_interface(const jani::Model& model, std::ostream& out)
{
    const unsigned num_blocks = get_num_blocks(model);
    // ignore cost variables
    out << (model.variables.size() - num_blocks - 1) << "\n";
    out << "begin-variables" << "\n";
    for (unsigned var_id = 0, interface_var = 0;
         var_id < model.variables.size();
         ++var_id) {
        const auto& var_name = model.variables.get_name(var_id);
        if (var_name == "hand-empty") {
            out << interface_var << " " << 1 << " " << "hand-empty" << "\n";
        } else if (var_name == "table-counter") {
            for (unsigned on_table = 0; on_table <= num_blocks; ++on_table) {
                out << interface_var << " " << on_table << " "
                    << "table-counter n" << on_table << "\n";
            }
        } else if (var_name.starts_with("block_")) {
            const unsigned block_id = std::stoi(var_name.substr(6));
            out << interface_var << " " << 0 << " " << "holding b" << block_id
                << "\n";
            out << interface_var << " " << 1 << " " << "on-table b" << block_id
                << "\n";
            for (unsigned block_id2 = 0, value = 2; block_id2 < num_blocks;
                 ++block_id2) {
                if (block_id == block_id2) continue;
                out << interface_var << " " << value << " on b" << block_id
                    << " b" << block_id2 << "\n";
                ++value;
            }
        } else if (var_name.starts_with("clear_")) {
            const unsigned block_id = std::stoi(var_name.substr(6));
            out << interface_var << " " << 1 << " " << "clear b" << block_id
                << "\n";
        } else if (
            var_name == "step_cost" || var_name.starts_with("cost_block_")) {
            continue;
        } else {
            POLICE_RUNTIME_ERROR("unknown variable " << var_name);
        }
        ++interface_var;
    }
    out << "end-variables" << "\n" << std::flush;
}

void generate_pddl_problem(
    const jani::Model& model,
    const VerificationProperty& property,
    std::ostream& out)
{
    const unsigned num_blocks = get_num_blocks(model);
    out << "(define (problem jani-blocks-nblk" << num_blocks << ")\n"
        << "  (:domain jani-blocks)\n"
        << "  (:objects";
    for (unsigned b = 0; b < num_blocks; ++b) {
        out << " b" << b;
    }
    out << " - block\n"
        << "           ";
    for (unsigned n = 0; n <= num_blocks; ++n) {
        out << " n" << n;
    }
    out << " - integer\n"
        << "  )\n"
        << "  (:init\n"
        << "    (on-table b" << (num_blocks - 1) << ")\n";
    for (unsigned b = 0; b + 1 < num_blocks; ++b) {
        out << "    (on b" << b << " b" << (b + 1) << ")\n";
    }
    out << "    (clear b" << (0) << ")\n"
        << "    (table-counter n1)\n"
        << "    (table-within-load)\n"
        << "    (hand-empty)\n";
    for (unsigned n = 1; n <= num_blocks; ++n) {
        out << "    (NEXT n" << (n - 1) << " n" << n << ")\n";
    }
    for (unsigned n = 0; n < num_blocks; ++n) {
        out << "    (TABLE-CAN-HOLD n" << n << ")\n";
    }
    out << "    (TABLE-CANNOT-HOLD n" << (num_blocks) << ")\n"
        << "  )\n"
        << "  (:goal (and\n";
    auto goal = LinearCondition::from_expression(property.reach);
    if (goal.size() != 1ul) {
        POLICE_RUNTIME_ERROR("expected conjunctive goal");
    }
    for (const auto& constraint : goal.front()) {
        if (constraint.size() != 1ul) {
            POLICE_RUNTIME_ERROR(
                "goal constraint references more than one variable");
        }
        const auto coef = constraint.coefs()[0];
        if (coef != 1.) {
            POLICE_RUNTIME_ERROR("invalid goal constraint: coefficient not 1");
        }
        if (constraint.type != LinearConstraint::EQUAL) {
            POLICE_RUNTIME_ERROR("invalid goal constraint: type not equal");
        }
        const auto var_id = constraint.refs()[0];
        const auto& var_name = model.variables.get_name(var_id);
        if (!var_name.starts_with("block_")) {
            POLICE_RUNTIME_ERROR(
                "invalid goal constraint: variable is not block position");
        }
        const int block = std::stoi(var_name.substr(6));
        const int value = constraint.rhs;
        if (value == 0) {
            out << "    (holding b" << block << ")\n";
        } else if (value == 1) {
            out << "    (on-table b" << block << ")\n";
        } else {
            if (value - 2 < block) {
                out << "    (on b" << block << " b" << (value - 2) << ")\n";
            } else {
                out << "    (on b" << block << " b" << (value - 1) << ")\n";
            }
        }
    }
    out << "    (table-within-load)\n"
        << "  ))\n"
        << ")\n";
}

#else

void dump_action_interface(const jani::Model& model, std::ostream& out)
{
    const unsigned num_blocks = get_num_blocks(model);
    out << model.action_names.size() << "\n";
    out << "begin-operators" << "\n";
    for (unsigned action_id = 0; action_id < model.action_names.size();
         ++action_id) {
        if (model.action_names[action_id] == "choose_table") {
            for (unsigned block = 0; block < num_blocks; ++block) {
                out << action_id << " put-down b" << block << "\n";
            }
        } else if (model.action_names[action_id].starts_with("choose_block_")) {
            const unsigned block_id =
                std::stoi(model.action_names[action_id].substr(13));
            out << action_id << " pick-up b" << block_id << "\n";
            for (unsigned block_id2 = 0; block_id2 < num_blocks; ++block_id2) {
                if (block_id == block_id2) continue;
                out << action_id << " unstack b" << block_id << " b"
                    << block_id2 << "\n";
            }
            for (unsigned block_id2 = 0; block_id2 < num_blocks; ++block_id2) {
                if (block_id == block_id2) continue;
                out << action_id << " stack b" << block_id2 << " b" << block_id
                    << "\n";
            }
        } else {
            const auto& name = model.action_names[action_id];
            if (name.starts_with("pickup_from_table_b")) {
                const unsigned block_id =
                    std::stoi(model.action_names[action_id].substr(19));
                out << action_id << " pick-up b" << block_id << "\n";
            } else if (name.starts_with("put_on_table_b")) {
                const unsigned block_id =
                    std::stoi(model.action_names[action_id].substr(14));
                out << action_id << " put-down b" << block_id << "\n";
            } else if (name.starts_with("pickup_from_block")) {
                auto pos = name.find_last_of('_');
                const unsigned b0 = std::stoi(name.substr(19, pos - 19));
                const unsigned b1 = std::stoi(name.substr(pos + 2));
                out << action_id << " unstack b" << b0 << " b" << b1 << "\n";
            } else {
                assert(name.starts_with("put_on_block"));
                auto pos = name.find_last_of('_');
                const unsigned b0 = std::stoi(name.substr(14, pos - 14));
                const unsigned b1 = std::stoi(name.substr(pos + 2));
                out << action_id << " stack b" << b0 << " b" << b1 << "\n";
            }
        }
    }
    out << "end-operators" << "\n" << std::flush;
}

void dump_state_interface(const jani::Model& model, std::ostream& out)
{
    const unsigned num_blocks = get_num_blocks(model);
    // ignore cost variables
    out << (1 + 2 * num_blocks) << "\n";
    out << "begin-variables" << "\n";
    for (unsigned var_id = 0, interface_var = 0;
         var_id < model.variables.size();
         ++var_id) {
        const auto& var_name = model.variables.get_name(var_id);
        if (var_name == "hand-empty") {
            out << interface_var << " " << 1 << " " << "handempty" << "\n";
        } else if (var_name.starts_with("block_")) {
            const unsigned block_id = std::stoi(var_name.substr(6));
            out << interface_var << " " << 0 << " " << "holding b" << block_id
                << "\n";
            out << interface_var << " " << 1 << " " << "ontable b" << block_id
                << "\n";
            for (unsigned block_id2 = 0, value = 2; block_id2 < num_blocks;
                 ++block_id2) {
                if (block_id == block_id2) continue;
                out << interface_var << " " << value << " on b" << block_id
                    << " b" << block_id2 << "\n";
                ++value;
            }
        } else if (var_name.starts_with("clear_")) {
            const unsigned block_id = std::stoi(var_name.substr(6));
            out << interface_var << " " << 1 << " " << "clear b" << block_id
                << "\n";
        } else {
            continue;
        }
        ++interface_var;
    }
    out << "end-variables" << "\n" << std::flush;
}

void generate_pddl_problem(
    const jani::Model& model,
    const VerificationProperty& property,
    std::ostream& out)
{
    const unsigned num_blocks = get_num_blocks(model);
    out << "(define (problem blocks-nblk" << num_blocks << ")\n"
        << "  (:domain blocks)\n"
        << "  (:objects";
    for (unsigned b = 0; b < num_blocks; ++b) {
        out << " b" << b;
    }
    out << "\n"
        << "  )\n"
        << "  (:init\n"
        << "    (ontable b" << (num_blocks - 1) << ")\n";
    for (unsigned b = 0; b + 1 < num_blocks; ++b) {
        out << "    (on b" << b << " b" << (b + 1) << ")\n";
    }
    out << "    (clear b" << (0) << ")\n"
        << "    (handempty)\n"
        << "  )\n"
        << "  (:goal (and\n";
    auto goal = LinearCondition::from_expression(property.reach);
    if (goal.size() != 1ul) {
        POLICE_RUNTIME_ERROR("expected conjunctive goal");
    }
    for (const auto& constraint : goal.front()) {
        if (constraint.size() != 1ul) {
            POLICE_RUNTIME_ERROR(
                "goal constraint references more than one variable");
        }
        const auto coef = constraint.coefs()[0];
        if (coef != 1.) {
            POLICE_RUNTIME_ERROR("invalid goal constraint: coefficient not 1");
        }
        if (constraint.type != LinearConstraint::EQUAL) {
            POLICE_RUNTIME_ERROR("invalid goal constraint: type not equal");
        }
        const auto var_id = constraint.refs()[0];
        const auto& var_name = model.variables.get_name(var_id);
        if (!var_name.starts_with("block_")) {
            POLICE_RUNTIME_ERROR(
                "invalid goal constraint: variable is not block position");
        }
        const int block = std::stoi(var_name.substr(6));
        const int value = constraint.rhs;
        if (value == 0) {
            out << "    (holding b" << block << ")\n";
        } else if (value == 1) {
            out << "    (ontable b" << block << ")\n";
        } else {
            if (value - 2 < block) {
                out << "    (on b" << block << " b" << (value - 2) << ")\n";
            } else {
                out << "    (on b" << block << " b" << (value - 1) << ")\n";
            }
        }
    }
    out << "  ))\n"
        << ")\n";
}

#endif

void generate_policy_interface(const jani::Model& model, std::ostream& out)
{
    out << "{\n"
        << "  \"elements\": [],\n"
        << "  \"file\": \"\",\n"
        << "  \"filter\": false,\n"
        << "  \"input\": [\n";
    for (unsigned var_id = 0, interface_var = 0;
         var_id < model.variables.size();
         ++var_id) {
        const auto& var_name = model.variables.get_name(var_id);
        if (var_name == "step_cost" || var_name.starts_with("cost_block_") ||
            var_name == "table-counter") {
            continue;
        }
        out << (interface_var > 0 ? ",\n" : "") << "    {\n"
            << "      \"automaton\": null,\n"
            << "      \"name\": \"" << var_name << "\"\n    }";
        ++interface_var;
    }
    out << "\n  ],\n" << "  \"output\": [\n";
    for (unsigned action_id = 0; action_id < model.action_names.size();
         ++action_id) {
        out << "    \"" << model.action_names[action_id] << "\""
            << (action_id + 1 < model.action_names.size() ? "," : "") << "\n";
    }
    out << "  ]\n}\n";
}

} // namespace

int main(int argc, const char** argv)
{
    CommandLineParser cli;
    add_jani_options(cli);

    std::string model_interface, pddl, jani2nnet;
    cli.add_raw_argument(
        "out-jani2pddl",
        [&model_interface](const auto&, std::string_view path) {
            model_interface = path;
        },
        "",
        {},
        true);
    cli.add_raw_argument(
        "out-pddl",
        [&pddl](const auto&, std::string_view path) { pddl = path; },
        "",
        {},
        true);
    cli.add_raw_argument(
        "out-jani2nnet",
        [&jani2nnet](const auto&, std::string_view path) { jani2nnet = path; },
        "",
        {},
        true);

    auto args = cli.parse(argc, argv);

    const auto& model = *args.jani_model;
    const auto& prop = *args.property;

    std::ofstream interface;
    interface.open(model_interface.c_str());
    dump_action_interface(model, interface);
    dump_state_interface(model, interface);

    std::ofstream problem;
    problem.open(pddl.c_str());
    generate_pddl_problem(model, prop, problem);

    std::ofstream pol;
    pol.open(jani2nnet.c_str());
    generate_policy_interface(model, pol);

    return 0;
}
