#pragma once

#include "police/base_types.hpp"
#include "police/action.hpp"
#include "police/linear_condition.hpp"
#include "police/linear_expression.hpp"
#include "police/model.hpp"
#include "police/storage/flat_state.hpp"
#include "police/storage/segmented_vector.hpp"
#include "police/utils/stopwatch.hpp"
#include "police/verifiers/ic3/syntactic/abstraction.hpp"
#include "police/verifiers/ic3/syntactic/frames_storage.hpp"
#include "police/verifiers/ic3/syntactic/policy_reasoner.hpp"
#include "police/verifiers/ic3/syntactic/policy_reasons.hpp"
#include "police/verifiers/ic3/syntactic/sufficient_condition.hpp"
#include "police/verifiers/ic3/syntactic/variable_classification.hpp"
#include "police/verifiers/ic3/frame.hpp"

#include <memory>

namespace police::ic3::syntactic {

class SyntacticFrameRefiner {
public:
    struct Statistics {
        Statistics();

        StopWatch total_time;
        StopWatch policy_time;
        StopWatch update_time;
        StopWatch hitting_set_time;

        unsigned long long policy_calls_total = 0;
        unsigned long long policy_calls_max = 0;
        unsigned long long goal_reason = 0;
        unsigned long long no_policy_reason = 0;
        unsigned long long policy_reason = 0;
    };

    SyntacticFrameRefiner(
        FramesStorage* frames,
        std::shared_ptr<LinearCondition> goal,
        std::shared_ptr<LinearCondition> avoid,
        const Model* model,
        vector<size_t> var_ranks,
        SyntacticAbstraction* abstraction,
        std::shared_ptr<PolicyReasoner> policy_reasoner,
        PolicyReasons* policy_reasons = nullptr);

    size_t add_reason(const flat_state& state, size_t target_frame);

    const Statistics& get_statistics() const { return stats_; }

private:
    [[nodiscard]]
    SuffCondAlternatives get_suff_cond_satisfied(
        const flat_state& state,
        const LinearConstraintConjunction& cond) const;

    [[nodiscard]]
    SuffCondAlternatives get_suff_cond_satisfied(
        const flat_state& state,
        const LinearCondition& cond) const;

    [[nodiscard]]
    SuffCondAlternatives get_suff_cond_violated(
        const flat_state& state,
        const LinearConstraintConjunction& cond) const;

    [[nodiscard]]
    SuffCondAlternatives get_suff_cond_violated(
        const flat_state& state,
        const LinearConstraintConjunction& cond,
        Cube& relaxed_state) const;

    [[nodiscard]]
    SuffCondAlternatives
    get_suff_cond_violated(const flat_state& state, const LinearCondition& cond)
        const;

    [[nodiscard]]
    vector<Cube> get_suff_cond_violated(
        const Cube& cube,
        const LinearConstraintConjunction& cond) const;

    [[nodiscard]]
    vector<Cube>
    get_suff_cond_violated(const Cube& cube, const LinearCondition& cond) const;

    std::uint8_t handle_silent_action(
        vector<SuffCondAlternatives>& reasons,
        const flat_state& state,
        const Action& action,
        size_t frame_idx) const;

    bool handle_labeled_actions(
        vector<std::uint8_t>& applicability,
        vector<SuffCondAlternatives>& reasons,
        const flat_state& state,
        size_t label,
        size_t& action_idx,
        size_t frame_idx) const;

    std::uint8_t handle_labeled_action(
        vector<SuffCondAlternatives>& reasons,
        const flat_state& state,
        const Action& action,
        size_t frame_idx) const;

    bool handle_outcome(
        vector<SuffCondAlternatives>& reasons,
        const flat_state& state,
        const Cube& relaxed_state,
        const Action& action,
        const vector<Assignment>& outcome,
        size_t frame_idx) const;

    /**
     * Collect the conjunction C of half spaces such that post[C \land
     * action.guard, a] => cube.
     **/
    SufficientCondition regress_cube(
        const Cube& relaxed_state,
        const Cube& cube,
        const Action& action,
        const vector<Assignment>& outcome) const;

    /**
     * Create new cube given by the state restricted to the given half spaces,
     * and insert it at the given frame. Update data structures accordingly
     * (abstraction, prune subsumed cubes in other frames etc).
     **/
    size_t update_frames(
        const vector<std::uint8_t>& applicable,
        const Cube& cube,
        size_t frame);

    [[nodiscard]]
    Cube project_state(const flat_state& state, const SufficientCondition& vars)
        const;

    size_t insert_transition(
        size_t cube_id,
        const Cube& cube,
        const Cube& context_cube,
        size_t action_idx,
        const Action& action,
        const vector<Assignment>& outcomes);

    void apply_unit_constraints(
        Cube& cube,
        const LinearConstraintConjunction& constraints) const;

    void propagate_constraints(
        Cube& cube,
        const LinearConstraintConjunction& constraints) const;

    Cube
    get_post_condition(const Cube& cube, const vector<Assignment>& outcomes)
        const;

    size_t update_abstraction(
        const vector<std::uint8_t>& applicable,
        const Cube& cube,
        size_t cube_id);

    [[nodiscard]]
    vector<size_t>
    get_satisfied_cubes(const Cube& state, size_t frame_index) const;

    [[nodiscard]]
    vector<size_t> get_implied_cubes(const Cube& cube) const;

    struct BoundEstimates {
        real_t lb = 0;
        real_t ub = 0;
    };

    [[nodiscard]]
    BoundEstimates bound_combination(
        const LinearCombination<size_t, real_t>& comb,
        const Cube& cube) const;

    [[nodiscard]]
    bool
    is_constraint_violated(const LinearConstraint& constraint, const Cube& cube)
        const;

    [[nodiscard]]
    bool is_conjunction_violated(
        const LinearConstraintConjunction& conj,
        const Cube& cube) const;

    void drop_trivial_bounds(Cube& cube) const;

    void
    drop_trivial_conditions(SufficientCondition& cond, const flat_state& state)
        const;

    void
    drop_trivial_conditions(SuffCondAlternatives& cond, const flat_state& state)
        const;

    Cube prepare_progression(
        const Cube& cube,
        const LinearConstraintConjunction& guard) const;

    std::pair<size_t, bool> register_reason(Cube reason);

    vector<size_t> ranked_vars_;
    FramesStorage* frames_;
    SyntacticAbstraction* abstraction_;
    const Model* model_;
    std::shared_ptr<LinearCondition> goal_;
    std::shared_ptr<LinearCondition> avoid_;
    std::shared_ptr<PolicyReasoner> policy_reasoner_;
    Cube var_bounds_;
    PolicyReasons* reasons_;

    vector<VariableCategory> var_class_;

#ifndef POLICE_NO_STATISTICS
    std::shared_ptr<std::ostream> stats_file_ = nullptr;
#endif

    mutable Statistics stats_;
};

} // namespace police::ic3::syntactic

namespace police {
std::ostream& operator<<(
    std::ostream& out,
    const ic3::syntactic::SyntacticFrameRefiner::Statistics& stats);
} // namespace police
