#pragma once

#include "police/base_types.hpp"
#include "police/storage/vector.hpp"

#include <algorithm>
#include <cassert>

#define DUMP_VALUES 0

#if DUMP_VALUES
#include "police/utils/io.hpp"
#include <iostream>
#endif

namespace police {

template <typename ValueFunction>
class VBasedPolicy {
    constexpr static size_t UNDEFINED = std::numeric_limits<size_t>::max();

public:
    VBasedPolicy(
        ValueFunction vfunc,
        vector<size_t> input_vars,
        vector<size_t> output_actions,
        size_t num_actions)
        : vfunc_(std::move(vfunc))
        , input_vars_(std::move(input_vars))
        , output_actions_(std::move(output_actions))
        , action_to_output_(num_actions, UNDEFINED)
    {
        for (auto i = 0u; i < output_actions_.size(); ++i) {
            assert(output_actions_[i] < num_actions);
            action_to_output_[output_actions_[i]] = i;
        }
    }

    template <typename State>
    size_t operator()(const State& state) const
    {
        const auto nn_out = vfunc_(get_input(state));
        assert(!nn_out.empty());

#if DUMP_VALUES
        std::cout << print_sequence(state)
                  << " -> "
                     "out: "
                  << print_sequence(nn_out);
#endif

        std::pair<real_t, int_t> best(
            -std::numeric_limits<real_t>::infinity(),
            1);
        for (int_t i = nn_out.size() - 1; i >= 0; --i) {
            const std::pair<real_t, int> val(
                nn_out[i],
                -static_cast<int_t>(output_actions_[i]));
            if (val > best) {
                best = val;
            }
        }
        assert(best.second <= 0);

#if DUMP_VALUES
        std::cout << " => (" << best.first << ", " << (-best.second) << ")"
                  << std::endl;
#endif

        return -best.second;
    }

    template <typename State, typename InputIterator>
    [[nodiscard]]
    auto operator()(const State& state, InputIterator first, InputIterator last)
        const
    {
#if DUMP_VALUES
        std::cout << "state: " << print_sequence(state) << "\n"
                  << "input: " << print_sequence(get_input(state))
                  << " (for variables " << print_sequence(get_input()) << ")"
                  << std::endl;
#endif
        const auto nn_out = vfunc_(get_input(state));
#if DUMP_VALUES
        std::cout << "action values: " << print_sequence(nn_out) << std::endl
                  << "selected action: ";
#endif
        const auto max_elem =
            std::max_element(first, last, [&](size_t a, size_t b) {
                const auto i = action_to_output_[a];
                const auto j = action_to_output_[b];
                if (i == UNDEFINED) return true;
                if (j == UNDEFINED) return false;
                return nn_out[i] < nn_out[j] ||
                       (nn_out[i] == nn_out[j] && b < a);
            });
        if (max_elem == last || action_to_output_[*max_elem] == UNDEFINED) {
#if DUMP_VALUES
            std::cout << "*UNDEFINED*" << std::endl;
#endif
            return last;
        }
#if DUMP_VALUES
        std::cout << (*max_elem) << " (index: " << action_to_output_[*max_elem]
                  << ", value: " << nn_out[action_to_output_[*max_elem]] << ")"
                  << std::endl;
#endif
        return max_elem;
    }

    const ValueFunction& get_value_function() const { return vfunc_; }

    /**
     * Sequence of variable ids corresponding to the network's input neurons in
     * ascending order in the neurons indices.
     **/
    const vector<size_t>& get_input() const { return input_vars_; }
    vector<size_t>& get_input() { return input_vars_; }

    /**
     * Sequence of action ids corresponding to the network's output neurons
     * in ascending order in the neurons indices.
     **/
    const vector<size_t>& get_output() const { return output_actions_; }
    vector<size_t>& get_output() { return output_actions_; }

    const vector<size_t>& get_action_indices() const { return action_to_output_; }

private:
    template <typename State>
    vector<real_t> get_input(const State& state) const
    {
        vector<real_t> result(input_vars_.size());
        for (int i = input_vars_.size() - 1; i >= 0; --i) {
            result[i] = static_cast<real_t>(state[input_vars_[i]]);
        }
        return result;
    }

    ValueFunction vfunc_;
    vector<size_t> input_vars_;
    vector<size_t> output_actions_;
    vector<size_t> action_to_output_;
};

} // namespace police

#undef DUMP_VALUES
