#include "police/cg_pruner.hpp"
#include "police/cg_relaxation.hpp"
#include "police/compute_graph.hpp"
#include "police/layer_bounds.hpp"
#include "police/macros.hpp"
#include "police/storage/vector.hpp"
#include "police/utils/stopwatch.hpp"

#include <algorithm>
#include <cassert>
#include <cstdlib>
#include <memory>

namespace police::cg {

#if POLICE_EIGEN
namespace {

[[nodiscard]]
vector<size_t> get_index_map(const vector<bool>& discard)
{
    vector<size_t> result(discard.size(), -1);
    for (size_t i = 0, j = 0; i < discard.size(); ++i) {
        if (!discard[i]) {
            result[i] = j;
            ++j;
        }
    }
    return result;
}

class Pruner final : public NodeVisitor {
public:
    explicit Pruner(
        real_t epsilon,
        const vector<LayerBounds>* layer_bounds,
        size_t input_size)
        : layer_bounds_(layer_bounds)
        , pruned_(input_size, false)
        , epsilon_(epsilon)
    {
    }

    void visit(const LinearLayer* node) override
    {
        const auto& bounds = layer_bounds_->at(cur_layer_);
        ++cur_layer_;

        vector<bool> pruned(node->num_outputs(), false);
        for (size_t neuron = 0; neuron < node->num_outputs(); ++neuron) {
            pruned[neuron] = std::abs(bounds.lb(neuron)) <= epsilon_ &&
                             std::abs(bounds.ub(neuron)) <= epsilon_;
        }
        pruned_.swap(pruned);

        recursion(node);
        assert(pruned_.size() == node->num_outputs());

        vector<bool> nonz(node->num_inputs(), false);
        for (size_t o = 0; o < node->num_outputs(); ++o) {
            if (pruned_[o]) continue;
            for (size_t i = 0; i < node->num_inputs(); ++i) {
                nonz[i] =
                    nonz[i] || std::abs(node->get_weights()(o, i)) > epsilon_;
            }
        }
        for (int i = nonz.size() - 1; i >= 0; --i) {
            pruned[i] = pruned[i] || !nonz[i];
        }

        size_t num_inputs = std::count(pruned.begin(), pruned.end(), false);
        size_t num_outputs = std::count(pruned_.begin(), pruned_.end(), false);

        std::shared_ptr<LinearLayer> new_node = nullptr;
        if (num_inputs == node->num_inputs() &&
            num_outputs == node->num_outputs()) {
            new_node = std::make_shared<LinearLayer>(
                node->get_weights(),
                node->get_biases());
        } else {
            Matrix w(num_outputs, num_inputs);
            LinVector b(num_outputs);
            for (size_t o = 0, oo = 0; o < node->num_outputs(); ++o) {
                if (pruned_[o]) continue;
                for (size_t i = 0, ii = 0; i < node->num_inputs(); ++i) {
                    if (pruned[i]) continue;
                    w(oo, ii) = node->get_weights()(o, i);
                    ++ii;
                }
                b[oo] = node->get_biases()[o];
                ++oo;
            }
            new_node =
                std::make_shared<LinearLayer>(std::move(w), std::move(b));
        }

        num_pruned_ += node->num_outputs() - num_outputs;

        auto lin = std::dynamic_pointer_cast<LinearLayer>(successor_);
        if (lin != nullptr) {
            Matrix w = lin->get_weights() * new_node->get_weights();
            LinVector b = lin->get_biases() +
                          (lin->get_weights() * new_node->get_biases());
            new_node =
                std::make_shared<LinearLayer>(std::move(w), std::move(b));
            successor_ = successor_->successor();
        }

        new_node->set_successor(std::move(successor_));
        successor_ = std::move(new_node);

        pruned_.swap(pruned);
    }

    void visit(const ReluLayer* node) override
    {
        const auto& bounds = layer_bounds_->at(cur_layer_);
        ++cur_layer_;
        for (size_t neuron = 0; neuron < node->num_outputs(); ++neuron) {
            pruned_[neuron] = pruned_[neuron] || bounds.ub(neuron) < epsilon_;
        }
        recursion(node);
        const size_t size = std::count(pruned_.begin(), pruned_.end(), false);
        std::shared_ptr<Node> relu = std::make_shared<ReluLayer>(size);
        relu->set_successor(std::move(successor_));
        successor_ = std::move(relu);
        num_pruned_ += node->num_outputs() - size;
    }

