#include "police/cg_relaxation.hpp"
#include "police/compute_graph.hpp"
#include "police/macros.hpp"

#if DEBUG_PRINTS
#include "police/utils/io.hpp"
#endif

#include <algorithm>
#include <limits>

namespace police::cg {

bool NodeRelaxer::is_optimizable() const
{
    return false;
}

void NodeRelaxer::set_input_bounds(const LayerBounds&)
{
}

void NodeRelaxer::init_optimization_round(size_t)
{
}

void NodeRelaxer::regress_generic([[maybe_unused]] Matrix& coefs)
{
    POLICE_RUNTIME_ERROR("regress_generic has not been implemented");
}

void NodeRelaxer::progress_generic([[maybe_unused]] Matrix& coefs)
{
    POLICE_RUNTIME_ERROR("progress_generic has not been implemented");
}

#ifdef POLICE_EIGEN
namespace {

class Optimizer {
    struct Sqrt {
        [[nodiscard]]
        real_t operator()(real_t b) const
        {
            return std::sqrt(b) + ADAM_EPSILON;
        }
    };

public:
    constexpr static real_t ADAM_ETA = 0.01;
    constexpr static real_t ADAM_BETA1 = 0.9;
    constexpr static real_t ADAM_BETA2 = 0.999;
    constexpr static real_t ADAM_EPSILON = 1e-8;

    Optimizer()
        : momentum_(0, 0)
        , variance_(0, 0)
    {
    }

    auto step(const Matrix& grads)
    {
#if DEBUG_PRINTS >= 2
        ScopedStopWatch time("adam step");
#endif
        beta1_norm_ *= ADAM_BETA1;
        beta2_norm_ *= ADAM_BETA2;
        momentum_ = ADAM_BETA1 * momentum_ + (1. - ADAM_BETA1) * grads;
        variance_ = ADAM_BETA2 * variance_ +
                    (1 - ADAM_BETA2) * grads.cwiseProduct(grads);
        const auto mom = momentum_ * (ADAM_ETA / (1. - beta1_norm_));
        auto var = (variance_ / (1. - beta2_norm_))
                       .cwiseSqrt()
                       .cwiseMax(ADAM_EPSILON)
                       .cwiseInverse();
        return mom.cwiseProduct(var.matrix());
    }

    void initialize(size_t rows, size_t cols)
    {
        momentum_.resize(rows, cols);
        momentum_.setZero();
        variance_.resize(rows, cols);
        variance_.setZero();
        beta1_norm_ = 1.;
        beta2_norm_ = 1.;
    }

private:
    Matrix momentum_;
    Matrix variance_;
    real_t beta1_norm_ = 1;
    real_t beta2_norm_ = 1;
};

class LinearLayerRelaxer final : public NodeRelaxer {
public:
    explicit LinearLayerRelaxer(std::shared_ptr<const LinearLayer> node)
        : node_(std::move(node))
    {
    }

    [[nodiscard]]
    size_t num_inputs() const override
    {
        return node_->num_inputs();
    }

    [[nodiscard]]
    size_t num_outputs() const override
    {
        return node_->num_outputs();
    }

    [[nodiscard]]
    std::pair<LinVector, LinVector>
    propagate_bounds(const LinVector& in_lbs, const LinVector& in_ubs) override
    {
        const auto& W = node_->get_weights();
        const auto& b = node_->get_biases();
        LinVector out_lbs = W.cwiseMax(0) * in_lbs + W.cwiseMin(0) * in_ubs + b;
        LinVector out_ubs = W.cwiseMax(0) * in_ubs + W.cwiseMin(0) * in_lbs + b;
        return std::pair<LinVector, LinVector>(
            std::move(out_lbs),
            std::move(out_ubs));
    }

    std::string name() const override { return "linear"; }

private:
    void regress_generic(Matrix& coefs) override
    {
        coefs = coefs * node_->get_weights();
    }

    void progress_generic(Matrix& input) override
    {
        input = (node_->get_weights() * input);
        for (int out = input.cols() - 1; out >= 0; --out) {
            input.col(out) += node_->get_biases();
        }
    }

private:
    std::shared_ptr<const LinearLayer> node_;
};

class ReluLayerRelaxer final : public NodeRelaxer {
    enum ReluState { OFF = 0, ON = 1, UNSTABLE = 2 };

public:
    ReluLayerRelaxer(
        std::shared_ptr<const ReluLayer> node,
        std::shared_ptr<Optimizer> optimizer)
        : unstable_(1, node->num_outputs())
        , coef_(1, node->num_outputs())
        , offset_(1, node->num_outputs())
        , optimizer_(std::move(optimizer))
        , size_(node->num_outputs())
    {
    }

