#include "police/sas/parser.hpp"
#include "police/action.hpp"
#include "police/base_types.hpp"
#include "police/expressions/expression.hpp"
#include "police/expressions/negation.hpp"
#include "police/linear_condition.hpp"
#include "police/linear_constraint.hpp"
#include "police/linear_expression.hpp"
#include "police/macros.hpp"
#include "police/verification_property.hpp"
#include <cctype>
#include <cwctype>
#include <fstream>
#include <istream>
#include <optional>
#include <sstream>
#include <string>
#include <string_view>

namespace police::sas {

namespace {

void consume_magic(std::istream& in, const std::string& magic)
{
    std::string l;
    if (!std::getline(in, l) || l != magic) {
        POLICE_EXIT_INVALID_INPUT("Expected " << magic << "; but read " << l);
    }
}

int getint(std::istream& in)
{
    std::string l;
    if (!std::getline(in, l)) {
        POLICE_EXIT_INVALID_INPUT("Expected int but reached end of file");
    }
    return std::stoi(l);
}

std::string getstr(std::istream& in)
{
    std::string l;
    if (!std::getline(in, l)) {
        POLICE_EXIT_INVALID_INPUT("Expected string but reached end of file");
    }
    return l;
}

void skip_block(std::istream& in, const std::string& block_name)
{
    std::string l;
    if (!std::getline(in, l) || l != std::string("begin_") + block_name) {
        POLICE_EXIT_INVALID_INPUT(
            "Expected begin_" << block_name << "; but read " << l);
    }
    while (std::getline(in, l)) {
        if (l == "end_" + block_name) {
            return;
        }
    }
    POLICE_EXIT_INVALID_INPUT("Missing end_" << block_name);
}

void parse_variable(std::istream& in, Model& model)
{
    consume_magic(in, "begin_variable");
    auto var_name = getstr(in);
    getint(in);
    auto num_values = getint(in);
    vector<identifier_name_t> val_names(num_values);
    for (int val = 0; val < num_values; ++val) {
        val_names[val] = (getstr(in));
    }
    consume_magic(in, "end_variable");
    model.variables.add_variable(var_name, BoundedIntType(0, num_values - 1));
    model.value_names[model.variables.size() - 1] = std::move(val_names);
}

expressions::Expression parse_initial_state(std::istream& in, int num_vars)
{
    consume_magic(in, "begin_state");
    expressions::Expression result = expressions::Constant(Value(true));
    for (int var = 0; var < num_vars; ++var) {
        result = result && expressions::equal(
                               expressions::Variable(var),
                               expressions::Constant(Value(getint(in))));
    }
    consume_magic(in, "end_state");
    return result;
}

expressions::Expression parse_goal(std::istream& in)
{
    consume_magic(in, "begin_goal");
    expressions::Expression result = expressions::Constant(Value(true));
    int size = getint(in);
    for (int i = 0; i < size; ++i) {
        int var, val;
        in >> var;
        in >> val;
        result = result && expressions::equal(
                               expressions::Variable(var),
                               expressions::Constant(Value(val)));
        getstr(in);
    }
    consume_magic(in, "end_goal");
    return result;
}

void parse_operator(std::istream& in, Model& model)
{
    consume_magic(in, "begin_operator");
    auto name = getstr(in);
    LinearConstraintConjunction pres;
    vector<Assignment> effs;

    int num_prevails = getint(in);
    for (; num_prevails > 0; --num_prevails) {
        int var, val;
        in >> var;
        in >> val;
        pres.push_back(
            LinearConstraint::unit_constraint(
                var,
                LinearConstraint::EQUAL,
                val));
        getstr(in);
    }

    int num_effs = getint(in);
    for (; num_effs > 0; --num_effs) {
        int aux, var, pre, post;
        in >> aux;
        in >> var;
        in >> pre;
        in >> post;
        if (pre != -1) {
            pres.push_back(
                LinearConstraint::unit_constraint(
                    var,
                    LinearConstraint::EQUAL,
                    pre));
        }
        effs.push_back(Assignment(var, LinearExpression::constant(post)));
        getstr(in);
    }

    std::sort(
        pres.begin(),
        pres.end(),
        [](const LinearConstraint& x, const LinearConstraint& y) {
            return x.refs()[0] < y.refs()[0];
        });
    std::sort(
        effs.begin(),
        effs.end(),
        [](const Assignment& x, const Assignment& y) {
            return x.var_id < y.var_id;
        });

    // cost
    getint(in);

    consume_magic(in, "end_operator");

    model.actions.push_back(Action(
        model.labels.size(),
        std::move(pres),
        {Outcome(std::move(effs))}));
    model.labels.push_back(std::move(name));
}

expressions::Expression parse_avoid_condition(std::istream& in)
{
    consume_magic(in, "begin_avoid_condition");
    expressions::Expression res = expressions::Constant(Value(false));
    int size = getint(in);
    for (; size > 0; --size) {
        int facts;
        in >> facts;
        expressions::Expression cube = expressions::Constant(Value(true));
        for (; facts > 0; --facts) {
            int var, val;
            in >> var;
            in >> val;
            cube = cube && expressions::equal(
                               expressions::Variable(var),
                               expressions::Constant(Value(val)));
        }
        res = res || cube;
        getstr(in);
    }
    consume_magic(in, "end_avoid_condition");
    return res;
}

class FactMatcher {
public:
    explicit FactMatcher(const Model& model)
        : model(&model)
    {
    }

