#include "utils.h"

long long cross_sys_clock_ms() {

    auto now = std::chrono::system_clock::now();
    auto millisec_since_epoch = std::chrono::time_point_cast<std::chrono::milliseconds>(now).time_since_epoch();
    return millisec_since_epoch.count();

}

long long cross_sys_clock_us() {

    auto now = std::chrono::system_clock::now();
    auto millisec_since_epoch = std::chrono::time_point_cast<std::chrono::microseconds>(now).time_since_epoch();
    return millisec_since_epoch.count();

}

long long cross_sys_clock_ns() {

    auto now = std::chrono::system_clock::now();
    auto millisec_since_epoch = std::chrono::time_point_cast<std::chrono::nanoseconds>(now).time_since_epoch();
    return millisec_since_epoch.count();

}


void visualize_2d(const std::vector<std::vector<Event> > &x, int out_H, int out_W, int t_min, int t_max, int mode, int threshold) {
    int buf[out_H + 2][out_W + 2] = {0};
	for (int i = 1; i <= out_W; ++i) 
		putchar('-');
	puts("");
	for (auto i : x) {
		for (auto j : i) {
			if (j.ts < t_min || j.ts > t_max)
				continue;
                
			assert (j.x < out_H && j.y < out_W);
			assert (j.x >= 0 && j.y >= 0);
				
			if (mode >= 0 && j.c == mode) 
				buf[j.x][j.y]++;
			if (mode == -1)
				buf[j.x][j.y]++;
		}	
	}
	for (int i = 0; i < out_H; ++i) {
		for (int j = 0; j < out_W; ++j)
			putchar(buf[i][j] >= threshold ? '#' : ' ');
		puts("");
	}
}

std::vector<Event> run_net(std::vector<Module*> & net, std::vector<Event> x) {
    for (auto &module_pt : net) {
        // printf("spike num: %ld\n", x.size());
        x = module_pt->forward(x);
        // x.resize((int)(x.size() * 0.8));
    }
    return x;
}

std::vector<std::vector<Event> > run_net(std::vector<Module*> & net, std::vector<std::vector<Event> > binned_x) {
    for (auto &module_pt : net) {
        binned_x = module_pt->forward(binned_x);
    }
    return binned_x;
}

double test(std::vector<Module*> & net, Dataset<int> * ds, ReadOut * readout_func, bool verbose, bool binned) {
    int N = ds->size(), cor = 0, ret;
    int total_bar_len = std::min(N, 20);
    std::vector<Event > spike_out;

    if (verbose) {
        printf("Testing %d pts per thread\n", N);
        for (int _ = 1; _ <= total_bar_len; ++_)
            putchar('-');
        puts("");
    }
    for (int i = 0; i < N; ++i) {
        if (verbose) {
            if (i / (N / total_bar_len) > (i - 1) / (N / total_bar_len) || i == 0) {
                printf("#");
                fflush(stdout); // for distributed output
            }
        }
        if (binned == true) 
            spike_out = std::move(unbin_events(run_net(net, ds->get_binned_input(i))));
        else spike_out = std::move(run_net(net, unbin_events(ds->get_binned_input(i))));
        
        ret = readout_func->to_class(net, spike_out);

        if(ds->get_label(i) == ret) ++cor;
    }
    double acc = 1.0 * cor / N;
    if (verbose) {
        printf("Done\n");
        puts("--------------------");
    }
    return acc;
}