    [[nodiscard]]
    size_t num_inputs() const override
    {
        return size_;
    }

    [[nodiscard]]
    size_t num_outputs() const override
    {
        return size_;
    }

    [[nodiscard]]
    size_t size() const
    {
        return size_;
    }

    void set_input_bounds(const LayerBounds& bounds) override
    {
        const ReluState states[] = {ON, OFF, ON, UNSTABLE};
        for (int neuron = size() - 1; neuron >= 0; --neuron) {
            const real_t lb = bounds.lb(neuron);
            const real_t ub = bounds.ub(neuron);
            const ReluState state =
                states[2 * static_cast<size_t>(lb < 0.) + (ub > 0.)];
            const real_t b = static_cast<real_t>(lb != ub);
            const real_t a = b * ub / (ub - lb + 1. - b);
            const real_t multipliers[] = {
                1, // off, ub
                0, // on, ub
                a, // unstable, ub
            };
            unstable_(0, neuron) = state == UNSTABLE;
            coef_(0, neuron) = multipliers[state];
            offset_(0, neuron) = -static_cast<real_t>(state == UNSTABLE) * lb;
        }
    }

    void init_optimization_round(size_t num_outputs) override
    {
        lb_coef_ = coef_.replicate(num_outputs, 1);
        weights_.resize(num_outputs, size());
        biases_.resize(num_outputs, size());
        gradients_.resize(num_outputs, size());
        optimizer_->initialize(num_outputs, size());
    }

    [[nodiscard]]
    bool is_optimizable() const override
    {
        return true;
    }

    [[nodiscard]]
    std::pair<LinVector, LinVector>
    propagate_bounds(const LinVector& in_lbs, const LinVector& in_ubs) override
    {
        return std::pair<LinVector, LinVector>(
            in_lbs.cwiseMax(0),
            in_ubs.cwiseMax(0));
    }

    std::string name() const override { return "relu"; }

private:
    template <bool Ub>
    void regress(Matrix& coefs)
    {
        const auto ltz = coefs.cwiseMin(0);
        const auto gtz = coefs.cwiseMax(0);
        if constexpr (Ub) {
            output_coefs_ =
                ltz.cwiseProduct(unstable_.replicate(coefs.rows(), 1));
            weights_ =
                gtz.cwiseSign().cwiseProduct(coef_.replicate(coefs.rows(), 1)) -
                ltz.cwiseSign().cwiseProduct(lb_coef_);
            biases_ = gtz.cwiseSign().cwiseProduct(
                offset_.replicate(coefs.rows(), 1));
        } else {
            output_coefs_ =
                gtz.cwiseProduct(unstable_.replicate(coefs.rows(), 1));
            weights_ =
                gtz.cwiseSign().cwiseProduct(lb_coef_) -
                ltz.cwiseSign().cwiseProduct(coef_.replicate(coefs.rows(), 1));
            biases_ = -ltz.cwiseSign().cwiseProduct(
                offset_.replicate(coefs.rows(), 1));
        }
        coefs = coefs.cwiseProduct(weights_);
    }

    void regress_lb(Matrix& coefs) override { regress<false>(coefs); }

    void regress_ub(Matrix& coefs) override { regress<true>(coefs); }

    void progress_generic(Matrix& input) override
    {
        gradients_ = input.transpose().cwiseProduct(output_coefs_);
#if DEBUG_PRINTS >= 2
        std::cout << "-- in\n" << input << std::endl;
        std::cout << "-- out\n" << output_coefs_ << std::endl;
        std::cout << "-- gradients\n" << gradients_ << std::endl;
#endif
        input =
            (input + biases_.transpose()).cwiseProduct(weights_.transpose());
    }

    void optimize_lb() override
    {
        auto step = optimizer_->step(gradients_);
        lb_coef_ = (lb_coef_ + step).cwiseMax(0).cwiseMin(1);
#if DEBUG_PRINTS >= 2
        std::cout << "- step\n" << step << std::endl;
        std::cout << "- new alpha\n" << alpha_ << std::endl;
#endif
    }

    void optimize_ub() override
    {
        auto step = optimizer_->step(gradients_);
        lb_coef_ = (lb_coef_ - step).cwiseMax(0).cwiseMin(1);
#if DEBUG_PRINTS >= 2
        std::cout << "- step\n" << step << std::endl;
        std::cout << "- new alpha\n" << alpha_ << std::endl;
#endif
    }

private:
    Matrix unstable_;
    Matrix coef_;
    Matrix offset_;

    Matrix lb_coef_;

    Matrix weights_;
    Matrix biases_;
    Matrix output_coefs_;
    Matrix gradients_;
    std::shared_ptr<Optimizer> optimizer_;

