#include "ffnn_generator.hpp"
#include "police/base_types.hpp"
#include "rng_utils.hpp"

#include <numeric>

namespace {
constexpr double MIN_BIAS = -100.;
constexpr double MAX_BIAS = 100.;
constexpr double MIN_WEIGHT = -10.;
constexpr double MAX_WEIGHT = 10.;
} // namespace

FFNNGenerator::FFNNGenerator(
    size_t input_size,
    size_t output_size,
    size_t hlayers,
    size_t hlayer_size,
    double density,
    int seed)
    : input_size_(input_size)
    , output_size_(output_size)
    , hlayers_(hlayers)
    , hlayer_size_(hlayer_size)
    , density_(density)
    , rng_(seed)
{
    next();
}

police::FeedForwardNeuralNetwork<> FFNNGenerator::generate() const
{
    police::FeedForwardNeuralNetwork<> net;
    net.input.resize(input_size_);
    if (hlayers_ == 0u) {
        net.layers.push_back(generate_layer(input_size_, output_size_));
    } else {
        net.layers.push_back(generate_layer(input_size_, hlayer_size_));
        for (auto n = hlayers_ - 1; n > 0u; --n) {
            net.layers.push_back(generate_layer(hlayer_size_, hlayer_size_));
        }
        net.layers.push_back(generate_layer(hlayer_size_, output_size_));
    }
    return net;
}

police::FeedForwardNeuralNetwork<>::Layer
FFNNGenerator::generate_layer(size_t in, size_t out) const
{
    police::FeedForwardNeuralNetwork<>::Layer layer;
    layer.biases.resize(out, 0.);
    layer.weights.resize(out, police::vector<police::real_t>(in, 0.));
    police::vector<size_t> inputs(in, 0);
    std::iota(inputs.begin(), inputs.end(), 0);
    for (int neuron = out - 1; neuron >= 0; --neuron) {
        rng_.shuffle(inputs.begin(), inputs.end());
        layer.biases[neuron] = generate_number(rng_, MIN_BIAS, MAX_BIAS);
        for (int i = std::min(in, static_cast<size_t>(density_ * in)); i > 0;
             --i) {
            layer.weights[neuron][inputs[i - 1]] =
                generate_number(rng_, MIN_WEIGHT, MAX_WEIGHT);
        }
    }
    return layer;
}

bool FFNNGenerator::next()
{
    net_ = generate();
    // std::cout << "Network: " << std::endl;
    // for (unsigned l = 0; l < net_.layers.size(); ++l) {
    //     const auto& layer = net_.layers[l];
    //     std::cout << "Layer#" << l << "\n";
    //     std::cout << "Biases: ";
    //     std::copy(
    //         layer.biases.begin(),
    //         layer.biases.end(),
    //         std::ostream_iterator<double>(std::cout, " "));
    //     std::cout << "\n";
    //     for (unsigned n = 0; n < layer.weights.size(); ++n) {
    //         const auto& weights = layer.weights[n];
    //         std::copy(
    //             weights.begin(),
    //             weights.end(),
    //             std::ostream_iterator<double>(std::cout, " "));
    //         std::cout << "\n";
    //     }
    // }
    // std::cout << std::endl;
    return true;
}
