#include <cupti.h>
#include <cuda.h>
#include <stdio.h>
#include <stdlib.h>
#include <mutex>
#include <string.h>
#include <stdint.h>
#include <vector>
#include <string>
#include <cstring>
#include <cxxabi.h>
#include <unistd.h>
static bool tracingEnabled = false;



static const std::vector<std::string> allowed_prefixes = {
    "cudaLaunchKernel",
    "cudaGraphLaunch",
    "cudaMemcpy",
    "cudaMemset",
    "cudaMalloc",
    "cudaFree",
    "cudaIpc",
    "cudaHostAlloc",
    "cudaDeviceSynchronize",
    "cudaStreamSynchronize",
    "cudaStreamWaitEvent"
};

bool has_allowed_prefix(const char* funcName) {
    if (!funcName) return false;

    for (const auto& prefix : allowed_prefixes) {
        if (strncmp(funcName, prefix.c_str(), prefix.size()) == 0) {
            return true;
        }
    }
    return false;
}


void write_json(FILE* fp, const char* json_str) {
    if (fp != nullptr) {
        fprintf(fp, "%s\n", json_str);
        fflush(fp);  
    }
}

static std::string demangle(const char* name) {
    int status = 0;
    char* demangled = abi::__cxa_demangle(name, nullptr, nullptr, &status);
    std::string result = (status == 0 && demangled) ? demangled : name;
    free(demangled);
    return result;
}


static int g_device_id = -1;
static FILE* g_fp = nullptr;
static std::mutex g_mutex;

#include <sys/stat.h>  
#include <errno.h>      
static void open_output_file_if_needed() {
    std::lock_guard<std::mutex> lock(g_mutex);
    if (g_fp) return;

    const char* output_dir = "output";
    if (mkdir(output_dir, 0777) == -1 && errno != EEXIST) {
        fprintf(stderr, "❌ Failed to create directory %s (%s)\n", 
                output_dir, strerror(errno));
        return;
    }

    char filename[256];
    snprintf(filename, sizeof(filename), 
             "%s/output_pid%d.tmp.jsonl",  
             output_dir, getpid());

    g_fp = fopen(filename, "a");
    if (!g_fp) {
        fprintf(stderr, "❌ Failed to open log file %s (%s)\n", 
                filename, strerror(errno));
    } else {
        fprintf(stderr, "✅ Logging to %s\n", filename);
    }
}



static void CUPTIAPI activityBufferRequested(uint8_t **buffer, size_t *size, size_t *maxNumRecords) {
    *size = 16 * 1024;
    *buffer = (uint8_t *)malloc(*size);
    *maxNumRecords = 0;
    // open_output_file_if_needed(); 
}