    size_t size_;
};

class MaxPoolRelaxerBase : public NodeRelaxer {
public:
    explicit MaxPoolRelaxerBase(std::shared_ptr<const MaxPoolLayer> node)
        : node_(std::move(node))
    {
        vector<size_t>* idxs[] = {&non_empty_, &empty_};
        for (size_t o = 0; o < node_->num_outputs(); ++o) {
            idxs[node_->get_input_refs(o).empty()]->push_back(o);
        }
    }

    [[nodiscard]]
    size_t num_inputs() const final
    {
        return node_->num_inputs();
    }

    [[nodiscard]]
    size_t num_outputs() const final
    {
        return node_->num_outputs();
    }

    [[nodiscard]]
    std::pair<LinVector, LinVector>
    propagate_bounds(const LinVector& in_lbs, const LinVector& in_ubs) final
    {
        LinVector out_lbs(num_outputs());
        out_lbs.setZero();
        LinVector out_ubs(num_outputs());
        out_ubs.setZero();
        for (const size_t& neuron : non_empty_) {
            real_t lb = -std::numeric_limits<real_t>::infinity();
            real_t ub = -std::numeric_limits<real_t>::infinity();
            for (const auto& idx : node_->get_input_refs(neuron)) {
                lb = std::max(lb, in_lbs[idx]);
                ub = std::max(ub, in_ubs[idx]);
            }
            out_lbs(neuron) = lb;
            out_ubs(neuron) = ub;
        }
        return std::pair<LinVector, LinVector>(
            std::move(out_lbs),
            std::move(out_ubs));
    }

    std::string name() const override { return "maxpool"; }

protected:
    std::shared_ptr<const MaxPoolLayer> node_;
    vector<size_t> empty_;
    vector<size_t> non_empty_;
};

class MaxPoolRelaxer final : public MaxPoolRelaxerBase {
public:
    MaxPoolRelaxer(
        std::shared_ptr<const MaxPoolLayer> node,
        std::shared_ptr<Optimizer> optimizer)
        : MaxPoolRelaxerBase(node)
        , lb_coefs_(node->num_outputs(), node->num_inputs())
        , ub_coefs_(node->num_outputs(), node->num_inputs())
        , offset_(node->num_outputs(), 1)
        , optimizable_(node->num_outputs(), node->num_inputs())
        , optimizer_(std::move(optimizer))
    {
    }

    void set_input_bounds(const LayerBounds& bounds) override
    {
        lb_coefs_.setZero();
        ub_coefs_.setZero();
        offset_.setZero();
        optimizable_.setZero();
        vector<size_t> active;
        for (int neuron = node_->num_outputs() - 1; neuron >= 0; --neuron) {
            // find referenced input with maximal lower bound
            size_t max_lb_idx = -1;
            real_t max_lb = -std::numeric_limits<real_t>::infinity();
            real_t max_ub = -std::numeric_limits<real_t>::infinity();
            for (const auto& idx : node_->get_input_refs(neuron)) {
                const auto lb = bounds.lb(idx);
                max_ub = std::max(max_ub, bounds.ub(idx));
                if (lb > max_lb) {
                    max_lb = lb;
                    max_lb_idx = idx;
                }
            }
            // set ub passthrough to that input element, unless there is
            // another referenced input whose ub is larger than max_lb
            // active keeps track of the input_ref indices whose ub is
            // larger than max_lb; all other inputs are guaranteed to be
            // dominated by the max_lb input
            // compute active and update passthrough_ub_ as necessary...
            for (const auto& idx : node_->get_input_refs(neuron)) {
                if (idx == max_lb_idx) {
                    active.push_back(idx);
                    continue;
                }
                const auto ub = bounds.ub(idx);
                if (ub > max_lb) {
                    active.push_back(idx);
                }
            }
            if (active.size() == 1u) {
                // full passthrough for both bounds
                ub_coefs_(neuron, max_lb_idx) = 1;
                lb_coefs_(neuron, max_lb_idx) = 1;
            } else {
                // constant upper bound:
                offset_(neuron) = max_ub;
                if (max_lb < 0) {
                    // if max_lb < 0, linear combination not guaranteed to be
                    // lower bound => passthrough to the var with max lb
                    lb_coefs_(neuron, max_lb_idx) = 1;
                } else {
                    // otherwise take linear combination over active inputs;
                    // mark linear combination weights as optimizable variables
                    for (size_t idx : active) {
                        lb_coefs_(neuron, idx) = 1. / active.size();
                        optimizable_(neuron, idx) = 1;
                    }
                }
            }
            active.clear();
        }
    }

