#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl/filesystem.h>
#include <pybind11/stl_bind.h>

#include "sc2_serializer/data_structures/replay_all.hpp"

#include <filesystem>
#include <fstream>
#include <unordered_set>

namespace py = pybind11;

namespace pfr = boost::pfr;

[[nodiscard]] auto getUnitTargetIndices(const cvt::StepDataSoA &replayData,
    std::int64_t index) noexcept -> py::array_t<std::int64_t>
{
    const auto &unitData = replayData.units[static_cast<std::size_t>(index)];
    py::array_t<std::int64_t> mappedIds(static_cast<py::ssize_t>(unitData.size()));

    // std::unordered_map<cvt::UID, std::int64_t> unitId2Idx;
    // for (std::int64_t idx = 0; idx < unitData.size(); ++idx) { unitId2Idx[unitData[idx].id] = idx; }
    // std::ranges::transform(unitData, mappedIds.mutable_data(), [&](const cvt::Unit &unit) {
    //     const auto it = unitId2Idx.find(unit.tgtId);
    //     if (it == unitId2Idx.end()) { return -1L; }
    //     return it->second;
    // });

    std::vector<cvt::UID> unitIds(unitData.size());
    std::ranges::transform(unitData, unitIds.begin(), [](const cvt::Unit &unit) { return unit.id; });
    std::ranges::transform(unitData, mappedIds.mutable_data(), [&](const cvt::Unit &unit) {
        const auto it = std::ranges::find(unitIds, unit.tgtId);
        if (it == unitIds.end()) { return -1L; }
        return static_cast<std::int64_t>(std::distance(unitIds.begin(), it));
    });

    return mappedIds;
}

[[nodiscard]] auto getUnitTargetCoordinates(const cvt::StepDataSoA &replayData,
    std::int64_t index) noexcept -> py::array_t<std::int32_t>
{
    const auto &unitData = replayData.units[static_cast<std::size_t>(index)];
    py::array_t<std::int32_t> data({ unitData.size(), 2UL });
    std::ranges::transform(unitData, reinterpret_cast<cvt::Point2d *>(data.mutable_data()), [](const cvt::Unit &unit) {
        auto &target_pos = unit.order0.target_pos;
        if (target_pos.x != 0 && target_pos.y != 0) {
            return target_pos;
        } else {
            return cvt::Point2d(-1, -1);
        }
    });
    return data;
}

template<> struct std::hash<cvt::Point2d>
{
    [[nodiscard]] auto operator()(const cvt::Point2d &p) const noexcept -> std::size_t
    {
        static_assert(sizeof(p) == sizeof(std::size_t) && "can't just bitcast as sizes are different");
        return std::bit_cast<std::size_t>(p);
    }
};

[[nodiscard]] auto getUniqueUnitTargetCoordinates(const cvt::StepDataSoA &replayData,
    std::int64_t index) noexcept -> py::array_t<std::int32_t>
{
    const auto &unitData = replayData.units[static_cast<std::size_t>(index)];
    // Gather order0.target_pos if non-zero into set
    // Turn set into array to return
    std::unordered_set<cvt::Point2d> targetCoords;
    std::ranges::for_each(unitData, [&](const cvt::Unit &unit) {
        auto &target_pos = unit.order0.target_pos;
        if (target_pos.x != 0 && target_pos.y != 0) { targetCoords.insert(target_pos); };
    });
    py::array_t<std::int32_t> data({ targetCoords.size(), 2UL });
    // Reinterpret as 2d point array and copy element-by-element
    std::ranges::copy(targetCoords, reinterpret_cast<cvt::Point2d *>(data.mutable_data()));
    return data;
}

[[nodiscard]] auto getUnitCoordinateIndices(const cvt::StepDataSoA &replayData,
    std::int64_t index,
    py::array_t<int> coordinates) noexcept -> py::array_t<std::int64_t>
{
    const auto &unitData = replayData.units[static_cast<std::size_t>(index)];
    std::span coords_view(
        reinterpret_cast<const cvt::Point2d *>(coordinates.data()), static_cast<std::size_t>(coordinates.shape(0)));
    py::array_t<std::int64_t> assignment(py::ssize_t_cast(unitData.size()));
    std::ranges::transform(unitData, assignment.mutable_data(), [&](const cvt::Unit &unit) {
        auto &target_pos = unit.order0.target_pos;
        if (target_pos.x != 0 && target_pos.y != 0) {
            const auto it = std::ranges::find(coords_view, target_pos);
            if (it == coords_view.end()) {
                return -1L;
            } else {
                return static_cast<std::int64_t>(std::distance(coords_view.begin(), it));
            }
        } else {
            return -1L;
        }
    });
    return assignment;
}

template<typename T, std::size_t I = pfr::tuple_size_v<T> - 1>
void copySubRange(const T &source, T &dest, std::size_t offset, std::size_t nelem) noexcept
{
    const auto &source_it = pfr::get<I>(source).begin() + offset;
    std::ranges::copy(source_it, source_it + nelem, std::back_inserter(pfr::get<I>(dest)));
    if constexpr (I > 0) { copySubRange<T, I - 1>(source, dest, offset, nelem); }
}

[[nodiscard]] auto createReplaySubsequnce(const cvt::ReplayDataSoA &replayData,
    std::size_t offset,
    std::size_t nelem) -> cvt::ReplayDataSoA
{
    if (offset + nelem > replayData.size()) { throw std::out_of_range("Offset + Length exceeds Replay Length"); }
    cvt::ReplayDataSoA newReplay;
    newReplay.header = replayData.header;
    // Add the index of the first sample to the replay hash to make it unique
    newReplay.header.replayHash.append(std::string("_") + std::to_string(offset));
    copySubRange(replayData.data, newReplay.data, offset, nelem);
    return newReplay;
}

