#include "police/jani/parser/expression.hpp"
#include "police/base_types.hpp"
#include "police/expressions/comparison.hpp"
#include "police/expressions/expression.hpp"
#include "police/expressions/expressions.hpp"
#include "police/jani/parser/language.hpp"
#include "police/jani/parser/schema_factory.hpp"
#include "police/macros.hpp"

#include <initializer_list>
#include <utility>

namespace police::jani::parser {

namespace {

using namespace police::expressions;

template <typename T>
struct MakeExpression {
    template <typename... Args>
    Expression operator()(Args&&... args) const
    {
        return T(std::forward<Args>(args)...);
    }
};

struct BuiltinConstants {
    enum class Constant {
        EULER = 0,
        PI = 1,
    };

    Expression operator()(Constant constant) const
    {
        switch (constant) {
        case Constant::EULER: return MakeConstantEuler()();
        case Constant::PI: return MakeConstantPi()();
        }
        POLICE_UNREACHABLE();
    }
};

struct MakeBooleanCombination {
    enum class Operand { AND, OR, IMPLIES };
    Expression operator()(Operand op, Expression left, Expression right) const
    {
        switch (op) {
        case Operand::AND: return left && right;
        case Operand::OR: return left || right;
        case Operand::IMPLIES: return !left || right;
        }
        POLICE_UNREACHABLE();
    }
};

struct MakeDerivedComparison {
    enum class Operator { GREATER, GREATER_EQUAL };
    Expression operator()(Operator op, Expression left, Expression right) const
    {
        switch (op) {
        case Operator::GREATER: return greater(left, right);
        case Operator::GREATER_EQUAL: return greater_equal(left, right);
        }
        POLICE_UNREACHABLE();
    }
};

struct MakeDerivedMinMax {
    enum class Variant { MIN, MAX };
    Expression operator()(Variant v, Expression left, Expression right) const
    {
        switch (v) {
        case Variant::MIN:
            return police::expressions::ite(
                less_equal(left, right),
                left,
                right);
        case Variant::MAX:
            return police::expressions::ite(
                greater_equal(left, right),
                left,
                right);
        }
        POLICE_UNREACHABLE();
    }
};

} // namespace

JaniSchema<Expression> expression_schema()
{
    static JaniSchema<Expression> schema{};
    static bool initialized = false;

    if (!initialized) {
        auto identifier = JaniString<MakeExpression<IdentifierReference>>{};

        auto constant = JaniMultiSchemata(
            identifier,
            JaniBool<MakeConstant>{},
            JaniInteger<MakeConstant>{},
            JaniReal<MakeConstant>{},
            make_dictionary<BuiltinConstants>(
                {},
                {},
                JaniDictArgument(
                    lang::CONSTANT,
                    JaniEnum<
                        factories::Construct<BuiltinConstants::Constant>,
                        lang::EULER,
                        lang::PI>())));

        auto if_then_else = make_dictionary<MakeExpression<IfThenElse>>(
            {},
            {JaniDictElement(
                lang::OP,
                JaniEnum<factories::Discard, lang::ITE>())},
            JaniDictArgument(lang::IF, schema),
            JaniDictArgument(lang::THEN, schema),
            JaniDictArgument(lang::ELSE, schema));

        auto derived_min_max = make_dictionary(
            MakeDerivedMinMax(),
            {},
            JaniDictArgument(
                lang::OP,
                JaniEnum<
                    factories::Construct<MakeDerivedMinMax::Variant>,
                    lang::MIN,
                    lang::MAX>()),
            JaniDictArgument(lang::LEFT, schema),
            JaniDictArgument(lang::RIGHT, schema));

        auto bool_op = make_dictionary<MakeBooleanCombination>(
            {},
            {},
            JaniDictArgument(
                lang::OP,
                JaniEnum<
                    factories::Construct<MakeBooleanCombination::Operand>,
                    lang::AND,
                    lang::OR,
                    lang::IMPLIES>{}),
            JaniDictArgument(lang::LEFT, schema),
            JaniDictArgument(lang::RIGHT, schema));

        auto neg_op = make_dictionary<MakeExpression<Negation>>(
            {},
            {JaniDictElement(
                lang::OP,
                JaniEnum<factories::Discard, lang::NEGATE>())},
            JaniDictArgument(lang::EXP, schema));

        auto comp_op = make_dictionary<MakeExpression<Comparison>>(
            {},
            {},
            JaniDictArgument(
                lang::OP,
                JaniEnum<
                    factories::Construct<Comparison::Operator>,
                    lang::EQUAL,
                    lang::NOT_EQUAL,
                    lang::LESS,
                    lang::LESS_EQUAL>{}),
            JaniDictArgument(lang::LEFT, schema),
            JaniDictArgument(lang::RIGHT, schema));

        auto derived_comp_op = make_dictionary(
            MakeDerivedComparison(),
            {},
            JaniDictArgument(
                lang::OP,
                JaniEnum<
                    factories::Construct<MakeDerivedComparison::Operator>,
                    lang::GREATER,
                    lang::GREATER_EQUAL>{}),
            JaniDictArgument(lang::LEFT, schema),
            JaniDictArgument(lang::RIGHT, schema));

        auto bin_op = make_dictionary<MakeExpression<NumericOperation>>(
            {},
            {},
            JaniDictArgument(
                lang::OP,
                JaniEnum<
                    factories::Construct<NumericOperation::Operand>,
                    lang::PLUS,
                    lang::MINUS,
                    lang::TIMES,
                    lang::DIV,
                    lang::MOD>{}),
            JaniDictArgument(lang::LEFT, schema),
            JaniDictArgument(lang::RIGHT, schema));

        auto bin_fn = make_dictionary(
            MakeExpression<BinaryFunctionCall>(),
            {},
            JaniDictArgument(
                lang::OP,
                JaniEnum<
                    factories::Construct<BinaryFunctionCall::Function>,
                    lang::POW,
                    lang::LOG>()),
            JaniDictArgument(lang::LEFT, schema),
            JaniDictArgument(lang::RIGHT, schema));

        auto unary_fn = make_dictionary(
            MakeExpression<FunctionCall>(),
            {},
            JaniDictArgument(
                lang::OP,
                JaniEnum<
                    factories::Construct<FunctionCall::Function>,
                    lang::FLOOR,
                    lang::CEIL>()),
            JaniDictArgument(lang::EXP, schema));

        auto deriv = make_dictionary(
            MakeExpression<Derivative>(),
            {JaniDictElement(
                lang::OP,
                JaniEnum<factories::Discard, lang::DER>())},
            JaniDictArgument(
                lang::VAR,
                JaniString<factories::Construct<identifier_name_t>>()));

        schema = JaniMultiSchemata(
            std::move(identifier),
            std::move(constant),
            std::move(if_then_else),
            std::move(bool_op),
            std::move(neg_op),
            std::move(comp_op),
            std::move(bin_op),
            std::move(bin_fn),
            std::move(unary_fn),
            std::move(derived_comp_op),
            std::move(derived_min_max),
            std::move(deriv));

        initialized = true;
    }

    return schema;
}

} // namespace police::jani::parser
