#include "distributed.h"

template <class LABEL_T>
DistributedDatasetWrapper<LABEL_T>::DistributedDatasetWrapper(Dataset<LABEL_T> * ds, int world_size, int rank) : world_size(world_size), ds(ds), rank(rank) {
    local_size = (ds->size() - rank - 1) / world_size + 1;
}

template <class LABEL_T>
int DistributedDatasetWrapper<LABEL_T>::get_label(int id) {
    return ds->get_label(id * world_size + rank);
}

template <class LABEL_T>
std::vector<std::vector<Event> > DistributedDatasetWrapper<LABEL_T>::get_binned_input(int id) {
    return ds->get_binned_input(id * world_size + rank);
}

template <class LABEL_T>
int DistributedDatasetWrapper<LABEL_T>::size() {
    return local_size;
}

DistributedModel::DistributedModel(Model model_org, int world_size) : world_size(world_size) {
    model_list.push_back(model_org);
    for (int i = 1; i < world_size; ++i) {
        Model tmp;
        for (auto module_pt : model_org) 
            tmp.push_back(module_pt->distributed_clone());
        model_list.push_back(tmp);
    }
}

DistributedReadOut::DistributedReadOut(ReadOut * readout_org, int world_size) : world_size(world_size) {
    readout_list.push_back(readout_org);
    for (int i = 1; i < world_size; ++i) {
        readout_list.push_back(readout_org->distributed_clone());
    }
}

double test_distributed(DistributedModel & net, Dataset<int> * ds, DistributedReadOut & readout_func, bool verbose, bool binned) {
    double acc = 0;
    std::vector<std::future<double> > result_list;
    std::vector<int > dds_size_list;
    for (int i = 0; i < net.world_size; ++i) {
        DistributedDatasetWrapper<int> * dds = new DistributedDatasetWrapper<int>(ds, net.world_size, i);
        dds_size_list.push_back(dds->size());
        result_list.push_back(std::async(std::bind(test, net.model_list[i], (Dataset<int>*)dds, readout_func.readout_list[i], i == 0 ? verbose : false, binned)));
    }
    for (int i = 0; i < net.world_size; ++i) {
        acc += (double)(result_list[i].get()) * dds_size_list[i];
    }
    acc /= ds->size();
    if (verbose) 
        printf("Merged Acc: %.2lf %%\n", acc * 100);
    return acc;
}