#include "police/storage/ffnn.hpp"
#include "police/base_types.hpp"

#include <algorithm>
#include <cassert>
#include <fstream>
#include <iterator>
#include <string>

namespace police {

namespace {
struct SplitIteratorSentinal {
    std::size_t count = 0;
};

struct SplitIterator {
    using value_type = std::string;
    using difference_type = int;

    SplitIterator(const std::string* s, std::size_t pos = 0)
        : str(s)
        , pos(pos)
        , next(pos == std::string::npos ? pos : find_next())
    {
    }

    [[nodiscard]]
    bool operator==(const SplitIterator& iter) const
    {
        return pos == iter.pos;
    }

    [[nodiscard]]
    bool operator==(const SplitIteratorSentinal& sentinal) const
    {
        return this->increments == sentinal.count;
    }

    [[nodiscard]]
    std::string operator*() const
    {
        assert(pos != std::string::npos);
        return str->substr(
            pos,
            next == std::string::npos ? next : (next - pos));
    }

    SplitIterator& operator++()
    {
        ++increments;
        if (next == std::string::npos) {
            pos = next;
        } else {
            pos = next + 1;
            next = find_next();
        }
        return *this;
    }

    SplitIterator operator++(int)
    {
        auto temp(*this);
        ++*this;
        return temp;
    }

    [[nodiscard]]
    std::size_t find_next() const
    {
        return str->find(',', pos + 1);
    }

    const std::string* str;
    std::size_t pos = std::string::npos;
    std::size_t next = std::string::npos;
    std::size_t increments = 0;
};

} // namespace

template <typename Real>
FeedForwardNeuralNetwork<Real>
FeedForwardNeuralNetwork<Real>::parse_nnet(std::string_view path)
{
    FeedForwardNeuralNetwork net;

    std::ifstream f(path.data());

    std::string line;

    // skip header lines
    bool header_line = true;
    while (header_line) {
        std::getline(f, line);
        header_line = (line.find("//") == 0);
    }

    // size parameters
    auto sizes = SplitIterator(&line);
    size_t num_layers = std::stoi(*sizes);
    size_t input_size = std::stoi(*++sizes);
    // auto output_size = std::stoi(sizes[2]);
    // auto may_layer_size = std::stoi(sizes[3]);

    // layer sizes
    std::getline(f, line);
    vector<size_t> layer_size;
    layer_size.reserve(num_layers + 1);
    std::ranges::transform(
        SplitIterator(&line),
        SplitIteratorSentinal(num_layers + 1),
        std::back_inserter(layer_size),
        [](auto&& s) { return std::stoi(s); });
    assert(layer_size.size() == num_layers + 1);

    // unused flag
    std::getline(f, line);

    // Input
    {
        // min values
        std::getline(f, line);
        vector<Real> min;
        min.reserve(input_size);
        std::ranges::transform(
            SplitIterator(&line),
            SplitIteratorSentinal(input_size),
            std::back_inserter(min),
            [](auto&& s) { return std::stod(s); });

        // max values
        std::getline(f, line);
        vector<Real> max;
        max.reserve(input_size);
        std::ranges::transform(
            SplitIterator(&line),
            SplitIteratorSentinal(input_size),
            std::back_inserter(max),
            [](auto&& s) { return std::stod(s); });

        // means
        std::getline(f, line);
        vector<Real> mean;
        mean.reserve(input_size);
        std::ranges::transform(
            SplitIterator(&line),
            SplitIteratorSentinal(input_size),
            std::back_inserter(mean),
            [](auto&& s) { return std::stod(s); });

        // ranges
        std::getline(f, line);
        vector<Real> range;
        range.reserve(input_size);
        std::ranges::transform(
            SplitIterator(&line),
            SplitIteratorSentinal(input_size),
            std::back_inserter(range),
            [](auto&& s) { return std::stod(s); });

        net.input.reserve(input_size);
        for (auto i = 0u; i < input_size; ++i) {
            net.input.emplace_back(min[i], max[i], mean[i], range[i]);
        }
    }

    // layers
    {
        net.layers.reserve(num_layers);
        for (auto layer = 0u; layer < num_layers; ++layer) {
            vector<vector<Real>> weights(layer_size[layer + 1]);
            for (auto i = 0u; i < layer_size[layer + 1]; ++i) {
                std::getline(f, line);
                weights[i].reserve(layer_size[layer]);
                std::ranges::transform(
                    SplitIterator(&line),
                    SplitIteratorSentinal(layer_size[layer]),
                    std::back_inserter(weights[i]),
                    [](auto&& s) { return std::stod(s); });
            }

            vector<Real> biases;
            biases.resize(layer_size[layer + 1]);
            for (auto i = 0u; i < layer_size[layer + 1]; ++i) {
                std::getline(f, line);
                biases[i] = [](auto&& s) { return std::stod(s); }(line);
            }

            net.layers.emplace_back(std::move(weights), std::move(biases));
        }
    }

    return net;
}

template class FeedForwardNeuralNetwork<real_t>;
} // namespace police
