#pragma once

#include "police/base_types.hpp"
#include "police/storage/lin_vector.hpp"
#include "police/storage/vector.hpp"

#include <iostream>
#include <cassert>

namespace police {

class LayerBounds {
public:
    explicit LayerBounds(size_t size);
    explicit LayerBounds(const LinVector& values);

    template <typename VecLb, typename VecUb>
    LayerBounds(const VecLb& lbs, const VecUb& ubs)
        : bounds_(2 * lbs.size(), 0)
    {
        assert(lbs.size() == ubs.size());
        for (int i = lbs.size() - 1; i >= 0; --i) {
            bounds_[2 * i] = lbs[i];
            bounds_[2 * i + 1] = ubs[i];
        }
    }

    [[nodiscard]]
    real_t get(size_t neuron, bool ub) const
    {
        return bounds_[2 * neuron + ub];
    }

    [[nodiscard]]
    const real_t* data() const
    {
        return bounds_.data();
    }

    [[nodiscard]]
    real_t& lb(size_t neuron)
    {
        assert(2 * neuron < bounds_.size());
        return bounds_[2 * neuron];
    }

    [[nodiscard]]
    real_t& ub(size_t neuron)
    {
        assert(2 * neuron + 1 < bounds_.size());
        return bounds_[2 * neuron + 1];
    }

    [[nodiscard]]
    real_t lb(size_t neuron) const
    {
        assert(2 * neuron < bounds_.size());
        return bounds_[2 * neuron];
    }

    [[nodiscard]]
    real_t ub(size_t neuron) const
    {
        assert(2 * neuron + 1 < bounds_.size());
        return bounds_[2 * neuron + 1];
    }

    [[nodiscard]]
    size_t size() const
    {
        return bounds_.size() / 2;
    }

    [[nodiscard]]
    const vector<real_t>& get_bounds() const
    {
        return bounds_;
    }

    void set_bounds(size_t neuron, real_t lb, real_t ub)
    {
        bounds_[2 * neuron] = lb;
        bounds_[2 * neuron + 1] = ub;
    }

    void get_lower_bounds(LinVector& bounds) const
    {
        assert(2 * bounds.size() == static_cast<int>(bounds_.size()));
        for (int i = bounds.size() - 1; i >= 0; --i) {
            bounds[i] = bounds_[2 * i];
        }
    }

private:
    vector<real_t> bounds_;
};

std::ostream& operator<<(std::ostream& out, const LayerBounds& bounds);

} // namespace police
