/*
  Name:     main.c
  Purpose:  Main function for benchmarking LUT, ALDR and other methods.
  Authors:  CMU Probabilistic Computing Systems Lab
            And An Anonymous Lab
  Copyright (C) 2025 CMU Probabilistic Computing Systems Lab, All Rights Reserved.

  Released under Apache 2.0; refer to LICENSE.txt
*/

#include <stdio.h>
#include <unistd.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <stdbool.h>
#include <time.h>
#include <getopt.h>
#include <sys/stat.h>
#include <sys/types.h>

#include "rapl.h"
#include "dist_loader.h"
#include "algorithms/numpy_sampling.h"
#include "algorithms/nLUT.h"
#include "algorithms/cLUT.h"

#include "macros.h"

// Includes from amplified-loaded-dice-roller[-experiments]
#include <alias.h>
#include <flip.h>
#include <aldr.h>

static int _dummy_func(void *_unsused)
{
    (void)_unsused;
    return 0;
}

struct cLUT_log_context
{
    size_t r, c;
};

static struct cLUT_log_context save_cLUT_log_context(clut_context_t *ctx)
{
    return (struct cLUT_log_context) {
        .r = ctx->r,
        .c = ctx->c,
    };
}

int main(int argc, char *argv[]) {
    bool print_headers = false;
    bool memory_flag = false;
    char *sampler = NULL;
    char *dist_file = NULL;
    distribution_t *distribution = NULL;
    
    static struct option long_options[] = {
        {"headers", no_argument, 0, 'h'},
        {"memory",  no_argument, 0, 'm'},
        {"help",    no_argument, 0, '?'},
        {0, 0, 0, 0}
    };
    const char *usage_str = "Usage: %s [--headers] [--memory] | [SAMPLER DIST_FILE]\n";
    const char *get_opt_str = "hm";
    
    int option_index = 0;
    int c;
    
    while ((c = getopt_long(argc, argv, get_opt_str, long_options, &option_index)) != -1) {
        switch (c) {
            case 'h':
                print_headers = true;
                break;
            case 'm':
                memory_flag = true;
                break;
            default:
                fprintf(stderr, usage_str, argv[0]);
                return 1;
        }
    }
    
    if (print_headers) {
        if (optind < argc) {
            fprintf(stderr, "Error: --headers option doesn't take additional arguments\n");
            return 1;
        }
    } else {
        if (argc - optind != 2) {
            fprintf(stderr, usage_str, argv[0]);
            return 1;
        }
        
        sampler = argv[optind];
        dist_file = argv[optind + 1];
        
        distribution = load_distribution(dist_file);
        if (!distribution) {
            fprintf(stderr, "Cannot open distribution file '%s'.\n", dist_file);
            return 2;
        }
    }

    if (!memory_flag) {

        if (rapl_setup() < 0) {
            fprintf(stderr, "\tNo perf_event rapl support found (requires Linux 3.14 + sudo or low paranoid-level)\n");
            return 1;
        }

        if (print_headers) {
            printf("sampler dist_file x preprocess_time_cold preprocess_time_warm sample_time sample_bits preprocess_bytes cLUT_r cLUT_c ");
            rapl_start(); // Only opens accessible counters. Will be closed on exit
            rapl_print_domains(stdout, "preprocess_cold");
            rapl_print_domains(stdout, "preprocess_warm");
            rapl_print_domains(stdout, "sample");
            printf("\n");
            return 0;
        }

    }

    int num_samples = 1000000;
    int num_preprocess_warm = 10;

    unsigned long long preprocess_time_cold = 0;
    unsigned long long preprocess_time_warm = 0;
    unsigned long long sample_time = 0;
    size_t preprocess_bytes = 0;
    size_t energy_entries = 0;
    rapl_dst_t *preprocess_energy_cold = NULL;
    rapl_dst_t *preprocess_energy_warm = NULL;
    rapl_dst_t *sample_energy = NULL;

    uint64_t x = 0;
    int _dummy_dst;
    struct cLUT_log_context cLUT_log;

    #define SAMPLE_EXPERIMENT(key, context_type, func_preprocess, post_experiment_action, post_experiment_dst, func_sample, func_free) \
        READ_PREPROCESS_SAMPLE_TIME(key, \
            sampler, \
            context_type, \
            func_preprocess, \
            post_experiment_action, \
            func_sample, \
            func_free, \
            memory_flag, \
            distribution->array, \
            distribution->n, \
            num_samples, \
            num_preprocess_warm, \
            preprocess_time_cold, \
            preprocess_time_warm, \
            post_experiment_dst, \
            sample_time, \
            preprocess_energy_cold, \
            preprocess_energy_warm, \
            sample_energy, \
            energy_entries, \
            x)

    SAMPLE_EXPERIMENT("aldr.flat",
        struct aldr_flat_s,
        preprocess_aldr_flat,
        bytes_sample_aldr_flat,
        preprocess_bytes,
        sample_aldr_flat,
        free_aldr_flat_s)
    else SAMPLE_EXPERIMENT("fldr.flat",
        struct aldr_flat_s,
        preprocess_fldr_flat,
        bytes_sample_aldr_flat,
        preprocess_bytes,
        sample_aldr_flat,
        free_aldr_flat_s)
    else SAMPLE_EXPERIMENT("aldr.enc",
        struct array_s,
        preprocess_aldr_enc,
        bytes_array,
        preprocess_bytes,
        sample_aldr_enc,
        free_array_s)
    else SAMPLE_EXPERIMENT("fldr.enc",
        struct array_s,
        preprocess_fldr_enc,
        bytes_array,
        preprocess_bytes,
        sample_aldr_enc,
        free_array_s)
    else SAMPLE_EXPERIMENT("alias.c",
        struct sample_weighted_alias_index_s,
        preprocess_weighted_alias,
        bytes_sample_weighted_alias_index,
        preprocess_bytes,
        sample_weighted_alias_index,
        free_sample_weighted_alias_index)
    else SAMPLE_EXPERIMENT("cLUT",
        clut_context_t,
        build_cLUT,
        save_cLUT_log_context,
        cLUT_log,
        cLUT_sampling,
        free_cLUT)
    else SAMPLE_EXPERIMENT("nLUT",
        nlut_context_t,
        build_nLUT,
        _dummy_func,
        _dummy_dst,
        nLUT_sampling,
        free_nLUT)
    else SAMPLE_EXPERIMENT("numpy",
        numpy_context_t,
        build_numpy,
        _dummy_func,
        _dummy_dst,
        numpy_sampling,
        free_numpy)
    else {
        printf("Unknown sampler %s\n", sampler);
        return -1;
    }
    (void)_dummy_dst;

    if (memory_flag) {
        // No further printing
        return 0;
    }

    double d_preprocess_time_cold = ((double)preprocess_time_cold) / 1e9;
    double d_preprocess_time_warm = ((double)preprocess_time_warm) / 1e9 / num_preprocess_warm;
    double d_sample_time = ((double)sample_time) / 1e9 / num_samples;
    double d_sample_bits = ((double)(NUM_RNG_CALLS*flip_k-flip_pos)) / num_samples;

    
    printf("%s %s %"PRIu64" %1.12f %1.12f %1.15f %1.8f %zu",
            sampler,
            dist_file,
            x,
            d_preprocess_time_cold,
            d_preprocess_time_warm,
            d_sample_time,
            d_sample_bits,
            preprocess_bytes);

    if (strcmp(sampler, "cLUT") == 0) {
        printf(" %zu %zu", cLUT_log.r, cLUT_log.c);
    } else {
        printf(" - -");
    }

    rapl_dst_t *energy_readings[] = {preprocess_energy_cold, preprocess_energy_warm, sample_energy, NULL};

    for (rapl_dst_t **r = energy_readings; *r != NULL; r++) {
        // FIXME: assumes energy_entries was the same for every run.
        for (size_t i = 0; i < energy_entries; i++) {
            printf(" %" RAPL_DST_FMT, (*r)[i]);
        }
        free(*r);
    }

    printf("\n");
    free_distribution(distribution);

    return 0;
}
