#pragma once

#include "police/base_types.hpp"
#include "police/compute_graph.hpp"
#include "police/layer_bounds.hpp"
#include "police/macros.hpp"
#include "police/utils/stopwatch.hpp"

namespace police::cg {

class NodeRelaxer {
public:
    virtual ~NodeRelaxer() = default;

    virtual size_t num_inputs() const = 0;

    virtual size_t num_outputs() const = 0;

    [[nodiscard]]
    virtual bool is_optimizable() const;

    virtual void set_input_bounds(const LayerBounds& bounds);

    virtual void init_optimization_round(size_t output_layer_size);

    template <bool Ub>
    void regress(Matrix& coefs)
    {
        if constexpr (Ub) {
            regress_ub(coefs);
        } else {
            regress_lb(coefs);
        }
    }

    template <bool Ub>
    void optimize()
    {
        if constexpr (Ub) {
            optimize_ub();
        } else {
            optimize_lb();
        }
    }

    template <bool Ub>
    void progress(Matrix& input)
    {
        if constexpr (Ub) {
            progress_ub(input);
        } else {
            progress_lb(input);
        }
    }

    [[nodiscard]]
    virtual std::pair<LinVector, LinVector>
    propagate_bounds(const LinVector& lbs, const LinVector& ubs) = 0;

    [[nodiscard]]
    virtual std::string name() const = 0;

private:
    virtual void regress_generic([[maybe_unused]] Matrix& coefs);
    virtual void regress_lb(Matrix& coefs) { regress_generic(coefs); }
    virtual void regress_ub(Matrix& coefs) { regress_generic(coefs); }

    virtual void progress_generic([[maybe_unused]] Matrix& coefs);
    virtual void progress_lb(Matrix& coefs) { progress_generic(coefs); }
    virtual void progress_ub(Matrix& coefs) { progress_generic(coefs); }

    virtual void optimize_lb() {}
    virtual void optimize_ub() {}
};

class NodeRelaxerFactory {
public:
    virtual ~NodeRelaxerFactory() = default;

    [[nodiscard]]
    std::shared_ptr<NodeRelaxer> create(const Node* node);

    [[nodiscard]]
    vector<std::shared_ptr<NodeRelaxer>> create_recursive(const Node* node);

private:
    class FactoryImpl;

    [[nodiscard]]
    virtual std::shared_ptr<NodeRelaxer> create_impl(const LinearLayer* node);

    [[nodiscard]]
    virtual std::shared_ptr<NodeRelaxer> create_impl(const ReluLayer* node);

    [[nodiscard]]
    virtual std::shared_ptr<NodeRelaxer> create_impl(const MaxPoolLayer* node);

    [[nodiscard]]
    virtual std::shared_ptr<NodeRelaxer>
    create_impl(const ExpandToConstLayer* node);
};

struct RelaxationOptions {
    size_t max_steps = 1000;
    real_t convergence_epsilon = 5e-3;
    real_t convergence_threshold = 3;
    real_t max_total_time = 99999;
    real_t max_layer_time = 99999;
    bool optimize_only_last = true;
    bool disable_optimization = false;
};

class ComputeGraphRelaxation {
public:
    ComputeGraphRelaxation(
        std::vector<std::shared_ptr<NodeRelaxer>> layers,
        RelaxationOptions options = RelaxationOptions());

    [[nodiscard]]
    vector<LayerBounds> compute_layer_bounds(const LayerBounds& input_bounds);

    [[nodiscard]]
    LayerBounds compute_bounds(const LayerBounds& input_bounds);

