#include "dataset.h"
#include "distributed.h"
#include "nn.h"
#include "utils.h"
#include "event.h"

// DVS-Gesture

float_T min_v_mem = -1.0;
SpikingConv2d net_0(2, 8, 3, 1, 32, 32, 1, 1, 2, 2, min_v_mem);
SpikingConv2d net_1(8, 16, 3, 1, 16, 16, 1, 1, 2, 1, min_v_mem);
SpikingConv2d net_2(16, 32, 3, 1, 8, 8, 1, 1, 2, 1, min_v_mem);
SpikingConv2d net_3(32, 32, 3, 1, 4, 4, 1, 1, 2, 1, min_v_mem);
Flatten2d fltn(32, 4, 4);
SpikingLinear fc1(32 * 4 * 4, 64, 1, 1, min_v_mem);
SpikingLinear fc2(64, 11, 1e10, 0, -1e10); // MembraneOut

Model my_net = {(Module*)&net_0, (Module*)&net_1, (Module*)&net_2, (Module*)&net_3, (Module*)&fltn, (Module*)&fc1, (Module*)&fc2};
auto ds = DVSGesture("[Dataset Path]", 1, false);
ReadOut * readout = new MembraneArgmax(11);


std::string model_path("../result_tmp/MTT_DVSGESTURE_BASET40_woBN_AdamW20_net.");


int main() {
    srand(233);
    net_0.load_weight(
        (model_path + "2.weight.txt").c_str(), 2);
    net_1.load_weight(
        (model_path + "6.weight.txt").c_str());
    net_2.load_weight(
        (model_path + "11.weight.txt").c_str());
    net_3.load_weight(
        (model_path + "15.weight.txt").c_str());
    fc1.load_weight(
        (model_path + "20.weight.txt").c_str());
    fc2.load_weight(
        (model_path + "22.weight.txt").c_str());
    DistributedModel dist_net(my_net, 40);
    DistributedReadOut dist_readout(readout, 40);
    auto st = cross_sys_clock_ms();
    test_distributed(dist_net, &ds, dist_readout, true, false);
    printf("%lld ms", cross_sys_clock_ms() - st);
    return 0;
}
