#include "dataset.h"
#include "distributed.h"
#include "nn.h"
#include "utils.h"
#include "event.h"

// CIFAR10DVS
// SpikingConv2d(int in_C, int out_C, int k_size, int padding, int out_H, int out_W, float_T Vth=1.0, bool soft_reset=0, int avg_pool_size=1, int stride=1, float_T min_v_mem=-1.0)
float_T min_v_mem = -1.0;
bool is_soft_reset = false;
SpikingConv2d net_0(2, 64, 3, 1, 32, 32, 1, is_soft_reset, 4, 1, min_v_mem);
SpikingConv2d net_1(64, 128, 3, 1, 32, 32, 1, is_soft_reset, 1, 1, min_v_mem);
SpikingConv2d net_2(128, 256, 3, 1, 16, 16, 1, is_soft_reset, 2, 1, min_v_mem);
SpikingConv2d net_3(256, 256, 3, 1, 16, 16, 1, is_soft_reset, 1, 1, min_v_mem);
SpikingConv2d net_4(256, 512, 3, 1, 8, 8, 1, is_soft_reset, 2, 1, min_v_mem);
SpikingConv2d net_5(512, 512, 3, 1, 8, 8, 1, is_soft_reset, 1, 1, min_v_mem);
SpikingConv2d net_6(512, 512, 3, 1, 4, 4, 1, is_soft_reset, 2, 1, min_v_mem);
SpikingConv2d net_7(512, 512, 3, 1, 4, 4, 1, is_soft_reset, 1, 1, min_v_mem);
SumPool2d sumP(2);
Flatten2d fltn(512, 2, 2);
SpikingLinear fc(512 * 2 * 2, 10, 1e10, 0, -1e10);

Model VGG9 = {(Module*)&net_0, (Module*)&net_1, (Module*)&net_2, (Module*)&net_3, (Module*)&net_4, (Module*)&net_5, (Module*)&net_6, (Module*)&net_7, (Module*)&sumP, (Module*)&fltn, (Module*)&fc};
auto ds = CIFAR10DVS("[Dataset Path]", 1, false);
ReadOut * readout = new MembraneArgmax(10);

std::string model_path("../result_tmp/MTT_FINETUNE_SGD_60_MERGED_net.");

double div_cnt = 1;
// ReadOut * readout = new SpikeCountArgmax();

int main() {
    srand(233);
    net_0.load_weight(
        (model_path + "2.weight.txt").c_str(), 1.5);
    net_1.load_weight(
        (model_path + "5.weight.txt").c_str(), div_cnt);
    net_2.load_weight(
        (model_path + "10.weight.txt").c_str(), div_cnt);
    net_3.load_weight(
        (model_path + "13.weight.txt").c_str(), div_cnt);
    net_4.load_weight(
        (model_path + "18.weight.txt").c_str(), div_cnt);
    net_5.load_weight(
        (model_path + "21.weight.txt").c_str(), div_cnt);
    net_6.load_weight(
        (model_path + "26.weight.txt").c_str(), div_cnt);
    net_7.load_weight(
        (model_path + "29.weight.txt").c_str(), div_cnt);
    fc.load_weight(
        (model_path + "34.weight.txt").c_str(), 4);
    DistributedModel dist_net(VGG9, 1);
    DistributedReadOut dist_readout(readout, 1);
    auto st = cross_sys_clock_ms();
    test_distributed(dist_net, &ds, dist_readout, true, false);
    printf("%lld ms", cross_sys_clock_ms() - st);
    return 0;
}
