#pragma once

#include "police/successor_generator/applicable_actions_generator.hpp"
#include "police/successor_generator/compress_state_adaptor.hpp"
#include "police/successor_generator/determinize_outcomes_adapator.hpp"
#include "police/successor_generator/successor_generator.hpp"

#include <iterator>

namespace police::successor_generator {

class BaseSuccessorGenerator
    : public SuccessorGenerator<ApplicableActionsGenerator> {
public:
    explicit BaseSuccessorGenerator(const vector<Action>* actions)
        : SuccessorGenerator<ApplicableActionsGenerator>(
              actions,
              ApplicableActionsGenerator(actions))
    {
    }
};

namespace _detail {
template <typename Policy>
class PolicyActionsGenerator {
public:
    PolicyActionsGenerator(const vector<Action>* actions, const Policy* policy)
        : aops_gen_(actions)
        , policy_(policy)
        , actions_(actions)
    {
    }

    [[nodiscard]]
    ApplicableActions operator()(const FlatState& state) const
    {
        ApplicableActions aops = aops_gen_(state);
        const auto selected = (*policy_)(state);
        auto i = 0u;
        for (auto j = 0u; j < aops.size(); ++j) {
            const auto& a = (*actions_)[aops[j]];
            if (a.label == SILENT_ACTION || a.label == selected) {
                aops[i] = aops[j];
                ++i;
            }
        }
        aops.erase(aops.begin() + i, aops.end());
        return aops;
    }

private:
    ApplicableActionsGenerator aops_gen_;
    const Policy* policy_;
    const vector<Action>* actions_;
};

template <typename Policy>
class MaskedPolicyActionsGenerator {
public:
    MaskedPolicyActionsGenerator(
        const vector<Action>* actions,
        const Policy* policy)
        : aops_gen_(actions)
        , policy_(policy)
        , actions_(actions)
    {
    }

    [[nodiscard]]
    ApplicableActions operator()(const FlatState& state) const
    {
        ApplicableActions aops = aops_gen_(state);
        assert(std::is_partitioned(aops.begin(), aops.end(), [&](auto idx) {
            return (*actions_)[idx].label == SILENT_ACTION;
        }));
        auto point = std::find_if_not(aops.begin(), aops.end(), [&](auto idx) {
            return (*actions_)[idx].label == SILENT_ACTION;
        });
        if (point != aops.end()) {
            vector<size_t> labels;
            std::transform(
                point,
                aops.end(),
                std::back_inserter(labels),
                [&](size_t action_idx) {
                    return (*actions_)[action_idx].label;
                });
            auto pos = (*policy_)(state, labels.begin(), labels.end());
            if (pos != labels.end()) {
                for (auto it = point; it != aops.end(); ++it) {
                    if ((*actions_)[*it].label == *pos) {
                        *point = *it;
                        ++point;
                    }
                }
            }
            aops.erase(point, aops.end());
        }
        return aops;
    }

private:
    ApplicableActionsGenerator aops_gen_;
    const Policy* policy_;
    const vector<Action>* actions_;
};
} // namespace _detail

template <typename Policy>
class PolicySuccessorGenerator
    : public SuccessorGenerator<_detail::PolicyActionsGenerator<Policy>> {
public:
    explicit PolicySuccessorGenerator(
        const vector<Action>* actions,
        const Policy* policy)
        : SuccessorGenerator<_detail::PolicyActionsGenerator<Policy>>(
              actions,
              _detail::PolicyActionsGenerator(actions, policy))
    {
    }
};

template <typename Policy>
class MaskedPolicySuccessorGenerator
    : public SuccessorGenerator<_detail::MaskedPolicyActionsGenerator<Policy>> {
public:
    explicit MaskedPolicySuccessorGenerator(
        const vector<Action>* actions,
        const Policy* policy)
        : SuccessorGenerator<_detail::MaskedPolicyActionsGenerator<Policy>>(
              actions,
              _detail::MaskedPolicyActionsGenerator(actions, policy))
    {
    }
};

using CompressedStateSuccessorGenerator =
    compress_state_adaptor<BaseSuccessorGenerator>;

using DeterminizedSuccessorGenerator =
    determinize_outcomes_adapator<CompressedStateSuccessorGenerator>;

} // namespace police::successor_generator
