#include "police/compute_graph.hpp"
#include "police/macros.hpp"
#include "police/storage/lin_vector.hpp"

#include <numeric>

#define DUMP_LAYER_OUTPUT 0

#if DUMP_LAYER_OUTPUT
#include "police/utils/io.hpp"
#endif

namespace police::cg {

LinVector Node::operator()(const vector<real_t>& input) const
{
    LinVector vec(input.size());
    for (int i = input.size() - 1; i >= 0; --i) {
        vec(i) = input[i];
    }
    auto out = compute(vec);
    const Node* node = successor_.get();
#if DUMP_LAYER_OUTPUT
    std::cout << "  <0> " << print_sequence(vec) << "\n";
    std::cout << "  <1> " << print_sequence(out) << "\n";
    int layer = 2;
#endif
    while (node != nullptr) {
        out = node->compute(std::move(out));
        node = node->successor_.get();
#if DUMP_LAYER_OUTPUT
        std::cout << "  <" << layer << "> " << print_sequence(out) << "\n";
        ++layer;
#endif
    }
    return out;
}

std::shared_ptr<Node> Node::copy() const
{
    auto result = copy_impl();
    auto* node = result.get();
    const auto* orig_node = this;
    while (orig_node->successor_ != nullptr) {
        node->set_successor(orig_node->successor_->copy_impl());
        node = node->successor_.get();
        orig_node = orig_node->successor_.get();
    }
    return result;
}

std::shared_ptr<Node> Node::deep_copy() const
{
    auto result = deep_copy_impl();
    auto* node = result.get();
    const auto* orig_node = this;
    while (orig_node->successor_ != nullptr) {
        node->set_successor(orig_node->successor_->deep_copy_impl());
        node = node->successor_.get();
        orig_node = orig_node->successor_.get();
    }
    return result;
}

size_t Node::num_linears() const
{
    return get_num_linears() +
           (successor_ == nullptr ? 0 : successor_->num_linears());
}

size_t Node::num_relus() const
{
    return get_num_relus() +
           (successor_ == nullptr ? 0 : successor_->num_relus());
}

size_t Node::num_pools() const
{
    return get_num_pools() +
           (successor_ == nullptr ? 0 : successor_->num_pools());
}

size_t Node::num_neurons() const
{
    return get_num_neurons() +
           (successor_ != nullptr ? successor_->num_neurons() : 0);
}

bool Node::is_leaf() const
{
    return successor_ == nullptr;
}

void Node::set_successor(std::shared_ptr<Node> successor)
{
    assert(successor == nullptr || successor->num_inputs() == num_outputs());
    successor_ = std::move(successor);
}

size_t Node::num_layers() const
{
    return successor_ == nullptr ? 1u : 1u + successor_->num_layers();
}

const std::shared_ptr<Node>& Node::successor() const
{
    return successor_;
}

Node* Node::leaf()
{
    Node* res = this;
    while (!res->is_leaf()) {
        res = res->successor_.get();
    }
    return res;
}

const Node* Node::leaf() const
{
    const Node* res = this;
    while (!res->is_leaf()) {
        res = res->successor_.get();
    }
    return res;
}

ComputeGraphSize Node::statistics() const
{
    return {
        num_inputs(),
        leaf()->num_outputs(),
        num_layers(),
        num_neurons(),
        num_linears(),
        num_relus(),
        num_pools()};
}

LinearLayer::LinearLayer(Matrix weights, LinVector biases)
    : weights_(std::make_shared<Matrix>(std::move(weights)))
    , biases_(std::make_shared<LinVector>(std::move(biases)))
{
}

size_t LinearLayer::LinearLayer::num_inputs() const
{
    return weights_->cols();
}

size_t LinearLayer::num_outputs() const
{
    return weights_->rows();
}

size_t LinearLayer::get_num_neurons() const
{
    return biases_->size();
}

LinVector LinearLayer::compute(const LinVector& input) const
{
    return (*weights_ * input) + *biases_;
}

std::shared_ptr<Node> LinearLayer::deep_copy_impl() const
{
    return std::make_shared<LinearLayer>(*weights_, *biases_);
}

ReluLayer::ReluLayer(size_t dimension)
    : zero_(std::make_shared<LinVector>(dimension))
{
    zero_->setZero();
}

size_t ReluLayer::num_inputs() const
{
    return zero_->size();
}

size_t ReluLayer::num_outputs() const
{
    return zero_->size();
}

size_t ReluLayer::get_num_neurons() const
{
    return zero_->size();
}

size_t ReluLayer::get_num_relus() const
{
    return zero_->size();
}

LinVector ReluLayer::compute(const LinVector& input) const
{
    assert(input.size() == zero_->size());
    return input.cwiseMax(*zero_);
}

std::shared_ptr<Node> ReluLayer::deep_copy_impl() const
{
    return std::make_shared<ReluLayer>(num_outputs());
}

MaxPoolLayer::MaxPoolLayer(vector<vector<size_t>> pools, size_t inputs)
    : pools_(std::make_shared<vector<vector<size_t>>>(std::move(pools)))
    , inputs_(inputs)
{
}

size_t MaxPoolLayer::num_inputs() const
{
    return inputs_;
}

size_t MaxPoolLayer::num_outputs() const
{
    return pools_->size();
}

const vector<size_t> MaxPoolLayer::get_input_refs(size_t pool_idx) const
{
    return (*pools_)[pool_idx];
}

size_t MaxPoolLayer::get_num_neurons() const
{
    return num_outputs();
}

size_t MaxPoolLayer::get_num_pools() const
{
    return num_outputs();
}

LinVector MaxPoolLayer::compute(const LinVector& input) const
{
    assert(input.size() == inputs_);
    assert(
        std::transform_reduce(
            pools_->begin(),
            pools_->end(),
            true,
            [](bool x, bool y) { return x && y; },
            [&](const vector<size_t>& pool) {
                return pool.empty() ||
                       *std::max_element(pool.begin(), pool.end()) < inputs_;
            }));
    LinVector result(pools_->size());
    for (int i = pools_->size() - 1; i >= 0; --i) {
        const vector<size_t>& indices = (*pools_)[i];
        if (indices.empty())
            result(i) = 0;
        else
            result(i) = std::transform_reduce(
                indices.begin(),
                indices.end(),
                -std::numeric_limits<real_t>::infinity(),
                [](real_t a, real_t b) { return std::max(a, b); },
                [&input](size_t idx) { return input(idx); });
    }
    return result;
}

std::shared_ptr<Node> MaxPoolLayer::deep_copy_impl() const
{
    return std::make_shared<MaxPoolLayer>(*pools_, inputs_);
}

ExpandToConstLayer::ExpandToConstLayer(
    size_t num_outputs,
    vector<size_t> remap_inputs,
    real_t scalar)
    : remap_inputs_(std::make_shared<vector<size_t>>(std::move(remap_inputs)))
    , scalar_(scalar)
    , num_outputs_(num_outputs)
{
}

size_t ExpandToConstLayer::num_inputs() const
{
    return remap_inputs_->size();
}

size_t ExpandToConstLayer::num_outputs() const
{
    return num_outputs_;
}

std::string ExpandToConstLayer::name() const
{
    return "expand";
}

LinVector ExpandToConstLayer::compute(const LinVector& input) const
{
    assert(input.size() == num_inputs());
    LinVector out(num_outputs_);
    out.fill(scalar_);
    for (int i = num_inputs() - 1; i >= 0; --i) {
        out[(*remap_inputs_)[i]] = input[i];
    }
    return out;
}

std::shared_ptr<Node> ExpandToConstLayer::deep_copy_impl() const
{
    return std::make_shared<ExpandToConstLayer>(
        num_outputs_,
        *remap_inputs_,
        scalar_);
}

} // namespace police::cg

namespace police {
std::ostream& operator<<(std::ostream& out, const cg::ComputeGraphSize& stats)
{
    out << "CG["
        << "inputs=" << stats.inputs << ", outputs=" << stats.outputs
        << ", layers=" << stats.layers << ", neurons=" << stats.neurons
        << ", linear=" << stats.linear << ", relu=" << stats.relus
        << ", pool=" << stats.pools << "]";
    return out;
}
} // namespace police

#undef DUMP_LAYER_OUTPUT