    void init_optimization_round(size_t cg_outputs) override
    {
        out_coefs_.resize(cg_outputs, num_outputs());
        per_res_lb_coefs_.resize(cg_outputs * num_outputs(), num_inputs());
        for (size_t o = 0; o < cg_outputs; ++o) {
            per_res_lb_coefs_.block<Eigen::Dynamic, Eigen::Dynamic>(
                o * num_outputs(),
                0,
                num_outputs(),
                num_inputs()) = lb_coefs_;
        }
        gradients_.resize(cg_outputs * num_outputs(), num_inputs());
        optimizer_->initialize(cg_outputs * this->num_outputs(), num_inputs());
    }

    [[nodiscard]]
    bool is_optimizable() const override
    {
        return true;
    }

private:
    template <bool UB>
    void regress(Matrix& coefs)
    {
        out_coefs_ = coefs;
        auto ltz = coefs.cwiseMin(0);
        auto gtz = coefs.cwiseMax(0);
        Matrix result(coefs.rows(), num_inputs());
        auto result_rows = result.rowwise();
        auto r = result_rows.begin();
        if constexpr (UB) {
            for (int o = 0; o < result.rows(); ++o, ++r) {
                *r =
                    (gtz.row(o) * ub_coefs_ +
                     ltz.row(o) * per_res_lb_coefs_
                                      .block<Eigen::Dynamic, Eigen::Dynamic>(
                                          o * num_outputs(),
                                          0,
                                          num_outputs(),
                                          per_res_lb_coefs_.cols()));
            }
        } else {
            for (int o = 0; o < result.rows(); ++o, ++r) {
                *r =
                    (gtz.row(o) * per_res_lb_coefs_
                                      .block<Eigen::Dynamic, Eigen::Dynamic>(
                                          o * num_outputs(),
                                          0,
                                          num_outputs(),
                                          per_res_lb_coefs_.cols()) +
                     ltz.row(o) * ub_coefs_);
            }
        }
        coefs.swap(result);
    }

    void regress_lb(Matrix& coefs) override { regress<false>(coefs); }

    void regress_ub(Matrix& coefs) override { regress<true>(coefs); }

    template <bool Ub>
    void progress(Matrix& input)
    {
        auto result = compute_relaxation<Ub>(input);
        update_gradients(input);
        input.swap(result);
    }

    void progress_lb(Matrix& coefs) override { progress<false>(coefs); }

    void progress_ub(Matrix& coefs) override { progress<true>(coefs); }

    void optimize_lb() override
    {
        auto step = optimizer_->step(gradients_);
        per_res_lb_coefs_ += step;
        normalize_lb_coefs();
    }

    void optimize_ub() override
    {
        auto step = optimizer_->step(gradients_);
        per_res_lb_coefs_ -= step;
        normalize_lb_coefs();
    }

private:
    template <bool Ub>
    [[nodiscard]]
    Matrix compute_relaxation(const Matrix& input) const
    {
        Matrix result(num_outputs(), input.cols());
        size_t o = 0;
        auto ltz = out_coefs_.cwiseSign().cwiseMin(0);
        auto gtz = out_coefs_.cwiseSign().cwiseMax(0);
        for (auto c : result.colwise()) {
            if constexpr (Ub) {
                c = gtz.row(o).transpose().cwiseProduct(
                        ub_coefs_ * input.col(o) + offset_) -
                    ltz.row(o).transpose().cwiseProduct(
                        per_res_lb_coefs_.block<Eigen::Dynamic, Eigen::Dynamic>(
                            o * num_outputs(),
                            0,
                            num_outputs(),
                            num_inputs()) *
                        input.col(o));
            } else {
                c = gtz.row(o).transpose().cwiseProduct(
                        per_res_lb_coefs_.block<Eigen::Dynamic, Eigen::Dynamic>(
                            o * num_outputs(),
                            0,
                            num_outputs(),
                            num_inputs()) *
                        input.col(o)) -
                    ltz.row(o).transpose().cwiseProduct(
                        ub_coefs_ * input.col(o) + offset_);
            }
            ++o;
        }
        return result;
    }

    void update_gradients(const Matrix& input)
    {
        Matrix diag(num_outputs(), num_outputs());
        diag.setZero();
        diag.diagonal().setOnes();
        for (int o = 0; o < input.cols(); ++o) {
            gradients_.block<Eigen::Dynamic, Eigen::Dynamic>(
                o * num_outputs(),
                0,
                num_outputs(),
                num_inputs()) =
                out_coefs_.row(o)
                    .replicate(num_outputs(), 1)
                    .cwiseProduct(diag) *
                optimizable_.cwiseProduct(
                    input.col(o).transpose().replicate(num_outputs(), 1));
        }
    }

