#include "police/command_line_parser.hpp"
#include "police/utils/algorithms.hpp"

#include <algorithm>
#include <cassert>
#include <sstream>

namespace police {

namespace {

struct CallAction {
    void operator()(GlobalArguments& args, std::string_view) { func(args); }
    std::function<void(GlobalArguments&)> func;
};

struct StoreFlag {
    explicit StoreFlag(bool GlobalArguments::* dest)
        : dest(dest)
    {
    }
    void operator()(GlobalArguments& args)
    {
        if (dest) (args.*dest) = true;
    }
    bool GlobalArguments::* dest;
};

} // namespace

bool CommandLineParser::has_argument(const std::string& name) const
{
    return options_.count(name) || group_to_id_.count(name);
}

CommandLineParser::OptionId CommandLineParser::add_action(
    std::string_view name,
    std::function<void(GlobalArguments&)> action,
    std::string_view help,
    std::initializer_list<std::string_view> dependencies,
    bool mandatory,
    std::optional<std::function<void(GlobalArguments&)>> fallback)
{
    std::string name_str(name);
    if (has_argument(name_str)) {
        POLICE_INTERNAL_ERROR(
            "adding multiple arguments with name "
            << name << " to command line parser");
    }
    const OptionId id = options_.size();
    names_.push_back(name_str);
    options_[name_str] = id;
    fallback_.push_back(std::move(fallback));
    parsers_.push_back(CallAction(std::move(action)));
    needs_arg_.push_back(false);
    help_.emplace_back(help);
    dependencies_.emplace_back();
    for (auto it = dependencies.begin(); it != dependencies.end(); ++it) {
        dependencies_.back().emplace_back(*it);
    }
    if (mandatory) {
        mandatory_.push_back({id});
    }
    return id;
}

CommandLineParser::OptionId CommandLineParser::add_flag(
    std::string_view name,
    bool GlobalArguments::* dest,
    std::string_view help,
    std::initializer_list<std::string_view> dependencies,
    std::optional<std::function<void(GlobalArguments&)>> fallback)
{
    return add_action(
        name,
        StoreFlag(dest),
        help,
        std::move(dependencies),
        false,
        std::move(fallback));
}

CommandLineParser::OptionId CommandLineParser::add_raw_argument(
    std::string_view name,
    std::function<void(GlobalArguments&, std::string_view)> parser,
    std::string_view help,
    std::initializer_list<std::string_view> dependencies,
    bool mandatory,
    std::optional<std::function<void(GlobalArguments&)>> fallback)
{
    std::string name_str(name);
    if (has_argument(name_str)) {
        POLICE_INTERNAL_ERROR(
            "adding multiple arguments with name "
            << name << " to command line parser");
    }
    const OptionId id = options_.size();
    names_.push_back(name_str);
    options_[name_str] = id;
    fallback_.push_back(std::move(fallback));
    parsers_.push_back(std::move(parser));
    needs_arg_.push_back(true);
    help_.emplace_back(help);
    dependencies_.emplace_back();
    for (auto it = dependencies.begin(); it != dependencies.end(); ++it) {
        dependencies_.back().emplace_back(*it);
    }
    if (mandatory) {
        mandatory_.push_back({id});
    }
    return id;
}

void CommandLineParser::create_dependency_group(
    std::string_view name,
    std::initializer_list<std::string_view> members)
{
    std::string name_str(name);
    if (has_argument(name_str)) {
        POLICE_INTERNAL_ERROR(
            "adding multiple arguments with name "
            << name << " to command line parser");
    }
    group_to_id_[name_str] = dependency_groups_.size();
    dependency_groups_.emplace_back();
    for (const auto& m : members) {
        dependency_groups_.back().emplace_back(m);
    }
}

void CommandLineParser::ensure_oneof(std::initializer_list<OptionId> options)
{
    mandatory_.emplace_back(options);
}

GlobalArguments CommandLineParser::parse(int argc, const char** argv)
{
    std::vector<std::string_view> args;
    for (int i = 1; i < argc; ++i) {
        args.push_back(argv[i]);
    }
    return parse(std::move(args));
}

namespace {
vector<vector<size_t>> get_dependency_graph(
    const unordered_map<std::string, CommandLineParser::OptionId>& ids,
    const unordered_map<std::string, size_t>& group_ids,
    const vector<vector<std::string>>& groups,
    const vector<vector<std::string>>& dependencies)
{
    assert(ids.size() == dependencies.size());
    vector<vector<size_t>> graph(ids.size());
    std::function<void(vector<size_t>&, const std::string&)> add_to_graph;
    add_to_graph = [&](vector<size_t>& result, const std::string& name) {
        auto id = ids.find(name);
        if (id == ids.end()) {
            auto gid = group_ids.find(name);
            if (gid != group_ids.end()) {
                for (const auto& member : groups[gid->second]) {
                    add_to_graph(result, member);
                }
            }
        } else {
            result.push_back(id->second);
        }
    };
    for (int i = dependencies.size() - 1; i >= 0; --i) {
        const auto& deps = dependencies[i];
        for (int j = deps.size() - 1; j >= 0; --j) {
            add_to_graph(graph[i], deps[j]);
        }
    }
    return graph;
}

constexpr size_t NOT_FOUND = -1;

vector<size_t> find_option_indices(
    const unordered_map<std::string, CommandLineParser::OptionId>& ids,
    const vector<bool>& expects_arg,
    const vector<std::string_view>& args)
{
    vector<size_t> indices(ids.size(), NOT_FOUND);
    for (size_t i = 0; i < args.size(); ++i) {
        std::string_view arg = args[i];
        if (!arg.starts_with("--") &&
            (arg.size() != 2 || !arg.starts_with("-"))) {
            POLICE_UNKNOWN_ARGUMENT(arg);
        }
        auto id = ids.find(
            std::string(
                args.size() == 2u ? args[i].substr(1) : args[i].substr(2)));
        if (id == ids.end()) {
            POLICE_UNKNOWN_ARGUMENT(arg);
        }
        if (i + 1 == args.size() && expects_arg[id->second]) {
            POLICE_INVALID_ARGUMENT(arg, "expects argument");
        }
        indices[id->second] = i;
        i += expects_arg[id->second];
    }
    return indices;
}

void check_required_arguments(
    const vector<size_t>& indices,
    const vector<vector<CommandLineParser::OptionId>>& mandatory,
    const vector<std::string>& names)
{
    auto it = std::find_if(
        mandatory.begin(),
        mandatory.end(),
        [&](const vector<CommandLineParser::OptionId>& group) {
            return std::all_of(group.begin(), group.end(), [&](auto id) {
                return indices[id] == NOT_FOUND;
            });
        });
    if (it != mandatory.end()) {
        if (it->size() == 1u) {
            POLICE_EXIT_INVALID_INPUT(
                "Argument " << names[it->front()] << " is missing.");
        }
        std::ostringstream namestr;
        bool sep = false;
        for (auto id = it->begin(); id != it->end(); ++id) {
            namestr << (sep ? ", " : "") << names[*id];
            sep = true;
        }
        POLICE_EXIT_INVALID_INPUT(
            "One of the arguments " << namestr.str()
                                    << " is required but not given.");
    }
}

void check_dependencies(
    const unordered_map<std::string, CommandLineParser::OptionId>& ids,
    const vector<size_t>& indices,
    const vector<vector<std::string>>& dependencies,
    const vector<std::optional<std::function<void(GlobalArguments&)>>>&
        fallback)
{
    for (size_t i = 0; i < indices.size(); ++i) {
        if (indices[i] != NOT_FOUND) {
            for (const auto& dep : dependencies[i]) {
                auto id = ids.find(dep);
                if (id != ids.end() && indices[id->second] == NOT_FOUND &&
                    !fallback[id->second].has_value()) {
                    POLICE_EXIT_INVALID_INPUT("Missing argument " << dep);
                }
            }
        }
    }
}

} // namespace

GlobalArguments CommandLineParser::parse(const vector<std::string_view>& args)
{
    GlobalArguments result;
    for (size_t i = 0; i < args.size(); ++i) {
        result.arguments.emplace_back(args[i]);
    }
    vector<size_t> indices = find_option_indices(options_, needs_arg_, args);
    check_required_arguments(indices, mandatory_, names_);
    check_dependencies(options_, indices, dependencies_, fallback_);
    auto order = topological_sort(get_dependency_graph(
        options_,
        group_to_id_,
        dependency_groups_,
        dependencies_));
    for (size_t i = 0; i < order.size(); ++i) {
        OptionId id = order[i];
        const size_t pos = indices[id];
        if (pos != NOT_FOUND) {
            parsers_[id](result, needs_arg_[id] ? args[pos + 1] : "");
        } else if (fallback_[id].has_value()) {
            fallback_[id].value()(result);
        }
    }
    return result;
}

} // namespace police
