#include "police/compute_graph_factory.hpp"
#include "police/cg_pruner.hpp"
#include "police/compute_graph.hpp"
#include "police/layer_bounds.hpp"
#include "police/macros.hpp"
#include "police/storage/ffnn.hpp"

#include <fstream>
#include <memory>
#include <string>
#include <string_view>

namespace police::cg {

namespace {
constexpr real_t TRUNCATE_EPSILON = 1e-3;

Matrix get_matrix(const vector<vector<real_t>>& coefs)
{
    assert(!coefs.empty());
    Matrix result(coefs.size(), coefs.front().size());
    for (int row_idx = coefs.size() - 1; row_idx >= 0; --row_idx) {
        const auto& row = coefs[row_idx];
        for (int col_idx = row.size() - 1; col_idx >= 0; --col_idx) {
            result(row_idx, col_idx) =
                std::abs(row[col_idx]) < TRUNCATE_EPSILON ? 0 : row[col_idx];
        }
    }
    return result;
}

LinVector get_vector(const vector<real_t>& vec)
{
    LinVector result(vec.size());
    for (int i = vec.size() - 1; i >= 0; --i) {
        result(i) = std::abs(vec[i]) < TRUNCATE_EPSILON ? 0 : vec[i];
    }
    return result;
}

LayerBounds get_input_bounds(
    const VariableSpace& variables,
    const vector<size_t>& input_vars)
{
    LayerBounds bounds(input_vars.size());
    for (size_t i = 0; i < input_vars.size(); ++i) {
        const auto t = variables.get_type(input_vars[i]);
        bounds.set_bounds(
            i,
            static_cast<real_t>(t.get_lower_bound()),
            static_cast<real_t>(t.get_upper_bound()));
    }
    return bounds;
}

} // namespace

PostProcessingResult post_process(
    std::shared_ptr<Node> root,
    const vector<size_t>& input_vars,
    const VariableSpace& variables)
{
    ComputeGraphPruner prn(TRUNCATE_EPSILON);
    LayerBounds input_bounds = get_input_bounds(variables, input_vars);
    auto r = prn.prune(root, input_bounds);
    std::cout << "after post-processing: " << r.cg->statistics() << std::endl;
    for (size_t i = 0; i < r.inputs.size(); ++i) {
        r.inputs[i] = input_vars[r.inputs[i]];
    }
    return r;
}

std::shared_ptr<Node> from_ffnn(const FeedForwardNeuralNetwork<>& net)
{
    assert(!net.layers.empty());
    std::shared_ptr<Node> result = std::make_shared<cg::LinearLayer>(
        get_matrix(net.layers.front().weights),
        get_vector(net.layers.front().biases));
    std::shared_ptr<Node> leaf = result;
    for (size_t l = 1; l < net.layers.size(); ++l) {
        leaf->set_successor(std::make_shared<ReluLayer>(leaf->num_outputs()));
        leaf = leaf->successor();
        leaf->set_successor(
            std::make_shared<cg::LinearLayer>(
                get_matrix(net.layers[l].weights),
                get_vector(net.layers[l].biases)));
        leaf = leaf->successor();
    }
    std::cout << "constructed " << result->statistics() << std::endl;
    return result;
}

namespace {
std::shared_ptr<LinearLayer>
parse_linear_layer(std::ifstream& f, size_t, size_t inputs)
{
    std::string line;
    std::getline(f, line);
    size_t num_outputs = std::stoi(line);
    Matrix weights(num_outputs, inputs);
    LinVector biases(num_outputs);
    double d;
    for (size_t o = 0; o < num_outputs; ++o) {
        for (size_t i = 0; i < inputs; ++i) {
            f >> d;
            weights(o, i) = std::abs(d) < TRUNCATE_EPSILON ? 0 : d;
        }
        std::getline(f, line);
    }
    for (size_t o = 0; o < num_outputs; ++o) {
        f >> d;
        biases(o) = std::abs(d) < TRUNCATE_EPSILON ? 0 : d;
    }
    std::getline(f, line);
    return std::make_shared<LinearLayer>(std::move(weights), std::move(biases));
}

std::shared_ptr<ReluLayer>
parse_relu_layer(std::ifstream&, size_t, size_t inputs)
{
    return std::make_shared<ReluLayer>(inputs);
}

std::shared_ptr<MaxPoolLayer>
parse_max_pool_layer(std::ifstream& f, size_t, size_t inputs)
{
    std::string line;
    std::getline(f, line);
    size_t num_outputs = std::stoi(line);
    vector<vector<size_t>> pools(num_outputs);
    for (size_t o = 0; o < num_outputs; ++o) {
        std::getline(f, line);
        size_t size = std::stoi(line);
        if (size == 0u) {
            continue;
        }
        pools[o].resize(size);
        for (size_t i = 0; i < size; ++i) {
            int x;
            f >> x;
            pools[o][i] = x;
            if (pools[o][i] > inputs) {
                POLICE_EXIT_INVALID_INPUT(
                    "index " << x << " exceeds input size (" << inputs << ")");
            }
        }
        std::getline(f, line);
    }
    return std::make_shared<MaxPoolLayer>(std::move(pools), inputs);
}
} // namespace

std::shared_ptr<Node> parse_nnet_file(std::string_view file_name)
{
    std::ifstream f(file_name.data());
    std::string line;
    std::getline(f, line);
    size_t num_inputs = std::stoi(line);
    std::getline(f, line);
    size_t num_layers = std::stoi(line);
    std::shared_ptr<Node> root = nullptr;
    std::shared_ptr<Node> parent = nullptr;
    size_t last_outputs = num_inputs;
    for (size_t l = 0; l < num_layers; ++l) {
        std::getline(f, line);
        std::shared_ptr<Node> node = nullptr;
        if (line == "linear") {
            node = parse_linear_layer(f, l, last_outputs);
        } else if (line == "relu") {
            node = parse_relu_layer(f, l, last_outputs);
        } else if (line == "max-pool") {
            node = parse_max_pool_layer(f, l, last_outputs);
        } else {
            POLICE_EXIT_INVALID_INPUT(
                "unknown computation graph node type " << line);
        }
        if (l == 0u) {
            root = node;
        } else {
            parent->set_successor(node);
        }
        parent = std::move(node);
        last_outputs = parent->num_outputs();
    }
    std::cout << "parsed " << root->statistics() << std::endl;
    return root;
}

} // namespace police::cg
