#include "police/ffnn_lp_encoder.hpp"

#include "police/nnlp.hpp"
#include "police/storage/ffnn.hpp"
#include "police/storage/variable_space.hpp"

namespace police {

namespace {

size_t add_real_variables(NNLP& lp, size_t num_vars = 1)
{
    assert(num_vars > 0);
    const auto result = lp.add_variable(RealType());
    for (; num_vars > 1; --num_vars) {
        lp.add_variable(RealType());
    }
    return result;
}

template <typename T>
void mark_input_variables(NNLP& lp, size_t num, T get_id)
{
    for (auto i = 0u; i < num; ++i) {
        lp.set_input_index(get_id(i), i);
    }
}

template <typename T>
void mark_output_variables(NNLP& lp, size_t num, T get_id)
{
    for (auto i = 0u; i < num; ++i) {
        lp.set_output_index(get_id(i), i);
    }
}

vector<size_t>
create_variables(NNLP& lp, const FeedForwardNeuralNetwork<real_t>& nn)
{
    vector<size_t> layer_vars;
    layer_vars.reserve(nn.layers.size());
    for (auto layer = 0u; layer + 1 < nn.layers.size(); ++layer) {
        assert(nn.layers[layer].size() > 0u);
        layer_vars.push_back(
            add_real_variables(lp, 2 * nn.layers[layer].size()));
    }
    layer_vars.push_back(add_real_variables(lp, nn.layers.back().size()));
    return layer_vars;
}

struct CreateNeuron {
    template <typename GetVarId>
    void operator()(
        NNLP& lp,
        const vector<real_t>& weights,
        real_t bias,
        size_t neuron_var,
        size_t,
        GetVarId get_var_id) const
    {
        // $neuron_var = \sum_i w_i * x_i + b
        LinearConstraint base_constraint(LinearConstraint::Type::EQUAL);
        base_constraint.rhs = -bias;
        base_constraint.reserve(weights.size());
        base_constraint.insert(neuron_var, -1.);
        for (auto i = 0u; i < weights.size(); ++i) {
            if (!number_utils::is_zero(weights[i])) {
                base_constraint.insert(get_var_id(i), weights[i]);
            }
        }
        lp.add_constraint(base_constraint);
    }
};

struct CreateHiddenNeuron {
    constexpr static CreateNeuron create_neuron{};

    template <typename GetVarId>
    void operator()(
        NNLP& lp,
        const vector<real_t>& weights,
        real_t bias,
        size_t non_activated_id,
        size_t activated_id,
        GetVarId get_input_var_id) const
    {
        // $neuron_var = \sum_i w_i * x_i + b
        create_neuron(
            lp,
            weights,
            bias,
            non_activated_id,
            activated_id,
            std::move(get_input_var_id));
        // $(neuron_var + 1) = relu(neuron_var)
        lp.add_constraint(ReluConstraint(non_activated_id, activated_id));
    }
};

template <typename Neuron>
struct CreateLayer {
    constexpr static Neuron create_neuron{};

    template <typename A, typename B, typename C>
    void operator()(
        NNLP& lp,
        const FeedForwardNeuralNetwork<real_t>::Layer& layer,
        A get_nonactived_var,
        B get_actived_var,
        C get_input_var_id) const
    {
        for (auto neuron = 0u; neuron < layer.size(); ++neuron) {
            create_neuron(
                lp,
                layer.weights[neuron],
                layer.biases[neuron],
                get_nonactived_var(neuron),
                get_actived_var(neuron),
                get_input_var_id);
        }
    }
};

using CreateOutLayer = CreateLayer<CreateNeuron>;
using CreateHiddenLayer = CreateLayer<CreateHiddenNeuron>;

} // namespace

size_t encode_ffnn_in_lp(
    NNLP& lp,
    const FeedForwardNeuralNetwork<real_t>& nn,
    const vector<size_t>& input_vars)
{
    assert(nn.layers.size() > 0);
    const size_t num_layers = nn.layers.size();
    const auto layer_vars = create_variables(lp, nn);
    auto get_input_var = [&](size_t idx) { return input_vars[idx]; };
    auto get_nonactived_var = [&](size_t layer) {
        const size_t base_idx = layer_vars[layer];
        auto get_var = [base_idx](size_t idx) { return base_idx + idx; };
        return get_var;
    };
    auto get_actived_var = [&](size_t layer) {
        const size_t base_idx = layer_vars[layer];
        const size_t layer_size = nn.layers[layer].size();
        auto get_var = [base_idx, layer_size](size_t idx) {
            return base_idx + layer_size + idx;
        };
        return get_var;
    };
    if (num_layers == 1) {
        CreateOutLayer()(
            lp,
            nn.layers.back(),
            get_nonactived_var(0),
            get_actived_var(0),
            get_input_var);
    } else {
        CreateHiddenLayer()(
            lp,
            nn.layers.front(),
            get_nonactived_var(0),
            get_actived_var(0),
            get_input_var);
        for (auto l = 1u; l + 1 < nn.layers.size(); ++l) {
            CreateHiddenLayer()(
                lp,
                nn.layers[l],
                get_nonactived_var(l),
                get_actived_var(l),
                get_actived_var(l - 1));
        }
        CreateOutLayer()(
            lp,
            nn.layers.back(),
            get_nonactived_var(num_layers - 1),
            get_actived_var(num_layers - 1),
            get_actived_var(num_layers - 2));
    }
    mark_input_variables(lp, input_vars.size(), get_input_var);
    mark_output_variables(
        lp,
        nn.layers.back().size(),
        get_nonactived_var(num_layers - 1));
    return layer_vars.back();
}

} // namespace police
