#include "police/cg_constant_compressor.hpp"
#include "police/compute_graph.hpp"
#include "police/macros.hpp"
#include "police/storage/lin_vector.hpp"
#include "police/storage/matrix.hpp"
#include <memory>

namespace police::cg {

namespace {

std::shared_ptr<Node> compress_constants(
    std::shared_ptr<LinearLayer> root,
    const vector<std::pair<size_t, real_t>>& constants)
{
    vector<bool> removed(root->num_inputs(), false);
    LinVector input(root->num_inputs());
    input.setZero();
    for (const auto& [idx, val] : constants) {
        removed[idx] = true;
        input[idx] = val;
    }
    const auto& W = root->get_weights();
    Matrix weights(root->num_outputs(), root->num_inputs() - constants.size());
    for (int r = root->num_outputs() - 1; r >= 0; --r) {
        for (size_t i = 0, j = 0; i < removed.size(); ++i) {
            if (!removed[i]) {
                weights(r, j) = W(r, i);
                ++j;
            }
        }
    }
    auto biases = W * input;

    auto res = std::make_shared<LinearLayer>(
        std::move(weights),
        root->get_biases() + biases);
    res->set_successor(root->successor());
    return res;
}

} // namespace

std::shared_ptr<Node> ConstantCompressor::operator()(
    std::shared_ptr<Node> root,
    const vector<std::pair<size_t, real_t>>& constants)
{
    if (constants.empty()) {
        return root;
    }
    auto linear = std::dynamic_pointer_cast<LinearLayer>(root);
    if (linear == nullptr) {
        POLICE_RUNTIME_ERROR(
            "cg constant compression currently supports only linear layers");
    }
    return compress_constants(linear, constants);
}

} // namespace police::cg
