#include "masking.h"
#include "omp.h"

std::tuple<std::vector<torch::Tensor>, std::vector<torch::Tensor>>
gen_cache_mask_cpu(std::vector<torch::Tensor> caches,
                    int dst_id, torch::Tensor dst, std::vector<torch::Tensor> dst_bndries,
                    std::vector<torch::Tensor> reuse_masks,
                    int num_threads) {

    int16_t num_parts = caches.size();

    // calculate dst offsets
    auto dst_data = dst.data_ptr<int64_t>();
    std::vector<int64_t> dst_start_offsets;
    std::vector<int64_t> dst_end_offsets;
    int64_t total_dst_bndry_size = 0;
    for (int16_t i = 0; i < num_parts; i++) {
        if ((int16_t) dst_id == i)
            continue;
        total_dst_bndry_size += dst_bndries[i].size(0);
    }
    int64_t dst_batch_size = dst.size(0) - total_dst_bndry_size;
    int64_t cur_dst_offset = dst_batch_size;
    for (int16_t i = 0; i < dst_bndries.size(); i++) {
        if ((int16_t) dst_id == i) {
            dst_start_offsets.push_back(0);
            dst_end_offsets.push_back(dst.size(0));
        } else {
            dst_start_offsets.push_back(cur_dst_offset);
            cur_dst_offset += dst_bndries[i].size(0);
            dst_end_offsets.push_back(cur_dst_offset);
        }
    }
    // check the max offset is equal to the size of the tensor
    AT_ASSERTM(cur_dst_offset == dst.size(0), "The sum of dst_bndries should be equal to the size of the tensor");

    std::vector<std::vector<int64_t>> reuse_elements(num_parts, std::vector<int64_t>());
    #pragma omp parallel for num_threads(num_parts)
    for(int16_t i = 0; i < num_parts; i++) {
        auto reuse_mask_data = reuse_masks[i].data_ptr<int64_t>();
        for (int64_t j = 0; j < reuse_masks[i].size(0); j++) {
            reuse_elements[i].push_back(dst_data[reuse_mask_data[j]]);
        }
    }


    std::vector<std::vector<int64_t>> cache_mask(num_parts, std::vector<int64_t>());
    #pragma omp parallel for num_threads(num_parts)
    for(int16_t i = 0; i < num_parts; i++) {
        auto cache_data = caches[i].data_ptr<int64_t>();
        std::set<int64_t> reuse_set(reuse_elements[i].begin(), reuse_elements[i].end());
        std::set<int64_t> dst_set(dst_data + dst_start_offsets[i], dst_data + dst_end_offsets[i]);
        for (int64_t j = 0; j < caches[i].size(0); j++) {
            if ((reuse_set.find(cache_data[j]) == reuse_set.end())
            && (dst_set.find(cache_data[j]) != dst_set.end())){
                cache_mask[i].push_back(j);
            }
        }
    }

    std::vector<std::vector<int64_t>> dst_mask(num_parts, std::vector<int64_t>());
    #pragma omp parallel for num_threads(num_parts)
    for(int16_t i = 0; i < num_parts; i++) {
        auto cache_data = caches[i].data_ptr<int64_t>();
        std::set<int64_t> reuse_set(reuse_elements[i].begin(), reuse_elements[i].end());
        std::set<int64_t> cache_set(cache_data, cache_data + caches[i].size(0));
        for (int64_t j = dst_start_offsets[i]; j < dst_end_offsets[i]; j++) {
            if ((reuse_set.find(dst_data[j]) == reuse_set.end())
            && (cache_set.find(dst_data[j]) != cache_set.end())){
                dst_mask[i].push_back(j);
            }
        }
    }

    // condition check
    // #pragma omp parallel for num_threads(num_parts)
    for (int16_t i = 0; i < num_parts; i++) {
        // std::cout << "part" << i << ": " << cache_mask[i].size() << " vs. " << dst_mask[i].size() << std::endl;
        AT_ASSERTM(cache_mask[i].size() == dst_mask[i].size(), "src_mask and dst_mask should have the same size");
    }

    // change masks to list of tensors
    std::vector<torch::Tensor> cache_mask_tensors;
    std::vector<torch::Tensor> dst_mask_tensors;
    for (int16_t i = 0; i < num_parts; i++) {
        cache_mask_tensors.push_back(torch::from_blob(cache_mask[i].data(), {(int64_t)cache_mask[i].size()}, caches[i].options()).clone());
        dst_mask_tensors.push_back(torch::from_blob(dst_mask[i].data(), {(int64_t)dst_mask[i].size()}, dst.options()).clone());
    }

    return std::make_tuple(cache_mask_tensors, dst_mask_tensors);

}

