#include "police/jani/model.hpp"
#include "police/base_types.hpp"
#include "police/action.hpp"
#include "police/linear_condition.hpp"
#include "police/linear_expression.hpp"
#include "police/expressions/comparison.hpp"
#include "police/expressions/expression.hpp"
#include "police/expressions/expression_visitor.hpp"
#include "police/expressions/expressions.hpp"
#include "police/expressions/expression_normalizer.hpp"
#include "police/jani/parser/parser.hpp"
#include "police/macros.hpp"
#include "police/storage/variable_space.hpp"

#include <algorithm>
#include <cassert>
#include <filesystem>
#include <fstream>
#include <iterator>
#include <limits>
#include <nlohmann/json.hpp>
#include <string_view>
#include <type_traits>
#include <variant>

namespace police::jani {

Model::Model(
    VariableSpace variables,
    vector<Automaton> automata,
    vector<SynchronizationVector> syncs,
    vector<expressions::Expression> properties,
    vector<identifier_name_t> action_names,
    vector<identifier_name_t> property_names,
    expressions::Expression initial_state,
    vector<size_t> global_vars) noexcept
    : variables(std::move(variables))
    , automata(std::move(automata))
    , syncs(std::move(syncs))
    , properties(std::move(properties))
    , action_names(std::move(action_names))
    , property_names(std::move(property_names))
    , global_vars(std::move(global_vars))
    , initial_state(std::move(initial_state))
{
}

Model Model::from_file(std::string_view file_name)
{
    std::ifstream inf(file_name.data());
    auto json = nlohmann::json::parse(inf);
    return police::jani::parse_model(json);
}

std::optional<expressions::Expression>
Model::get_property(std::string_view name) const
{
    if (std::filesystem::exists(std::filesystem::path(name.data()))) {
        return parse_state_condition(name, *this);
    } else {
        auto it = std::find_if(
            property_names.begin(),
            property_names.end(),
            [&](auto&& s) { return s == name; });
        if (it != property_names.end()) {
            return properties[std::distance(property_names.begin(), it)];
        }
    }
    return std::nullopt;
}

namespace {

class ExpressionTypeDeducer : public expressions::ExpressionVisitor {
public:
    explicit ExpressionTypeDeducer(const VariableSpace& vars)
        : vars_(vars)
    {
    }

    Value::Type operator()(const expressions::Expression& expr)
    {
        expr.accept(*this);
        return var_type_;
    }

    [[nodiscard]]
    static Value::Type max_type(const Value::Type& a, const Value::Type& b)
    {
        if (a == Value::Type::REAL || b == Value::Type::REAL) {
            return Value::Type::REAL;
        }
        if (a == Value::Type::INT || b == Value::Type::INT) {
            return Value::Type::INT;
        }
        return Value::Type::BOOL;
    }

    void visit(const expressions::BinaryFunctionCall&)
    {
        var_type_ = Value::Type::REAL;
    }

    void visit(const expressions::BinaryFunctionCallGeneric& expr)
    {
        expr.left.accept(*this);
        Value::Type left = var_type_;
        expr.right.accept(*this);
        var_type_ = max_type(left, var_type_);
    }

    void visit(const expressions::Conjunction&)
    {
        var_type_ = Value::Type::BOOL;
    }

    void visit(const expressions::Disjunction&)
    {
        var_type_ = Value::Type::BOOL;
    }

    void visit(const expressions::Comparison&)
    {
        var_type_ = Value::Type::BOOL;
    }

    void visit(const expressions::Constant& expr)
    {
        var_type_ = expr.get_value().get_type();
    }

    void visit(const expressions::Derivative&)
    {
        var_type_ = Value::Type::REAL;
    }

    void visit(const expressions::FunctionCall&)
    {
        var_type_ = Value::Type::INT;
    }

    void visit(const expressions::IdentifierReference&)
    {
        POLICE_RUNTIME_ERROR("unexpected identifier reference");
    }