    void normalize_lb_coefs()
    {
        for (auto r : per_res_lb_coefs_.rowwise()) {
            const auto total = std::max(r.sum(), static_cast<real_t>(1e-3));
            r *= 1. / total;
        }
    }

    Matrix lb_coefs_;         // OUT x IN
    Matrix ub_coefs_;         // OUT x IN
    Matrix offset_;           // OUT x 1
    Matrix optimizable_;      // OUT x IN
    Matrix out_coefs_;        // RES x OUT
    Matrix per_res_lb_coefs_; // RES * (OUT x IN)
    Matrix gradients_;        // RES * (OUT x IN)

    std::shared_ptr<Optimizer> optimizer_;
};

class MaxPoolRelaxerSimple final : public MaxPoolRelaxerBase {
public:
    explicit MaxPoolRelaxerSimple(std::shared_ptr<const MaxPoolLayer> node)
        : MaxPoolRelaxerBase(node)
        , ub_(node->num_outputs())
        , ub_idx_(node->num_outputs(), -1)
        , lb_idx_(node->num_outputs(), -1)
    {
    }

    void set_input_bounds(const LayerBounds& bounds) override
    {
        std::fill(ub_.begin(), ub_.end(), 0);
        std::fill(lb_idx_.begin(), lb_idx_.end(), -1);
        std::fill(ub_idx_.begin(), ub_idx_.end(), -1);
        for (const size_t& neuron : non_empty_) {
            // find referenced input with maximal lower bound
            size_t max_lb_idx = -1;
            real_t max_lb = -std::numeric_limits<real_t>::infinity();
            real_t max_ub = -std::numeric_limits<real_t>::infinity();
            for (const auto& idx : node_->get_input_refs(neuron)) {
                const auto lb = bounds.lb(idx);
                max_ub = std::max(max_ub, bounds.ub(idx));
                if (lb > max_lb) {
                    max_lb = lb;
                    max_lb_idx = idx;
                }
            }
            size_t active = 0;
            for (const auto& idx : node_->get_input_refs(neuron)) {
                if (idx == max_lb_idx) {
                    ++active;
                    continue;
                }
                const auto ub = bounds.ub(idx);
                if (ub > max_lb) {
                    ++active;
                }
            }
            lb_idx_[neuron] = max_lb_idx;
            if (active <= 1u) {
                ub_idx_[neuron] = max_lb_idx;
            } else {
                ub_[neuron] = max_ub;
            }
        }
    }

    void init_optimization_round(size_t output_layer_size) override
    {
        chosen_idx_.resize(output_layer_size * num_outputs());
    }

private:
    template <bool Ub>
    void regress(Matrix& coefs)
    {
        Matrix result(coefs.rows(), num_inputs());
        result.setZero();
        std::fill(chosen_idx_.begin(), chosen_idx_.end(), -1);
        for (int o = 0; o < coefs.rows(); ++o) {
            for (const size_t& n : non_empty_) {
                if constexpr (Ub) {
                    if (coefs(o, n) > 0) {
                        if (ub_idx_[n] >= 0) {
                            result(o, ub_idx_[n]) += coefs(o, n);
                            chosen_idx_[o * num_outputs() + n] = ub_idx_[n];
                        }
                    } else if (coefs(o, n) < 0) {
                        result(o, lb_idx_[n]) += coefs(o, n);
                        chosen_idx_[o * num_outputs() + n] = lb_idx_[n];
                    }
                } else {
                    if (coefs(o, n) > 0) {
                        result(o, lb_idx_[n]) += coefs(o, n);
                        chosen_idx_[o * num_outputs() + n] = lb_idx_[n];
                    } else if (coefs(o, n) < 0. && ub_idx_[n] >= 0) {
                        result(o, ub_idx_[n]) += coefs(o, n);
                        chosen_idx_[o * num_outputs() + n] = ub_idx_[n];
                    }
                }
            }
        }
        coefs.swap(result);
    }

    void regress_ub(Matrix& coefs) override { regress<true>(coefs); }

    void regress_lb(Matrix& coefs) override { regress<false>(coefs); }

    void progress_generic(Matrix& input) override
    {
        Matrix result(num_outputs(), input.cols());
        for (size_t n = 0; n < num_outputs(); ++n) {
            for (int o = 0; o < input.cols(); ++o) {
                if (chosen_idx_[o * num_outputs() + n] >= 0) {
                    result(n, o) = input(chosen_idx_[o * num_outputs() + n], o);
                } else {
                    result(n, o) = ub_[n];
                }
            }
        }
        input.swap(result);
    }