    std::pair<int, int> operator[](std::string_view fact_name) const
    {
        const auto atom = get_atom_string(fact_name);
        for (int var = model->value_names.size() - 1; var >= 0; --var) {
            const auto& facts = model->value_names.find(var)->second;
            for (int val = facts.size() - 1; val >= 0; --val) {
                if (facts[val] == atom) {
                    return {var, val};
                }
            }
        }
        POLICE_RUNTIME_ERROR("failed to find fact " << fact_name);
    }

private:
    static std::string get_atom_string(std::string_view fact_name)
    {
        std::ostringstream res;
        res << "Atom ";
        size_t i = 0;
        for (; i < fact_name.size() && fact_name[i] != ' '; ++i) {
        }
        res << fact_name.substr(0, i) << "(";
        bool first = true;
        for (; i + 1 < fact_name.size(); ++i) {
            const size_t j = i + 1;
            for (; i + 1 < fact_name.size() && fact_name[i + 1] != ' '; ++i) {
            }
            res << (first ? "" : ", ") << fact_name.substr(j, i - j + 1);
            first = false;
        }
        res << ")";
        return res.str();
    }

    const Model* model;
};

const char* skipws(const char* c, const char* end)
{
    while (c != end && std::iswspace(*c)) {
        ++c;
    }
    return c;
}

const char* nextws(const char* c, const char* end)
{
    while (c != end && !std::iswspace(*c)) {
        ++c;
    }
    return c;
}

std::string_view nextstr(const char*& begin, const char* end)
{
    const char* first = begin;
    begin = nextws(begin, end);
    return std::string_view(first, begin);
}

std::optional<expressions::Expression>
parse_expr(FactMatcher& facts, const char*& begin, const char* end)
{
    begin = skipws(begin, end);
    if (begin == end || *begin == ')') {
        return std::nullopt;
    }
    if (*begin != '(') {
        POLICE_RUNTIME_ERROR("Expected (");
    }
    begin = skipws(begin + 1, end);
    const char* old_begin = begin;
    auto op = nextstr(begin, end);
    if (op == "or") {
        expressions::Expression res = expressions::Constant(Value(false));
        for (;;) {
            auto e = parse_expr(facts, begin, end);
            if (!e.has_value()) {
                break;
            }
            res = res || e.value();
        }
        ++begin;
        return res;
    } else if (op == "and") {
        expressions::Expression res = expressions::Constant(Value(true));
        for (;;) {
            auto e = parse_expr(facts, begin, end);
            if (!e.has_value()) {
                break;
            }
            res = res && e.value();
        }
        ++begin;
        return res;
    } else if (op == "not") {
        auto e = parse_expr(facts, begin, end);
        if (begin == end || !e.has_value()) {
            POLICE_RUNTIME_ERROR("Expected )");
        }
        ++begin;
        return expressions::Negation(e.value());
    } else {
        begin = old_begin;
        while (begin != end && *begin != ')') {
            ++begin;
        }
        if (begin == end) {
            POLICE_RUNTIME_ERROR("Expected )");
        }
        auto fact = facts[std::string_view(old_begin, begin)];
        ++begin;
        return expressions::equal(
            expressions::Variable(fact.first),
            expressions::Constant(Value(fact.second)));
    }
}

} // namespace

std::pair<Model, VerificationProperty> parse(std::istream& in)
{
    Model res;
    skip_block(in, "version");
    skip_block(in, "metric");
    int num_vars = getint(in);
    for (int i = 0; i < num_vars; ++i) {
        parse_variable(in, res);
    }
    int mut_groups = getint(in);
    for (int i = 0; i < mut_groups; ++i) {
        skip_block(in, "mutex_group");
    }
    auto state = parse_initial_state(in, num_vars);
    auto goal = parse_goal(in);
    int num_ops = getint(in);
    for (int i = 0; i < num_ops; ++i) {
        parse_operator(in, res);
    }
    // axioms
    getint(in);
    auto avoid = parse_avoid_condition(in);
    VerificationProperty prop(
        std::move(state),
        std::move(goal),
        std::move(avoid));
    return {std::move(res), std::move(prop)};
}

std::pair<Model, VerificationProperty> parse(std::string_view path)
{
    std::ifstream inf(path.data());
    if (!inf.is_open()) {
        POLICE_RUNTIME_ERROR("couldn't open SAS file " << path);
    }
    return parse(inf);
}

expressions::Expression
parse_pddl_expression(const Model& model, std::string_view path)
{
    FactMatcher matcher(model);
    std::ifstream inf(path.data());
    if (!inf.is_open()) {
        POLICE_RUNTIME_ERROR("couldn't open PDDL file " << path);
    }
    std::stringstream buffer;
    buffer << inf.rdbuf();
    std::string content = buffer.str();
    for (unsigned i = 0; i < content.size(); ++i) {
        content[i] = std::tolower(content[i]);
    }
    const char* first = content.data();
    const char* end = content.data() + content.size();
    auto result = parse_expr(matcher, first, end);
    if (!result.has_value()) {
        POLICE_RUNTIME_ERROR("failed parsing pddl expression");
    }
    return result.value();
}

} // namespace police::sas