void normalizeCoordinatesInplace(py::list dataSequence, py::array_t<float> center, py::array_t<float> size)
{
    auto center_ = center.unchecked<1>();
    std::array<float, 2> inv_half_size;
    inv_half_size[0] = 1.f / (size.unchecked<1>()(0) / 2.f);
    inv_half_size[1] = 1.f / (size.unchecked<1>()(1) / 2.f);
    for (auto &&item : dataSequence) {
        auto data = py::cast<py::array_t<float>>(item);
        auto ref = data.mutable_unchecked<2>();
        for (py::ssize_t idx = 0; idx < ref.shape(0); ++idx) {
            ref(idx, 0) = (ref(idx, 0) - center_(0)) * inv_half_size[0];
            ref(idx, 1) = (ref(idx, 1) - center_(1)) * inv_half_size[1];
        }
    }
}

[[nodiscard]] auto inBounds2D(py::array_t<float> data, float lower, float upper) -> py::array_t<bool>
{
    const auto acc = data.unchecked<2>();
    py::array_t<bool> mask(data.shape(0));
    auto mask_a = mask.mutable_unchecked<1>();
    for (py::ssize_t idx = 0; idx < acc.shape(0); ++idx) {
        const float x = acc(idx, 0);
        const float y = acc(idx, 1);
        bool cond = lower < x && x < upper;
        cond &= lower < y && y < upper;
        mask_a(idx) = cond;
    }
    return mask;
}

[[nodiscard]] auto gatherUniqueUnitTypes(const cvt::StepDataSoA &replayData) -> std::unordered_set<int>
{
    std::unordered_set<int> uniqueTypes;
    for (const auto &unitStep : replayData.units) {
        for (const auto &unit : unitStep) { uniqueTypes.insert(unit.unitType); }
    }
    return uniqueTypes;
}

PYBIND11_MAKE_OPAQUE(std::unordered_map<int, int>);


[[nodiscard]] auto createUnitTypeToContiguousMap(const std::filesystem::path &filepath) -> std::unordered_map<int, int>
{
    std::vector<int> ids;
    {
        std::ifstream file(filepath);
        std::string data;
        while (std::getline(file, data, ',')) { ids.emplace_back(std::atoi(data.c_str())); }
    }
    std::ranges::sort(ids);

    std::unordered_map<int, int> mapping;
    for (std::size_t i = 0; i < ids.size(); ++i) { mapping.emplace(ids[i], static_cast<int>(i)); }
    return mapping;
}

void makeUnitTypeContiguous(py::array_t<float> data, py::ssize_t ch_index, const std::unordered_map<int, int> &mapping)
{
    auto arrayHandle = data.mutable_unchecked<2>();
    for (py::ssize_t idx = 0; idx < data.shape(0); ++idx) {
        const auto unitType = static_cast<int>(arrayHandle(idx, ch_index));
        const float newValue = static_cast<float>(mapping.at(unitType));
        arrayHandle(idx, ch_index) = newValue;
    }
}

PYBIND11_MODULE(dataset_utils, m)
{
    py::module_::import("sc2_serializer._sc2_serializer").attr("StepDataSoA");

    m.def("get_unit_target_indices",
        &getUnitTargetIndices,
        "Returns array that maps units to the index of their target unit, -1 for no mapping.",
        py::arg("replay_data"),
        py::arg("index"));

    m.def("get_target_coordinates",
        &getUnitTargetCoordinates,
        "Get the target coordinates of order0 of the units, units without target coordinates are filled with (-1,-1).",
        py::arg("replay_data"),
        py::arg("index"));

    m.def("get_unique_target_coordinates",
        &getUniqueUnitTargetCoordinates,
        "Gather the set of unique target position coordinates that units are assigned.",
        py::arg("replay_data"),
        py::arg("index"));

    m.def("get_unit_coordinate_indices",
        &getUnitCoordinateIndices,
        "Returns array that maps units to the index in the array of unique coordinate targets, -1 for no mapping.",
        py::arg("replay_data"),
        py::arg("index"),
        py::arg("coordinates"));

    m.def("create_replay_subsequence",
        &createReplaySubsequnce,
        "Creates replay that is a subsequence of an existing replay given the new beginning (offset) and length "
        "(nelem) of the subrange.",
        py::arg("replay_data"),
        py::arg("offset"),
        py::arg("nelem"));

    m.def("normalize_coordinates_inplace",
        &normalizeCoordinatesInplace,
        "Inplace apply normalization to coordinates [-1,1] assumes [x,y] are first two elements of last dimension.",
        py::arg("sequence"),
        py::arg("center"),
        py::arg("size"));

    m.def("in_bounds_2d",
        &inBounds2D,
        "Check if position coodinates of 2D array [N, [x,y,...]] are within lower and upper bounds.",
        py::arg("data"),
        py::arg("lower"),
        py::arg("upper"));

    m.def("gather_unique_unit_types",
        &gatherUniqueUnitTypes,
        "Get the unique unit types from replay data",
        py::arg("replay"));

    py::bind_map<std::unordered_map<int, int>>(m, "UnitTypeToContiguous");

    m.def("create_unit_type_to_contiguous_map",
        &createUnitTypeToContiguousMap,
        "Read the unit type ids from a file, sort in order for consistency, then create mapping from id to idx into "
        "sorted array",
        py::arg("filepath"));

    m.def("make_unit_type_contiguous",
        &makeUnitTypeContiguous,
        "Inplace remap original sc2 unit type id to contiguous id. data is an extracted feature tensor where the type "
        "id is the last element of the last dimension.",
        py::arg("data"),
        py::arg("index"),
        py::arg("mapping"));
}