std::tuple<std::vector<torch::Tensor>, std::vector<torch::Tensor>>
gen_reuse_mask_cpu(int src_id, torch::Tensor src, std::vector<torch::Tensor> src_bndries,
              int dst_id, torch::Tensor dst, std::vector<torch::Tensor> dst_bndries,
              int num_threads) {

    AT_ASSERTM(src_id != dst_id, "src_id and dst_id should not be the same");
    AT_ASSERTM(src_bndries.size() == dst_bndries.size(), "src_bndries and dst_bndries should have the same size");

    int16_t num_parts = src_bndries.size();

    // calculate src offsets
    auto src_data = src.data_ptr<int64_t>();
    std::vector<int64_t> src_start_offsets;
    std::vector<int64_t> src_end_offsets;
    int64_t total_src_bndry_size = 0;
    for (int16_t i = 0; i < num_parts; i++) {
        if ((int16_t) src_id == i)
            continue;
        total_src_bndry_size += src_bndries[i].size(0);
    }
    int64_t src_batch_size = src.size(0) - total_src_bndry_size;
    int64_t cur_src_offset = src_batch_size;
    for (int16_t i = 0; i < src_bndries.size(); i++) {
        if ((int16_t) src_id == i) {
            src_start_offsets.push_back(0);
            src_end_offsets.push_back(src.size(0));
        } else {
            src_start_offsets.push_back(cur_src_offset);
            cur_src_offset += src_bndries[i].size(0);
            src_end_offsets.push_back(cur_src_offset);
        }
    }
    // check the max offset is equal to the size of the tensor
    AT_ASSERTM(cur_src_offset == src.size(0), "The sum of src_bndries should be equal to the size of the tensor");

    // calculate dst offsets
    auto dst_data = dst.data_ptr<int64_t>();
    std::vector<int64_t> dst_start_offsets;
    std::vector<int64_t> dst_end_offsets;
    int64_t total_dst_bndry_size = 0;
    for (int16_t i = 0; i < num_parts; i++) {
        if ((int16_t) dst_id == i)
            continue;
        total_dst_bndry_size += dst_bndries[i].size(0);
    }
    int64_t dst_batch_size = dst.size(0) - total_dst_bndry_size;
    int64_t cur_dst_offset = dst_batch_size;
    for (int16_t i = 0; i < dst_bndries.size(); i++) {
        if ((int16_t) dst_id == i) {
            dst_start_offsets.push_back(0);
            dst_end_offsets.push_back(dst.size(0));
        } else {
            dst_start_offsets.push_back(cur_dst_offset);
            cur_dst_offset += dst_bndries[i].size(0);
            dst_end_offsets.push_back(cur_dst_offset);
        }
    }
    // check the max offset is equal to the size of the tensor
    AT_ASSERTM(cur_dst_offset == dst.size(0), "The sum of dst_bndries should be equal to the size of the tensor");

    std::vector<std::vector<int64_t>> src_mask(num_parts, std::vector<int64_t>());
    #pragma omp parallel for num_threads(num_parts)
    for(int16_t i = 0; i < num_parts; i++) {
        std::set<int64_t> dst_set(dst_data + dst_start_offsets[i], dst_data + dst_end_offsets[i]);
        for (int64_t j = src_start_offsets[i]; j < src_end_offsets[i]; j++) {
            if (dst_set.find(src_data[j]) != dst_set.end()) {
                src_mask[i].push_back(j);
            }
        }
    }

    std::vector<std::vector<int64_t>> dst_mask(num_parts, std::vector<int64_t>());
    #pragma omp parallel for num_threads(num_parts)
    for(int16_t i = 0; i < num_parts; i++) {
        std::set<int64_t> src_set(src_data + src_start_offsets[i], src_data + src_end_offsets[i]);
        for (int64_t j = dst_start_offsets[i]; j < dst_end_offsets[i]; j++) {
            if (src_set.find(dst_data[j]) != src_set.end()) {
                dst_mask[i].push_back(j);
            }
        }
    }

    // condition check
    #pragma omp parallel for num_threads(num_parts)
    for (int16_t i = 0; i < num_parts; i++) {
        // std::cout << "part" << i << ": " << src_mask[i].size() << " vs. " << dst_mask[i].size() << std::endl;
        AT_ASSERTM(src_mask[i].size() == dst_mask[i].size(), "src_mask and dst_mask should have the same size");
    }

    // change masks to list of tensors
    std::vector<torch::Tensor> src_mask_tensors;
    std::vector<torch::Tensor> dst_mask_tensors;
    for (int16_t i = 0; i < num_parts; i++) {
        src_mask_tensors.push_back(torch::from_blob(src_mask[i].data(), {(int64_t)src_mask[i].size()}, src.options()).clone());
        dst_mask_tensors.push_back(torch::from_blob(dst_mask[i].data(), {(int64_t)dst_mask[i].size()}, dst.options()).clone());
    }

    return std::make_tuple(src_mask_tensors, dst_mask_tensors);
}