    vector<real_t> ub_;
    vector<int> ub_idx_;
    vector<int> lb_idx_;
    vector<int> chosen_idx_;
};

class ExpandToConstRelaxer final : public NodeRelaxer {
public:
    explicit ExpandToConstRelaxer(
        std::shared_ptr<const ExpandToConstLayer> node)
        : node_(std::move(node))
    {
    }

    size_t num_inputs() const override { return node_->num_inputs(); }

    size_t num_outputs() const override { return node_->num_outputs(); }

    std::pair<LinVector, LinVector>
    propagate_bounds(const LinVector& inlbs, const LinVector& inubs) override
    {
        LinVector outlbs(num_outputs());
        outlbs.fill(node_->get_scalar());
        LinVector outubs(num_outputs());
        outubs.fill(node_->get_scalar());
        for (int i = num_inputs() - 1; i >= 0; --i) {
            const int j = node_->get_input_remap()[i];
            outlbs[j] = inlbs[i];
            outubs[j] = inubs[i];
        }
        return {std::move(outlbs), std::move(outubs)};
    }

    std::string name() const override { return "expand"; }

private:
    void regress_generic(Matrix& coefs) override
    {
        Matrix res(coefs.rows(), num_inputs());
        for (int r = coefs.rows() - 1; r >= 0; --r) {
            for (int i = num_inputs() - 1; i >= 0; --i) {
                const int j = node_->get_input_remap()[i];
                res(r, i) = coefs(r, j);
            }
        }
        coefs.swap(res);
    }

    void progress_generic(Matrix& coefs) override
    {
        Matrix res(num_outputs(), coefs.cols());
        res.fill(node_->get_scalar());
        for (int i = num_inputs() - 1; i >= 0; --i) {
            const int j = node_->get_input_remap()[i];
            res.row(j) = coefs.row(i);
        }
        res.swap(coefs);
    }

private:
    std::shared_ptr<const ExpandToConstLayer> node_;
};

} // namespace
#endif

class NodeRelaxerFactory::FactoryImpl final : public NodeVisitor {
public:
    explicit FactoryImpl(NodeRelaxerFactory* factory)
        : factory(factory)
    {
    }

    void visit(const LinearLayer* node)
    {
        relaxation = factory->create_impl(node);
    }

    void visit(const ReluLayer* node)
    {
        relaxation = factory->create_impl(node);
    }

    void visit(const MaxPoolLayer* node)
    {
        relaxation = factory->create_impl(node);
    }

    void visit(const ExpandToConstLayer* node)
    {
        relaxation = factory->create_impl(node);
    }

