#pragma once

#include "police/base_types.hpp"
#include "police/storage/vector.hpp"

#include <algorithm>
#include <cassert>
#include <functional>
#include <numeric>
#include <string_view>

namespace police {

struct Relu {
    template <typename Real>
    Real operator()(Real&& r) const
    {
        return r >= 0. ? r : 0.;
    }
};

template <typename Real = real_t>
class FeedForwardNeuralNetwork {
public:
    using real_t = Real;

    struct Input {
        Real min;
        Real max;
        Real mean;
        Real range;

        [[nodiscard]]
        Real normalize(Real r) const
        {
            return std::clamp(r, min, max) / range;
        }
    };

    struct Layer {
        vector<vector<Real>> weights;
        vector<Real> biases;

        Layer() = default;

        Layer(vector<vector<Real>>&& weights, vector<Real>&& biases)
            : weights(std::move(weights))
            , biases(std::move(biases))
        {
        }

        template <typename Activation = std::identity>
        [[nodiscard]]
        vector<Real> operator()(
            const vector<Real>& input,
            Activation activation_fn = Activation()) const
        {
            vector<Real> result(biases);
            for (int i = weights.size() - 1; i >= 0; --i) {
                const vector<Real>& neuron = weights[i];
                assert(neuron.size() == input.size());
                result[i] = activation_fn(
                    std::transform_reduce(
                        neuron.begin(),
                        neuron.end(),
                        input.begin(),
                        result[i],
                        std::plus<>(),
                        std::multiplies<>()));
            }
            return result;
        }

        [[nodiscard]]
        size_t size() const
        {
            return weights.size();
        }

        [[nodiscard]]
        size_t num_inputs() const
        {
            return !weights.empty() ? weights.back().size() : 0u;
        }

        [[nodiscard]]
        size_t num_outputs() const
        {
            return size();
        }
    };

    FeedForwardNeuralNetwork() = default;

    [[nodiscard]]
    static FeedForwardNeuralNetwork parse_nnet(std::string_view path);

    [[nodiscard]]
    vector<Real> operator()(const vector<Real>& input) const
    {
        vector<Real> result(input);
        for (auto i = 0u; i < layers.size() - 1; ++i) {
            result = layers[i](result, Relu());
        }
        return layers.back()(result);
    }

    vector<Input> input;
    vector<Layer> layers;
};

} // namespace police