    void visit(const MaxPoolLayer* node) override
    {
        const auto& bounds = layer_bounds_->at(cur_layer_);
        ++cur_layer_;
        vector<bool> pruned(node->num_outputs(), false);
        for (size_t neuron = 0; neuron < node->num_outputs(); ++neuron) {
            const auto& old_group = node->get_input_refs(neuron);
            const size_t inputs = std::count_if(
                old_group.begin(),
                old_group.end(),
                [&](size_t idx) { return !pruned_[idx]; });
            pruned[neuron] =
                inputs == 0 || (std::abs(bounds.lb(neuron)) <= epsilon_ &&
                                std::abs(bounds.ub(neuron)) <= epsilon_);
        }
        pruned_.swap(pruned);
        recursion(node);
        vector<bool> used(node->num_inputs(), false);
        for (size_t o = 0; o < node->num_outputs(); ++o) {
            if (pruned_[o]) continue;
            for (size_t idx : node->get_input_refs(o)) {
                used[idx] = true;
            }
        }
        for (int i = used.size() - 1; i >= 0; --i) {
            pruned[i] = pruned[i] || !used[i];
        }

        const size_t num_inputs =
            std::count(pruned.begin(), pruned.end(), false);
        const size_t num_outputs =
            std::count(pruned_.begin(), pruned_.end(), false);

        std::shared_ptr<Node> new_node = nullptr;
        if (num_inputs == node->num_inputs() &&
            num_outputs == node->num_outputs()) {
            new_node =
                std::make_shared<MaxPoolLayer>(node->get_pools(), num_inputs);
        } else {
            const auto new_index = get_index_map(pruned);
            vector<vector<size_t>> pools;
            pools.reserve(num_outputs);
            for (size_t neuron = 0; neuron < node->num_outputs(); ++neuron) {
                if (pruned_[neuron]) continue;
                const auto& old_group = node->get_input_refs(neuron);
                vector<size_t> group;
                group.reserve(old_group.size());
                for (const auto& idx : old_group) {
                    if (!pruned[idx]) {
                        group.push_back(new_index[idx]);
                    }
                }
                assert(!group.empty());
                pools.push_back(std::move(group));
            }
            new_node =
                std::make_shared<MaxPoolLayer>(std::move(pools), num_inputs);
        }

        num_pruned_ += node->num_outputs() - num_outputs;

        new_node->set_successor(std::move(successor_));
        successor_ = std::move(new_node);

        pruned_.swap(pruned);
    }

    void visit(const ExpandToConstLayer* node) override
    {
        ++cur_layer_;
        {
            vector<bool> pruned(
                node->num_outputs(),
                std::abs(node->get_scalar()) <= epsilon_);
            for (int i = node->num_inputs() - 1; i >= 0; --i) {
                pruned[node->get_input_remap()[i]] = pruned_[i];
            }
            pruned.swap(pruned_);
        }
        recursion(node);
        vector<std::pair<size_t, size_t>> input_remap;
        input_remap.reserve(node->num_inputs());
        for (int i = node->num_inputs() - 1; i >= 0; --i) {
            input_remap.emplace_back(node->get_input_remap()[i], i);
        }
        std::sort(input_remap.begin(), input_remap.end());
        vector<size_t> inputs(node->num_inputs());
        size_t size = 0;
        vector<bool> pruned(node->num_inputs(), true);
        for (size_t o = 0, i = 0; o < node->num_outputs(); ++o) {
            if (pruned_[o]) continue;
            for (; i < input_remap.size() && input_remap[i].first < o; ++i) {
            }
            if (i < input_remap.size() && input_remap[i].first == o) {
                const auto j = input_remap[i].second;
                pruned[j] = false;
                inputs[j] = size;
            }
            ++size;
        }
        size_t j = 0;
        for (size_t i = 0; i < inputs.size(); ++i) {
            if (!pruned[i]) {
                inputs[j] = inputs[i];
                ++j;
            }
        }
        inputs.erase(inputs.begin() + j, inputs.end());
        std::shared_ptr<Node> new_node = std::make_shared<ExpandToConstLayer>(
            size,
            std::move(inputs),
            node->get_scalar());
        new_node->set_successor(std::move(successor_));
        pruned_.swap(pruned);
        successor_ = std::move(new_node);
        num_pruned_ += node->num_outputs() - size;
    }