    void visit(const expressions::IfThenElse& expr)
    {
        expr.consequence.accept(*this);
        Value::Type cons = var_type_;
        expr.alternative.accept(*this);
        var_type_ = max_type(var_type_, cons);
    }

    void visit(const expressions::Negation&) { var_type_ = Value::Type::BOOL; }

    void visit(const expressions::NumericOperation& expr)
    {
        expr.left.accept(*this);
        Value::Type left = var_type_;
        expr.right.accept(*this);
        var_type_ = max_type(left, var_type_);
    }

    void visit(const expressions::Variable& expr)
    {
        var_type_ = vars_[expr.var_id].type.value_type();
    }

    const VariableSpace& vars_;
    Value::Type var_type_ = Value::Type::REAL;
};

class StrictInequalityRemover : public expressions::ExpressionVisitor {
public:
    explicit StrictInequalityRemover(const VariableSpace& vars)
        : type_deducer_(vars)
    {
    }

    void visit_default(const expressions::Expression& expr) { result = expr; }

    void visit(const expressions::BinaryFunctionCall& expr)
    {
        result = expressions::BinaryFunctionCall(
            expr.function,
            (*this)(expr.left),
            (*this)(expr.right));
    }

    void visit(const expressions::BinaryFunctionCallGeneric& expr)
    {
        result = expressions::BinaryFunctionCallGeneric(
            expr.function,
            (*this)(expr.left),
            (*this)(expr.right));
    }

    void visit(const expressions::Conjunction& expr)
    {
        vector<expressions::Expression> transformed;
        transformed.reserve(expr.children.size());
        for (size_t i = 0; i < expr.children.size(); ++i) {
            expr.children[i].accept(*this);
            transformed.push_back(std::move(result.value()));
        }
        result = expressions::Conjunction(std::move(transformed));
    }

    void visit(const expressions::Disjunction& expr)
    {
        vector<expressions::Expression> transformed;
        transformed.reserve(expr.children.size());
        for (size_t i = 0; i < expr.children.size(); ++i) {
            expr.children[i].accept(*this);
            transformed.push_back(std::move(result.value()));
        }
        result = expressions::Disjunction(std::move(transformed));
    }

    void visit(const expressions::Comparison& expr)
    {
        const Value::Type left = type_deducer_(expr.left);
        const Value::Type right = type_deducer_(expr.right);
        switch (expr.op) {
        case expressions::Comparison::Operator::LESS:
            if (left != Value::Type::REAL && right != Value::Type::REAL) {
                expressions::Expression new_left =
                    expr.left + expressions::Constant(Value(1));
                result = expressions::Comparison(
                    expressions::Comparison::Operator::LESS_EQUAL,
                    new_left,
                    expr.right);
                break;
            }
            [[fallthrough]];
        case expressions::Comparison::Operator::NOT_EQUAL:
            if (left == Value::Type::BOOL && right == Value::Type::BOOL) {
                result = expressions::Disjunction(
                    {expressions::Conjunction({
                         expr.left,
                         expressions::Negation(expr.right),
                     }),
                     expressions::Conjunction({
                         !expr.left,
                         expr.right,
                     })});
                break;
            } else if (
                left != Value::Type::REAL && right != Value::Type::REAL) {
                result = expressions::Conjunction(
                    {expressions::Comparison(
                         expressions::Comparison::Operator::LESS_EQUAL,
                         expr.left + expressions::Constant(Value(1)),
                         expr.right),
                     expressions::Comparison(
                         expressions::Comparison::Operator::LESS_EQUAL,
                         expr.right + expressions::Constant(Value(1)),
                         expr.left)});
            }
            [[fallthrough]];
        default: result = expr; break;
        }
    }

    void visit(const expressions::Constant& expr) { visit_default(expr); }

    void visit(const expressions::Derivative& expr) { visit_default(expr); }

