#pragma once

#include "police/base_types.hpp"
#include "police/jani/parser/expression.hpp"
#include "police/jani/parser/schema_factory.hpp"
#include "police/jani/parser/variable.hpp"

#include <optional>

namespace police::jani::parser {

struct TransientValue {
    identifier_name_t ref;
    Expression value;
};

struct Location {
    identifier_name_t name;
    std::optional<Expression> time_progress;
    std::vector<TransientValue> transient_values;
};

struct Assignment {
    identifier_name_t ref;
    Expression value;
    int index = 0;
};

struct Destination {
    identifier_name_t location;
    std::optional<Expression> probability = std::nullopt;
    std::vector<Assignment> assignments;
};

struct Edge {
    identifier_name_t location;
    std::optional<identifier_name_t> action = std::nullopt;
    std::optional<Expression> rate = std::nullopt;
    std::optional<Expression> guard = std::nullopt;
    std::vector<Destination> destinations;
};

struct Automaton {
    identifier_name_t name;
    std::vector<VariableDeclaration> variables;
    std::optional<Expression> restrict_initial = std::nullopt;
    std::vector<Location> locations;
    std::vector<identifier_name_t> initial_locations;
    std::vector<Edge> edges;
};

JaniSchema<Automaton> automaton_schema();

template <
    typename F,
    typename Automaton_cv,
    std::enable_if_t<
        std::is_same_v<std::remove_cvref_t<Automaton_cv>, Automaton>,
        int> = 0>
void apply_to_all_expressions(Automaton_cv&& automaton, F fn = F())
{
    if (automaton.restrict_initial.has_value()) {
        fn(automaton.restrict_initial.value());
    }
    for (auto& var : automaton.variables) {
        apply_to_all_expressions(var, fn);
    }
    for (auto& loc : automaton.locations) {
        if (loc.time_progress.has_value()) {
            fn(loc.time_progress.value());
        }
        for (auto& tv : loc.transient_values) {
            fn(tv.value);
        }
    }
    for (auto& edge : automaton.edges) {
        if (edge.rate.has_value()) {
            fn(edge.rate.value());
        }
        if (edge.guard.has_value()) {
            fn(edge.guard.value());
        }
        for (auto& dest : edge.destinations) {
            if (dest.probability.has_value()) {
                fn(dest.probability.value());
            }
            for (auto& assign : dest.assignments) {
                fn(assign.value);
            }
        }
    }
}

void visit_all(
    const Automaton& automaton,
    expressions::ExpressionVisitor& visitor);

void transform_all(
    Automaton& automaton,
    expressions::ExpressionTransformer& transformer);

} // namespace police::jani::parser