    template <typename Predicate>
    auto compute_lower_bounds(
        const LayerBounds& input_bounds,
        Predicate terminate = {})
    {
        total_timer_.reset();
        const auto [in_lbs, in_ubs] = extract_bound_vectors(input_bounds);
        LinVector lbs(in_lbs);
        LinVector ubs(in_ubs);
        auto terminate_wrapper = [&terminate, &lbs, &ubs]() {
            return terminate(lbs, ubs);
        };
        const auto layer_bounds = compute_layer_bounds(
            input_bounds,
            in_lbs,
            in_ubs,
            layers_.end() - 1,
            lbs,
            ubs);
        layers_.back()->set_input_bounds(layer_bounds.back());
        compute_next_layer_bounds<true, false>(
            in_lbs,
            in_ubs,
            layers_.begin(),
            layers_.end(),
            !options_.disable_optimization,
            lbs,
            ubs,
            std::move(terminate_wrapper));
        return lbs;
    }

private:
    [[nodiscard]]
    vector<LayerBounds> compute_layer_bounds(
        const LayerBounds& input_bounds,
        const LinVector& in_lbs,
        const LinVector& in_ubs,
        vector<std::shared_ptr<NodeRelaxer>>::const_iterator end,
        LinVector& lbs,
        LinVector& ubs);

    [[nodiscard]]
    bool time_is_up() const;

    [[nodiscard]]
    static std::pair<LinVector, LinVector>
    extract_bound_vectors(const LayerBounds& bounds);

    static void propagate_bounds(
        vector<std::shared_ptr<NodeRelaxer>>::const_iterator begin,
        vector<std::shared_ptr<NodeRelaxer>>::const_iterator end,
        LinVector& lbs,
        LinVector& ubs);

    template <bool Lb = true, bool Ub = true, typename Predicate>
    void compute_next_layer_bounds(
        const LinVector& in_lbs,
        const LinVector& in_ubs,
        vector<std::shared_ptr<NodeRelaxer>>::const_iterator begin,
        vector<std::shared_ptr<NodeRelaxer>>::const_iterator end,
        bool optimize,
        LinVector& lbs,
        LinVector& ubs,
        Predicate terminate = {})
    {
        assert(begin != end);
        layer_timer_.reset();
        propagate_bounds(end - 1, end, lbs, ubs);
        if (!optimize || terminate() || time_is_up()) {
            return;
        }
        bool optimizable = false;
        for (auto it = begin; it != end; ++it) {
            optimizable |= (*it)->is_optimizable();
        }
        if (optimizable) {
            if constexpr (Lb) {
                optimize_bounds<false>(
                    in_lbs,
                    in_ubs,
                    begin,
                    end,
                    lbs,
                    terminate);
            }
            if constexpr (Ub) {
                optimize_bounds<true>(
                    in_lbs,
                    in_ubs,
                    begin,
                    end,
                    ubs,
                    terminate);
            }
        }
    }

    template <bool Ub>
    real_t optimization_step(
        Matrix& coefs,
        Matrix& values,
        const LinVector& in_lbs,
        const LinVector& in_ubs,
        vector<std::shared_ptr<NodeRelaxer>>::const_iterator begin,
        vector<std::shared_ptr<NodeRelaxer>>::const_iterator end,
        LinVector& result) const;

    template <bool Ub, typename Predicate>
    void optimize_bounds(
        const LinVector& in_lbs,
        const LinVector& in_ubs,
        vector<std::shared_ptr<NodeRelaxer>>::const_iterator begin,
        vector<std::shared_ptr<NodeRelaxer>>::const_iterator end,
        LinVector& result,
        Predicate terminate) const
    {
        const size_t input_size = (*begin)->num_inputs();
        const size_t result_size = (*(end - 1))->num_outputs();

        auto it = end - 1;
        for (it = begin; it != end; ++it) {
            (*it)->init_optimization_round(result_size);
        }

        Matrix coefs(result_size, result_size);
        Matrix values(input_size, result_size);
        size_t convergence_counter = 0;
        for (int step = options_.max_steps - 1;; --step) {
            const auto delta = optimization_step<Ub>(
                coefs,
                values,
                in_lbs,
                in_ubs,
                begin,
                end,
                result);

            const size_t converged = delta < options_.convergence_epsilon;
            convergence_counter = converged * convergence_counter + converged;

            if (step == 0 ||
                convergence_counter >= options_.convergence_threshold ||
                time_is_up() || terminate()) {
                return;
            }

            for (it = begin; it != end; ++it) {
                (*it)->optimize<Ub>();
            }
        }

        POLICE_UNREACHABLE();
    }

    vector<std::shared_ptr<NodeRelaxer>> layers_;
    StopWatch total_timer_;
    StopWatch layer_timer_;
    RelaxationOptions options_;
};

} // namespace police::cg