    void visit(const expressions::FunctionCall& expr)
    {
        result = expressions::FunctionCall(expr.function, (*this)(expr.expr));
    }

    void visit(const expressions::IdentifierReference& expr)
    {
        visit_default(expr);
    }

    void visit(const expressions::IfThenElse& expr)
    {
        result = expressions::IfThenElse(
            (*this)(expr.condition),
            (*this)(expr.consequence),
            (*this)(expr.alternative));
    }

    void visit(const expressions::Negation& expr)
    {
        result = expressions::Negation((*this)(expr.expr));
    }

    void visit(const expressions::NumericOperation& expr)
    {
        result = expressions::NumericOperation(
            expr.operand,
            (*this)(expr.left),
            (*this)(expr.right));
    }

    void visit(const expressions::Variable& expr) { visit_default(expr); }

    expressions::Expression operator()(const expressions::Expression& expr)
    {
        expr.accept(*this);
        return result.value();
    }

    std::optional<expressions::Expression> result = std::nullopt;
    ExpressionTypeDeducer type_deducer_;
};

expressions::Expression normalize_expression(
    const VariableSpace& vars,
    const expressions::Expression& expr)
{
    auto normalized = StrictInequalityRemover(vars)(expr);
    expressions::ExpressionNormalizer norm;
    normalized.transform(norm);
    return normalized;
}

police::Assignment normalize(const Assignment& assignment)
{
    return {
        assignment.var_id,
        LinearExpression::from_expression(assignment.value)};
}

vector<police::Assignment> normalize(const vector<Assignment>& assignments)
{
    if (assignments.empty()) {
        return {};
    }
    {
        int index = assignments.front().index;
        if (std::any_of(
                assignments.begin(),
                assignments.end(),
                [&](const auto& a) { return a.index != index; })) {
            POLICE_NOT_SUPPORTED(
                "Normalization doesn't support indexed assignments");
        }
    }
    vector<police::Assignment> result;
    result.reserve(assignments.size());
    std::transform(
        assignments.begin(),
        assignments.end(),
        std::back_inserter(result),
        [](const auto& a) { return normalize(a); });
    std::sort(result.begin(), result.end(), [](const auto& a, const auto& b) {
        return a.var_id < b.var_id;
    });
    return result;
}

police::Outcome normalize(const Outcome& outcome)
{
    return {normalize(outcome.assignments)};
}

vector<police::Outcome> normalize(const vector<Outcome>& outcomes)
{
    vector<police::Outcome> result;
    result.reserve(outcomes.size());
    std::transform(
        outcomes.begin(),
        outcomes.end(),
        std::back_inserter(result),
        [&](const auto& out) { return normalize(out); });
    return result;
}

real_t get_lower_bound(
    const LinearCombination<size_t, real_t>& co,
    const VariableSpace& variables)
{
    real_t result = 0.;
    for (const auto& [var, coef] : co) {
        const auto t = variables[var].type;
        result +=
            std::visit(
                [&coef](auto&& t) -> real_t {
                    using T = std::decay_t<decltype(t)>;
                    if constexpr (std::is_same_v<T, BoolType>) {
                        return coef < 0 ? 1 : 0;
                    } else if constexpr (std::is_same_v<T, BoundedIntType>) {
                        return coef < 0 ? t.upper_bound : t.lower_bound;
                    } else if constexpr (std::is_same_v<T, BoundedRealType>) {
                        return coef < 0 ? t.upper_bound : t.lower_bound;
                    } else {
                        return coef < 0
                                   ? std::numeric_limits<real_t>::infinity()
                                   : -std::numeric_limits<real_t>::infinity();
                    }
                },
                t) *
            coef;
    }
    return result;
}

real_t get_upper_bound(
    const LinearCombination<size_t, real_t>& co,
    const VariableSpace& variables)
{
    real_t result = 0.;
    for (const auto& [var, coef] : co) {
        const auto t = variables[var].type;
        result +=
            std::visit(
                [&coef](auto&& t) -> real_t {
                    using T = std::decay_t<decltype(t)>;
                    if constexpr (std::is_same_v<T, BoolType>) {
                        return coef >= 0. ? 1 : 0;
                    } else if constexpr (std::is_same_v<T, BoundedIntType>) {
                        return coef >= 0. ? t.upper_bound : t.lower_bound;
                    } else if constexpr (std::is_same_v<T, BoundedRealType>) {
                        return coef >= 0. ? t.upper_bound : t.lower_bound;
                    } else {
                        return coef >= 0.
                                   ? std::numeric_limits<real_t>::infinity()
                                   : -std::numeric_limits<real_t>::infinity();
                    }
                },
                t) *
            coef;
    }
    return result;
}

void incorporate_domain(
    LinearConstraint& constraint,
    const VariableSpace& variables)
{
    constraint.remove_zero_coefficients();
    if (constraint.size() == 1) {
        auto [var, coef] = *constraint.begin();
        if (coef < 0) {
            coef *= -1.;
            constraint.rhs *= -1.;
            constraint.type =
                constraint.type == LinearConstraint::LESS_EQUAL
                    ? LinearConstraint::GREATER_EQUAL
                : constraint.type == LinearConstraint::GREATER_EQUAL
                    ? LinearConstraint::LESS_EQUAL
                    : LinearConstraint::EQUAL;
        }
        if (coef != 1.) {
            constraint.rhs /= coef;
            coef = 1.;
        }
        if (constraint.type != LinearConstraint::EQUAL) {
            if (constraint.type == LinearConstraint::GREATER_EQUAL) {
                auto lb = get_upper_bound(constraint, variables);
                if (lb == constraint.rhs) {
                    constraint.type = LinearConstraint::EQUAL;
                }
            } else {
                auto ub = get_lower_bound(constraint, variables);
                if (ub == constraint.rhs) {
                    constraint.type = LinearConstraint::EQUAL;
                }
            }
        }
    }
}

LinearConstraintConjunction
normalize(const VariableSpace& variables, const expressions::Expression& guard)
{
    auto normed = normalize_expression(variables, guard);
    auto cond = LinearCondition::from_expression(normed);
    if (cond.size() != 1u) {
        POLICE_NOT_SUPPORTED("Normalization only supports conjunctive linear "
                            "constraints as guards");
    }
    for (auto& conj : cond) {
        for (auto& cons : conj) {
            incorporate_domain(cons, variables);
        }
    }
    return std::move(cond[0]);
}

police::Action normalize_(const VariableSpace& variables, const Edge& edge)
{
    return {
        edge.action,
        normalize(variables, edge.guard),
        normalize(edge.outcomes)};
}

} // namespace

police::Model Model::normalize() const
{
    if (automata.size() != 1u) {
        POLICE_RUNTIME_ERROR("currently automata networks of more than "
                            "automaton are not supported");
    }
    if (automata.front().num_locations() != 1u) {
        POLICE_RUNTIME_ERROR(
            "currently automata with more than one location are not supported");
    }
    const auto& a = automata.front();
    vector<Action> actions;
    actions.reserve(a.get_silent_edges(0).size() + a.get_sync_edges(0).size());
    std::transform(
        a.get_silent_edges(0).begin(),
        a.get_silent_edges(0).end(),
        std::back_inserter(actions),
        [this](const Edge& e) { return normalize_(variables, e); });
    std::transform(
        a.get_sync_edges(0).begin(),
        a.get_sync_edges(0).end(),
        std::back_inserter(actions),
        [this](const Edge& e) { return normalize_(variables, e); });
    std::sort(
        actions.begin() + a.get_silent_edges(0).size(),
        actions.end(),
        [](const auto& a, const auto& b) { return a.label < b.label; });
    return {variables, std::move(actions), action_names};
}

} // namespace police::jani