void CUPTIAPI activityBufferCompleted(CUcontext ctx,
                                      uint32_t streamId,
                                      uint8_t* buffer,
                                      size_t size,
                                      size_t validSize) {
    CUpti_Activity* record = NULL;
    CUptiResult status;

    open_output_file_if_needed(); 
    while ((status = cuptiActivityGetNextRecord(buffer, validSize, &record)) == CUPTI_SUCCESS) {
        char json[4096];  
        switch (record->kind) {
            case CUPTI_ACTIVITY_KIND_MARKER: {
                CUpti_ActivityMarker2* marker = (CUpti_ActivityMarker2*)record;

                uint32_t processId = 0;
                uint32_t threadId = 0;


                processId = marker->objectId.pt.processId;
                threadId = marker->objectId.pt.threadId;

                const char* name = marker->name ? marker->name : "null";

                snprintf(json, sizeof(json),
                    "{ \"type\": \"NVTX_MARKER\", \"name\": \"%s\", \"timestamp\": %lu, \"id\": %u, \"process_id\": %u, \"thread_id\": %u }",
                    name, marker->timestamp, marker->id, processId, threadId);

                write_json(g_fp, json);
                break;
            }



            case CUPTI_ACTIVITY_KIND_KERNEL: {
                CUpti_ActivityKernel9* kernel = (CUpti_ActivityKernel9*)record;

                // std::string demangledName = demangle(kernel->name);
                std::string demangledName = kernel->name;

                snprintf(json, sizeof(json),
                    "{ \"type\": \"KERNEL\", \"name\": \"%s\", \"gpu_start\": %lu, \"gpu_end\": %lu, \"duration\": %lu, \"correlation_id\": %u }",
                    demangledName.c_str(),
                    kernel->start,
                    kernel->end,
                    kernel->end - kernel->start,
                    kernel->correlationId);

                write_json(g_fp, json);
                break;
            }

            case CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL: {
                CUpti_ActivityKernel9* kernel = (CUpti_ActivityKernel9*)record;

                // std::string demangledName = demangle(kernel->name);
                std::string demangledName = kernel->name;

                snprintf(json, sizeof(json),
                    "{ \"type\": \"CONCURRENT_KERNEL\", \"name\": \"%s\", \"gpu_start\": %lu, \"gpu_end\": %lu, \"duration\": %lu, \"correlation_id\": %u }",
                    demangledName.c_str(),
                    kernel->start,
                    kernel->end,
                    kernel->end - kernel->start,
                    kernel->correlationId);
                write_json(g_fp, json);
                break;
            }


            
            case CUPTI_ACTIVITY_KIND_RUNTIME: {
                CUpti_ActivityAPI* runtime = (CUpti_ActivityAPI*)record;

                const char* funcName = nullptr;
                CUptiResult res = cuptiGetCallbackName(CUPTI_CB_DOMAIN_RUNTIME_API, runtime->cbid, &funcName);

                if (res != CUPTI_SUCCESS || funcName == nullptr) {
                    funcName = "Unknown";
                }

                if (has_allowed_prefix(funcName)) {
                    snprintf(json, sizeof(json),
                        "{ \"type\": \"RUNTIME\", \"cbid\": %u, \"name\": \"%s\", \"start\": %lu, \"end\": %lu, \"duration\": %lu, \"correlation_id\": %u, \"process_id\": %u, \"thread_id\": %u }",
                        runtime->cbid, funcName,
                        runtime->start, runtime->end, runtime->end - runtime->start,
                        runtime->correlationId,
                        runtime->processId,
                        runtime->threadId);
                    write_json(g_fp, json);
                }

                break;
            }

            case CUPTI_ACTIVITY_KIND_DRIVER: {
                CUpti_ActivityAPI* driver = (CUpti_ActivityAPI*)record;

                const char* funcName = nullptr;
                CUptiResult res = cuptiGetCallbackName(CUPTI_CB_DOMAIN_DRIVER_API, driver->cbid, &funcName);

                if (res != CUPTI_SUCCESS || funcName == nullptr) {
                    funcName = "Unknown";
                }

                if (strncmp(funcName, "cuLaunchKernel", strlen("cuLaunchKernel")) == 0) {
                    snprintf(json, sizeof(json),
                        "{ \"type\": \"DRIVER\", \"cbid\": %u, \"name\": \"%s\", \"start\": %lu, \"end\": %lu, \"duration\": %lu, \"correlation_id\": %u, \"process_id\": %u, \"thread_id\": %u }",
                        driver->cbid, funcName,
                        driver->start, driver->end, driver->end - driver->start,
                        driver->correlationId,
                        driver->processId,
                        driver->threadId);
                    write_json(g_fp, json);
                }

                break;
            }


            default:
                break;
        }
    }

    if (status != CUPTI_SUCCESS && status != CUPTI_ERROR_MAX_LIMIT_REACHED) {
        const char* errstr;
        cuptiGetResultString(status, &errstr);
        fprintf(stderr, "❌ CUPTI error: %s\n", errstr);
    }

    // fclose(g_fp);
    free(buffer);
}




extern "C" int InitializeInjection() {

    // cuInit(0); 

    if (tracingEnabled) return 1;


    //Enable activity kinds (kernel + memcpy)
    cuptiActivityEnable(CUPTI_ACTIVITY_KIND_KERNEL);
    cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL);
    //cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MEMCPY);
    cuptiActivityEnable(CUPTI_ACTIVITY_KIND_RUNTIME);
    cuptiActivityEnable(CUPTI_ACTIVITY_KIND_DRIVER);

    // Enable NVTX
    cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MARKER);
    //cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MARKER_DATA);

    // Register buffer management
    cuptiActivityRegisterCallbacks(activityBufferRequested, activityBufferCompleted);
    
    tracingEnabled = true;
    return 1;
}


__attribute__((constructor))
static void init() {
    InitializeInjection();
}


__attribute__((destructor))
static void fini() {
    if (tracingEnabled) {
        cuptiActivityFlushAll(0);
        if (g_fp) fclose(g_fp);
    }
}
