#ifndef UTILS_H
#define UTILS_H

#include <vector>
#include <string>
#include <fstream>
#include <sstream>
#include <filesystem>
#include <exception>

std::vector<std::vector<float>> load_data_csv(const std::string &filename) {
    std::vector<std::vector<float>> data;
    std::ifstream file(filename);
    std::string line;

    // Read the data
    while (std::getline(file, line)) {
        std::vector<float> row;
        std::stringstream ss(line);
        std::string value;

        while (std::getline(ss, value, ',')) {
            row.push_back(std::stof(value));
        }
        data.push_back(row);
    }

    return data;
}

std::vector<std::string> load_key_csv(const std::string &filename) {
    std::vector<std::string> data;
    std::ifstream file(filename);
    std::string line;

    // Read the data
    while (std::getline(file, line)) {
        data.push_back(line);
    }

    return data;
}

void save_data_bin(const std::string &filename, const std::vector<std::vector<float>> &data) {
    std::ofstream file(filename, std::ios::binary);
    size_t rows = data.size();
    size_t cols = data[0].size();

    // Write the dimensions first
    file.write(reinterpret_cast<const char*>(&rows), sizeof(rows));
    file.write(reinterpret_cast<const char*>(&cols), sizeof(cols));

    // Write the data row by row
    for (const auto &row : data) {
        file.write(reinterpret_cast<const char*>(row.data()), row.size() * sizeof(float));
    }
}

void save_key_bin(const std::string &filename, const std::vector<std::string> &data) {
    std::ofstream file(filename, std::ios::binary);

    // Write the number of strings
    size_t num_strings = data.size();
    file.write(reinterpret_cast<const char*>(&num_strings), sizeof(num_strings));

    for (const auto &str : data) {
        // Write the length of each string
        size_t length = str.size();
        file.write(reinterpret_cast<const char*>(&length), sizeof(length));

        // Write the actual string data
        file.write(str.data(), length);
    }
}

std::vector<std::vector<float>> load_data_bin(const std::string &filename) {
    std::ifstream file(filename, std::ios::binary);
    size_t rows, cols;

    // Read the dimensions first
    file.read(reinterpret_cast<char*>(&rows), sizeof(rows));
    file.read(reinterpret_cast<char*>(&cols), sizeof(cols));

    std::vector<std::vector<float>> data(rows, std::vector<float>(cols));

    // Read the data row by row
    for (auto &row : data) {
        file.read(reinterpret_cast<char*>(row.data()), cols * sizeof(float));
    }

    return data;
}

std::vector<std::string> load_key_bin(const std::string &filename) {
    std::ifstream file(filename, std::ios::binary);

    // Read the number of strings
    size_t num_strings;
    file.read(reinterpret_cast<char*>(&num_strings), sizeof(num_strings));

    std::vector<std::string> data(num_strings);

    for (auto &str : data) {
        // Read the length of each string
        size_t length;
        file.read(reinterpret_cast<char*>(&length), sizeof(length));

        // Read the actual string data
        str.resize(length);
        file.read(&str[0], length);
    }

    return data;
}

std::vector<std::vector<float>> load_data(const std::string &filename) {
    if (filename.substr(filename.size() - 4) == ".csv") {
        return load_data_csv(filename);
    } else {
        return load_data_bin(filename);
    }
}

std::vector<std::string> load_key(const std::string &filename) {
    if (filename.substr(filename.size() - 4) == ".csv") {
        return load_key_csv(filename);
    } else {
        return load_key_bin(filename);
    }
}

float calculate_accuracy(const std::vector<float>& predictions, const std::vector<float>& labels) {
    int correct = 0;
    for (size_t i = 0; i < predictions.size(); ++i) {
        if ((predictions[i] >= 0.5 && labels[i] == 1) || (predictions[i] < 0.5 && labels[i] == 0)) {
            correct++;
        }
    }
    return static_cast<float>(correct) / labels.size();
}

float calculate_false_positive_rate(const std::vector<bool> predictions, const std::vector<std::vector<float>>& labels) {
    if (predictions.size() != labels.size()) {
        throw std::runtime_error("Predictions and labels must have the same size.");
    }
    if (predictions.size() == 0) {
        return 0.0;
    }
    int false_positives = 0;
    int true_negatives = 0;
    for (size_t i = 0; i < predictions.size(); ++i) {
        if (predictions[i] && labels[i][0] == 0) {
            false_positives++;
        } else if (!predictions[i] && labels[i][0] == 0) {
            true_negatives++;
        }
    }
    if (false_positives + true_negatives == 0) {
        return 0.0;
    }
    return static_cast<float>(false_positives) / (false_positives + true_negatives);
}

float calculate_false_negative_rate(const std::vector<bool> predictions, const std::vector<std::vector<float>>& labels) {
    if (predictions.size() != labels.size()) {
        throw std::runtime_error("Predictions and labels must have the same size.");
    }
    if (predictions.size() == 0) {
        return 0.0;
    }
    int false_negatives = 0;
    int true_positives = 0;
    for (size_t i = 0; i < predictions.size(); ++i) {
        if (!predictions[i] && labels[i][0] == 1) {
            false_negatives++;
        } else if (predictions[i] && labels[i][0] == 1) {
            true_positives++;
        }
    }
    if (false_negatives + true_positives == 0) {
        return 0.0;
    }
    return static_cast<float>(false_negatives) / (false_negatives + true_positives);
}

float getBinFileSize(const std::string& filePath) {
    std::error_code ec;
    auto fileSize = std::filesystem::file_size(filePath, ec);
    if (ec) {
        throw std::runtime_error("Error getting file size: " + ec.message());
    }
    // Convert size to KB
    return static_cast<float>(fileSize) / 1024.0;
}

#endif // UTILS_H