    [[nodiscard]]
    std::shared_ptr<Node>& get_transformed_cg()
    {
        return successor_;
    }

    [[nodiscard]]
    vector<size_t> get_non_pruned_inputs() const
    {
        vector<size_t> result;
        result.reserve(pruned_.size());
        for (size_t i = 0; i < pruned_.size(); ++i) {
            if (!pruned_[i]) {
                result.push_back(i);
            }
        }
        return result;
    }

    [[nodiscard]]
    vector<size_t> get_non_pruned_outputs() const
    {
        vector<size_t> result;
        result.reserve(pruned_outputs_.size());
        for (size_t i = 0; i < pruned_outputs_.size(); ++i) {
            if (!pruned_outputs_[i]) {
                result.push_back(i);
            }
        }
        return result;
    }

    [[nodiscard]]
    size_t num_pruned() const
    {
        return num_pruned_;
    }

private:
    void recursion(const Node* node)
    {
        if (node->successor() != nullptr) {
            node->successor()->accept(this);
        } else {
            pruned_outputs_ = pruned_;
            assert(pruned_.size() == node->num_outputs());
        }
    }

    const vector<LayerBounds>* layer_bounds_ = nullptr;
    vector<bool> pruned_;
    vector<bool> pruned_outputs_;
    std::shared_ptr<Node> successor_ = nullptr;
    size_t cur_layer_ = 1;

    size_t num_pruned_ = 0;
    real_t epsilon_;
};

[[maybe_unused]] [[nodiscard]]
vector<LayerBounds> compute_layer_bounds(
    std::shared_ptr<Node> root,
    const LayerBounds& input_bounds)
{
    static RelaxationOptions
        relaxation_options{1000, 5e-3, 3, 600, 600, true, true};
    static std::shared_ptr<NodeRelaxerFactory> node_relaxer_factory =
        std::make_shared<NodeRelaxerFactory>();
    ComputeGraphRelaxation relaxer(
        node_relaxer_factory->create_recursive(root.get()),
        relaxation_options);
    return relaxer.compute_layer_bounds(input_bounds);
}

} // namespace
#endif

ComputeGraphPruner::ComputeGraphPruner(real_t epsilon)
    : epsilon_(epsilon)
{
}

PostProcessingResult ComputeGraphPruner::prune(
    [[maybe_unused]] std::shared_ptr<Node> root,
    [[maybe_unused]] const LayerBounds& input_bounds)
{
#if POLICE_EIGEN
#if 0
    vector<size_t> in(root->num_inputs());
    std::iota(in.begin(), in.end(), 0);
    return {root, in};
#else
    ScopedStopWatch sw("post-processing compute graph");
    std::cout << "post-processing compute graph..." << std::endl;
    ScopedStopWatch sw2("computed node bounds");
    std::cout << "computing compute graph node bounds..." << std::endl;
    auto bounds = compute_layer_bounds(root, input_bounds);
    sw2.destroy();
    Pruner pruner(epsilon_, &bounds, input_bounds.size());
    root->accept(&pruner);
    std::shared_ptr<Node> result = pruner.get_transformed_cg();
    vector<size_t> out = pruner.get_non_pruned_outputs();
    const size_t num_outs = root->leaf()->num_outputs();
    if (out.size() != num_outs) {
        result->leaf()->set_successor(
            std::make_shared<ExpandToConstLayer>(num_outs, std::move(out), 0));
    }
    return {std::move(result), pruner.get_non_pruned_inputs()};
#endif
#else
    POLICE_MISSING_DEPENDENCY("eigen");
#endif
}

} // namespace police::cg