    NodeRelaxerFactory* factory;
    std::shared_ptr<NodeRelaxer> relaxation = nullptr;
};

std::shared_ptr<NodeRelaxer> NodeRelaxerFactory::create(const Node* node)
{
    FactoryImpl impl(this);
    node->accept(&impl);
    return std::move(impl.relaxation);
}

vector<std::shared_ptr<NodeRelaxer>>
NodeRelaxerFactory::create_recursive(const Node* node)
{
    vector<std::shared_ptr<NodeRelaxer>> result;
    while (node != nullptr) {
        result.push_back(create(node));
        node = node->successor().get();
    }
    return result;
}

std::shared_ptr<NodeRelaxer>
NodeRelaxerFactory::create_impl([[maybe_unused]] const LinearLayer* node)
{
#if POLICE_EIGEN
    return std::make_shared<LinearLayerRelaxer>(node->shared_from_this());
#else
    return nullptr;
#endif
}

std::shared_ptr<NodeRelaxer>
NodeRelaxerFactory::create_impl([[maybe_unused]] const ReluLayer* node)
{
#if POLICE_EIGEN
    return std::make_shared<ReluLayerRelaxer>(
        node->shared_from_this(),
        std::make_shared<Optimizer>());
#else
    return nullptr;
#endif
}

std::shared_ptr<NodeRelaxer>
NodeRelaxerFactory::create_impl([[maybe_unused]] const MaxPoolLayer* node)
{
#if POLICE_EIGEN
#if 0
        return std::make_shared<MaxPoolRelaxer>(
            node->shared_from_this(),
            std::make_shared<Optimizer>());
#else
    return std::make_shared<MaxPoolRelaxerSimple>(node->shared_from_this());
#endif
#else
    return nullptr;
#endif
}

std::shared_ptr<NodeRelaxer>
NodeRelaxerFactory::create_impl([[maybe_unused]] const ExpandToConstLayer* node)
{
#if POLICE_EIGEN
    return std::make_shared<ExpandToConstRelaxer>(node->shared_from_this());
#else
    return nullptr;
#endif
}

#if POLICE_EIGEN

ComputeGraphRelaxation::ComputeGraphRelaxation(
    std::vector<std::shared_ptr<NodeRelaxer>> layers,
    RelaxationOptions options)
    : layers_(std::move(layers))
    , options_(std::move(options))
{
    assert(!layers_.empty());
}

vector<LayerBounds>
ComputeGraphRelaxation::compute_layer_bounds(const LayerBounds& input_bounds)
{
    total_timer_.reset();
    const auto [in_lbs, in_ubs] = extract_bound_vectors(input_bounds);
    LinVector lbs(in_lbs);
    LinVector ubs(in_ubs);
    return compute_layer_bounds(
        input_bounds,
        in_lbs,
        in_ubs,
        layers_.end(),
        lbs,
        ubs);
}

LayerBounds
ComputeGraphRelaxation::compute_bounds(const LayerBounds& input_bounds)
{
    return compute_layer_bounds(input_bounds).back();
}

vector<LayerBounds> ComputeGraphRelaxation::compute_layer_bounds(
    const LayerBounds& input_bounds,
    const LinVector& in_lbs,
    const LinVector& in_ubs,
    vector<std::shared_ptr<NodeRelaxer>>::const_iterator end,
    LinVector& lbs,
    LinVector& ubs)
{
    auto terminate = []() { return false; };
    vector<LayerBounds> layer_bounds;
    layer_bounds.reserve(layers_.size() + 1);
    layer_bounds.push_back(input_bounds);
    for (auto layer = layers_.begin(); layer != end; ++layer) {
        (*layer)->set_input_bounds(layer_bounds.back());
        compute_next_layer_bounds(
            in_lbs,
            in_ubs,
            layers_.begin(),
            layer + 1,
            !options_.disable_optimization &&
                (!options_.optimize_only_last || layer + 1 == layers_.end()),
            lbs,
            ubs,
            terminate);
        layer_bounds.push_back(LayerBounds(lbs, ubs));
    }
    return layer_bounds;
}

bool ComputeGraphRelaxation::time_is_up() const
{
    return total_timer_.get_seconds() >= options_.max_total_time ||
           layer_timer_.get_seconds() >= options_.max_layer_time;
}

std::pair<LinVector, LinVector>
ComputeGraphRelaxation::extract_bound_vectors(const LayerBounds& bounds)
{
    LinVector lbs(bounds.size());
    LinVector ubs(bounds.size());
    for (int i = bounds.size() - 1; i >= 0; --i) {
        lbs[i] = bounds.lb(i);
        ubs[i] = bounds.ub(i);
    }
    return std::pair<LinVector, LinVector>(std::move(lbs), std::move(ubs));
}

void ComputeGraphRelaxation::propagate_bounds(
    vector<std::shared_ptr<NodeRelaxer>>::const_iterator begin,
    vector<std::shared_ptr<NodeRelaxer>>::const_iterator end,
    LinVector& lbs,
    LinVector& ubs)
{
    for (auto it = begin; it != end; ++it) {
        auto [out_lbs, out_ubs] = (*it)->propagate_bounds(lbs, ubs);
        assert(std::all_of(out_lbs.begin(), out_lbs.end(), [](real_t v) {
            return v != std::numeric_limits<real_t>::infinity() &&
                   v != -std::numeric_limits<real_t>::infinity() &&
                   !std::isnan(v);
        }));
        assert(std::all_of(out_ubs.begin(), out_ubs.end(), [](real_t v) {
            return v != std::numeric_limits<real_t>::infinity() &&
                   v != -std::numeric_limits<real_t>::infinity() &&
                   !std::isnan(v);
        }));
        out_lbs.swap(lbs);
        out_ubs.swap(ubs);
    }
}

template <bool Ub>
real_t ComputeGraphRelaxation::optimization_step(
    Matrix& coefs,
    Matrix& values,
    const LinVector& in_lbs,
    const LinVector& in_ubs,
    vector<std::shared_ptr<NodeRelaxer>>::const_iterator begin,
    vector<std::shared_ptr<NodeRelaxer>>::const_iterator end,
    LinVector& result) const
{
#if DEBUG_PRINTS >= 2
    std::cout << "- iteration " << (MAX_STEPS - step) << "/" << MAX_STEPS
              << "..." << std::endl;
#endif
    coefs.resize(result.size(), result.size());
    coefs.setZero();
    coefs.diagonal().setOnes();
#if DEBUG_PRINTS >= 2
    std::cout << "- regression matrix: " << "\n";
#endif
    for (auto it = end - 1;; --it) {
#if DEBUG_PRINTS >= 3
        std::cout << "-\n" << coefs << std::endl;
#endif
        (*it)->regress<Ub>(coefs);
        if (it == begin) {
            break;
        }
    }
#if DEBUG_PRINTS >= 2
    std::cout << "--->\n" << coefs << std::endl;
#endif
    const auto signs = coefs.transpose().cwiseSign();
    const auto ltz = signs.cwiseMin(0);
    const auto gtz = signs.cwiseMax(0);
    if constexpr (Ub) {
        values = gtz.cwiseProduct(in_ubs.replicate(1, result.size())) -
                 ltz.cwiseProduct(in_lbs.replicate(1, result.size()));
    } else {
        values = gtz.cwiseProduct(in_lbs.replicate(1, result.size())) -
                 ltz.cwiseProduct(in_ubs.replicate(1, result.size()));
    }
#if DEBUG_PRINTS >= 2
    std::cout << "- optimizing bounds: " << "\n" << values << std::endl;
#endif
    for (auto it = begin; it != end; ++it) {
        (*it)->progress<Ub>(values);
    }
#if DEBUG_PRINTS >= 2
    std::cout << "- result: " << "\n" << values << std::endl;
#endif
    assert(
        (size_t)values.rows() == (*(end - 1))->num_outputs() &&
        values.rows() == values.cols());
    real_t delta = 0;
    for (int o = result.size() - 1; o >= 0; --o) {
        assert(
            values(o, o) != std::numeric_limits<real_t>::infinity() &&
            values(o, o) != -std::numeric_limits<real_t>::infinity() &&
            !std::isnan(values(o, o)));
        const auto old = result[o];
        if constexpr (Ub) {
            result[o] = std::min(old, values(o, o));
        } else {
            result[o] = std::max(old, values(o, o));
        }
        delta = std::max(delta, std::abs(old - result[o]));
    }

    return delta;
}

template real_t ComputeGraphRelaxation::optimization_step<true>(
    Matrix& coefs,
    Matrix& values,
    const LinVector& in_lbs,
    const LinVector& in_ubs,
    vector<std::shared_ptr<NodeRelaxer>>::const_iterator begin,
    vector<std::shared_ptr<NodeRelaxer>>::const_iterator end,
    LinVector& result) const;

template real_t ComputeGraphRelaxation::optimization_step<false>(
    Matrix& coefs,
    Matrix& values,
    const LinVector& in_lbs,
    const LinVector& in_ubs,
    vector<std::shared_ptr<NodeRelaxer>>::const_iterator begin,
    vector<std::shared_ptr<NodeRelaxer>>::const_iterator end,
    LinVector& result) const;

#else

ComputeGraphRelaxation::ComputeGraphRelaxation(
    std::shared_ptr<const Node>,
    std::shared_ptr<NodeRelaxerFactory>,
    RelaxationOptions)
{
    POLICE_MISSING_DEPENDENCY("eigen");
}

vector<LayerBounds>
ComputeGraphRelaxation::compute_layer_bounds(const LayerBounds&)
{
    POLICE_MISSING_DEPENDENCY("eigen");
}

LayerBounds ComputeGraphRelaxation::compute_bounds(const LayerBounds&)
{
    POLICE_MISSING_DEPENDENCY("eigen");
}

vector<LayerBounds> ComputeGraphRelaxation::compute_layer_bounds(
    const LayerBounds&,
    const LinVector&,
    const LinVector&,
    vector<std::shared_ptr<NodeRelaxer>>::const_iterator,
    LinVector&,
    LinVector&)
{
    POLICE_MISSING_DEPENDENCY("eigen");
}

bool ComputeGraphRelaxation::time_is_up() const
{
    POLICE_MISSING_DEPENDENCY("eigen");
}

std::pair<LinVector, LinVector>
ComputeGraphRelaxation::extract_bound_vectors(const LayerBounds&)
{
    POLICE_MISSING_DEPENDENCY("eigen");
}

void ComputeGraphRelaxation::propagate_bounds(
    vector<std::shared_ptr<NodeRelaxer>>::const_iterator,
    vector<std::shared_ptr<NodeRelaxer>>::const_iterator,
    LinVector&,
    LinVector&)
{
    POLICE_MISSING_DEPENDENCY("eigen");
}

template <bool Ub>
real_t ComputeGraphRelaxation::optimization_step(
    Matrix&,
    Matrix&,
    const LinVector&,
    const LinVector&,
    vector<std::shared_ptr<NodeRelaxer>>::const_iterator,
    vector<std::shared_ptr<NodeRelaxer>>::const_iterator,
    LinVector&) const
{
    POLICE_MISSING_DEPENDENCY("eigen");
}

#endif

} // namespace police::cg
