#include "police/command_line_options.hpp"
#include "police/command_line_parser.hpp"
#include "police/global_arguments.hpp"
#include "police/jani/model.hpp"
#include "police/linear_condition.hpp"
#include "police/macros.hpp"
#include "police/verification_property.hpp"

#include <cmath>
#include <fstream>
#include <ostream>
#include <string>

using namespace police;

namespace {
unsigned get_num_tiles(const jani::Model& model)
{
    unsigned res = 0;
    for (unsigned var_id = 0; var_id < model.variables.size(); ++var_id) {
        const auto& var_name = model.variables.get_name(var_id);
        res += var_name.starts_with("tile");
    }
    return res;
}

void dump_actions(const jani::Model& model, std::ostream& out)
{
    const unsigned num_tiles = get_num_tiles(model);
    const unsigned dimension = std::sqrt(num_tiles + 1);
    out << model.action_names.size() << "\n";
    out << "begin-operators" << "\n";
    for (unsigned action_id = 0; action_id < model.action_names.size();
         ++action_id) {
        const auto& name = model.action_names[action_id];
        const auto dir = name.substr(
            5,
            std::distance(
                name.begin() + 5,
                std::find(name.begin() + 5, name.end(), '_')));

        if (std::count(name.begin(), name.end(), '_') == 3) {
            const auto tile_str = name.substr(
                dir.length() + 6,
                name.rfind('_') - dir.length() - 6);
            const auto pos_str = name.substr(name.rfind('_') + 1);
            const int pos = std::stoi(pos_str);
            if (dir == "left") {
                out << action_id << " move-left t_" << tile_str << " p_" << pos
                    << " p_" << (pos - 1) << "\n";
            } else if (dir == "right") {
                out << action_id << " move-right t_" << tile_str << " p_" << pos
                    << " p_" << (pos + 1) << "\n";
            } else if (dir == "up") {
                out << action_id << " move-up t_" << tile_str << " p_" << pos
                    << " p_" << (pos - dimension) << "\n";
            } else if (dir == "down") {
                out << action_id << " move-down t_" << tile_str << " p_" << pos
                    << " p_" << (pos + dimension) << "\n";
            } else {
                POLICE_INTERNAL_ERROR("unknown action " << name);
            }
        } else {
            const auto tile_str = name.substr(name.rfind('_') + 1);
            if (dir == "left") {
                for (unsigned y = 1; y <= dimension; ++y) {
                    for (unsigned x = 1; x + 1 <= dimension; ++x) {
                        out << action_id << " move-left t_" << tile_str << " p_"
                            << (x + 1) << " p_" << x << " p_" << y << "\n";
                    }
                }
            } else if (dir == "right") {
                for (unsigned y = 1; y <= dimension; ++y) {
                    for (unsigned x = 1; x + 1 <= dimension; ++x) {
                        out << action_id << " move-right t_" << tile_str
                            << " p_" << (x) << " p_" << (x + 1) << " p_" << y
                            << "\n";
                    }
                }
            } else if (dir == "up") {
                for (unsigned x = 1; x <= dimension; ++x) {
                    for (unsigned y = 1; y + 1 <= dimension; ++y) {
                        out << action_id << " move-up t_" << tile_str << " p_"
                            << x << " p_" << (y + 1) << " p_" << (y) << "\n";
                    }
                }
            } else if (dir == "down") {
                for (unsigned x = 1; x <= dimension; ++x) {
                    for (unsigned y = 1; y + 1 <= dimension; ++y) {
                        out << action_id << " move-down t_" << tile_str << " p_"
                            << x << " p_" << y << " p_" << (y + 1) << "\n";
                    }
                }
            } else {
                POLICE_INTERNAL_ERROR("unknown action " << name);
            }
        }
    }
    out << "end-operators" << "\n" << std::flush;
}

void dump_variables(const jani::Model& model, std::ostream& out)
{
    const int num_tiles = get_num_tiles(model);
    out << (num_tiles + 1) << "\n";
    out << "begin-variables" << "\n";
    for (unsigned var_id = 0, interface_var = 0;
         var_id < model.variables.size();
         ++var_id) {
        const auto& var_name = model.variables.get_name(var_id);
        if (var_name == "empty") {
            for (int p = 0; p <= num_tiles; ++p) {
                out << interface_var << " " << p << " empty p_" << p << "\n";
            }
        } else if (var_name.starts_with("tile_")) {
            const auto tile = var_name.substr(5);
            for (int p = 0; p <= num_tiles; ++p) {
                out << interface_var << " " << p << " at t_" << tile << " p_"
                    << p << "\n";
            }
        } else {
            continue;
        }
        ++interface_var;
    }
    out << "end-variables" << "\n" << std::flush;
}

void dump_pddl_problem(
    const jani::Model& model,
    const VerificationProperty& property,
    std::ostream& out)
{
    const unsigned num_tiles = get_num_tiles(model);
    const unsigned dimension = std::sqrt(num_tiles + 1);
    out << "(define (problem npuzzle-n" << num_tiles << ")\n"
        << "  (:domain n-puzzle-typed)\n"
        << "  (:objects\n   ";
    for (unsigned i = 1; i <= num_tiles; ++i) {
        if (i != 3) out << " t_" << i;
    }
    out << " - tile\n   ";
    for (unsigned i = 0; i <= num_tiles; ++i) {
        if (i != 7) out << " p_" << i;
    }
    out << " - position";
    out << "\n"
        << "  )\n"
        << "  (:init"
        << "\n    (empty p_0)";
    for (unsigned tile = 1; tile <= num_tiles; ++tile) {
        out << "\n    (at t_" << tile << " p_" << tile << ")";
    }
    for (unsigned y = 0; y < dimension; ++y) {
        for (unsigned x = 0; x + 1 < dimension; ++x) {
            out << "\n    (LEFT p_" << (y * dimension + x) << " p_"
                << (y * dimension + x + 1) << ")";
        }
    }
    for (unsigned y = 0; y + 1 < dimension; ++y) {
        for (unsigned x = 0; x < dimension; ++x) {
            out << "\n    (ABOVE p_" << (y * dimension + x) << " p_"
                << ((y + 1) * dimension + x) << ")";
        }
    }
    // for (unsigned i = 0; i + 1 <= num_tiles; ++i) {
    //     out << "\n    (NEXT p_" << i << " p_" << (i + 1) << ")";
    // }
    out << "\n  )\n"
        << "  (:goal (and";
    auto goal = LinearCondition::from_expression(property.reach);
    if (goal.size() != 1ul) {
        POLICE_RUNTIME_ERROR("expected conjunctive goal");
    }
    for (const auto& constraint : goal.front()) {
        if (constraint.size() != 1ul) {
            POLICE_RUNTIME_ERROR(
                "goal constraint references more than one variable");
        }
        const auto coef = constraint.coefs()[0];
        if (coef != 1.) {
            POLICE_RUNTIME_ERROR("invalid goal constraint: coefficient not 1");
        }
        if (constraint.type != LinearConstraint::EQUAL) {
            POLICE_RUNTIME_ERROR("invalid goal constraint: type not equal");
        }
        const auto var_id = constraint.refs()[0];
        const auto& var_name = model.variables.get_name(var_id);
        if (var_name == "empty") {
            const int value = constraint.rhs;
            out << "\n    (empty p_" << value << ")";
            continue;
        } else if (!var_name.starts_with("tile_")) {
            POLICE_RUNTIME_ERROR(
                "invalid goal constraint: variable "
                << var_name << " is not tile position");
        }
        const int value = constraint.rhs;
        out << "\n    (at t_" << var_name.substr(5) << " p_" << value << ")";
    }
    out << "\n  ))\n"
        << ")\n";
}

void dump_jani2nnet(const jani::Model& model, std::ostream& out)
{
    out << "{\n"
        << "  \"elements\": [],\n"
        << "  \"file\": \"\",\n"
        << "  \"filter\": false,\n"
        << "  \"input\": [\n";
    for (unsigned var_id = 0, interface_var = 0;
         var_id < model.variables.size();
         ++var_id) {
        const auto& var_name = model.variables.get_name(var_id);
        if (var_name == "empty" || var_name.starts_with("tile_")) {
            out << (interface_var > 0 ? ",\n" : "") << "    {\n"
                << "      \"automaton\": null,\n"
                << "      \"name\": \"" << var_name << "\"\n    }";
            ++interface_var;
        }
    }
    out << "\n  ],\n" << "  \"output\": [\n";
    for (unsigned action_id = 0; action_id < model.action_names.size();
         ++action_id) {
        out << "    \"" << model.action_names[action_id] << "\""
            << (action_id + 1 < model.action_names.size() ? "," : "") << "\n";
    }
    out << "  ]\n}\n";
}

} // namespace

int main(int argc, const char** argv)
{
    CommandLineParser cli;
    add_jani_options(cli);

    std::string model_interface, pddl, jani2nnet;
    cli.add_raw_argument(
        "out-jani2pddl",
        [&model_interface](const auto&, std::string_view path) {
            model_interface = path;
        },
        "",
        {},
        true);
    cli.add_raw_argument(
        "out-pddl",
        [&pddl](const auto&, std::string_view path) { pddl = path; },
        "",
        {},
        true);
    cli.add_raw_argument(
        "out-jani2nnet",
        [&jani2nnet](const auto&, std::string_view path) { jani2nnet = path; },
        "",
        {},
        true);

    auto args = cli.parse(argc, argv);

    const auto& model = *args.jani_model;
    const auto& prop = *args.property;

    std::ofstream interface;
    interface.open(model_interface.c_str());
    dump_actions(model, interface);
    dump_variables(model, interface);

    std::ofstream problem;
    problem.open(pddl.c_str());
    dump_pddl_problem(model, prop, problem);

    std::ofstream pol;
    pol.open(jani2nnet.c_str());
    dump_jani2nnet(model, pol);

    return 0;
}
