import argparse
import os
import sys
import time
import pickle
import json
import random
import numpy as np
import csv
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
from thirdparty.PHBF.filters import HPBF

def load_key(file_path):
    '''
    equivalent to load_key_csv in "../../src/utils/utils.h"
    '''
    with open(file_path, "r") as f:
        return f.read().splitlines()

def load_data(file_path, type="float"):
    '''
    equivalent to load_data_csv in "../../src/utils/utils.h"
    '''

    # Loading data in chunks in order to prevent memory overflow

    chunk_size=10000
    data = []
    dtype = np.float32 if type == "float" else np.int32

    with open(file_path, "r") as f:
        reader = csv.reader(f)
        buffer = []
        for row in reader:
            if type == "int":
                buffer.append([int(x) for x in row])
            elif type == "float":
                buffer.append([float(x) for x in row])
            else:
                raise ValueError("Invalid type")

            if len(buffer) >= chunk_size:
                data.append(np.array(buffer, dtype=dtype))
                buffer = []

        if buffer:
            data.append(np.array(buffer, dtype=dtype))

    return np.vstack(data)

def construct_phbf(
    all_pos_key, all_pos_X, X_val_key, X_val, y_val, phbf_folder_path, bitarray_size, hash_count,
    X_test_key, X_test, y_test, pos_query_ratio, query_num, result_folder_path
):
    ##################
    # Construct PHBF #
    ##################
    
    # If phbf_folder_path does not exist, create it
    if not os.path.exists(phbf_folder_path):
        os.makedirs(phbf_folder_path)

    phbf_model_path = phbf_folder_path + "/phbf_model.pickle"
    phbf_info_path = phbf_folder_path + "/info.json"

    # Construct PHBF
    input_dim = len(all_pos_X[0])
    sample_factor = 10

    # PHBF phbf(bits_assigned_per_partition, hash_count, input_dim, sample_factor)
    hpbf = HPBF(
            bitarray_size,
            hash_count,
            input_dim,
            sample_factor=sample_factor,
        )

    model_construct_start_time = time.time() * 1000

    configuration_start = time.time() * 1000
    pos_X_val = [X for i, X in enumerate(X_val) if y_val[i][0] == 1]
    neg_X_val = [X for i, X in enumerate(X_val) if y_val[i][0] == 0]
    hpbf.initialize(pos_X_val, neg_X_val)
    configuration_end = time.time() * 1000

    add_key_start_time = time.time() * 1000
    hpbf.bulk_add(all_pos_X)
    add_key_end_time = time.time() * 1000

    model_construct_end_time = time.time() * 1000

    phbf_size_bf = bitarray_size / 1000 / 8 # [WARNING] Actual size is larger than this

    # to pickle
    # print("Save the model to", phbf_model_path)
    # with open(phbf_model_path, "wb") as f:
    #     pickle.dump(hpbf, f)
    phbf_info = {
        "model_type": "phbf",
        "bitarray_size": bitarray_size,
        "hash_count": hash_count,
        "model_construct_time_ms": model_construct_end_time - model_construct_start_time,
        "configuration_time_ms": configuration_end - configuration_start,
        "add_keys_time_ms": add_key_end_time - add_key_start_time,
        "model_size_kb": phbf_size_bf
    }
    # Save PHBF info
    with open(phbf_info_path, "w") as f:
        json.dump(phbf_info, f, indent=4)

    #############
    # Test PHBF #
    #############

    result_json_path = f"{result_folder_path}/result.json"
    if not os.path.exists(result_folder_path):
        os.makedirs(result_folder_path)

    # prepare data for testing
    assert(pos_query_ratio == 0.0)
    neg_X_test = [X for i, X in enumerate(X_test) if y_test[i][0] == 0]
    random.seed(0)
    neg_X_test = random.sample(neg_X_test, query_num)
    neg_X_test = np.array(neg_X_test)

    # warm-up
    warmup_num = 10
    test_num = 10
    for i in range(warmup_num):
        hpbf.compute_fpr(neg_X_test)

    # test
    test_time_list = []
    fpr = 1.0
    for i in range(test_num):
        start_time = time.time() * 1000
        fpr = hpbf.compute_fpr(neg_X_test)
        end_time = time.time() * 1000
        test_time_list.append((end_time - start_time))
    false_positives = int(fpr * query_num)

    # Save result
    result = {
        "bf_model_folder_path": phbf_folder_path,
        "pos_query_ratio": pos_query_ratio,
        "query_num": query_num,
        "fpr": fpr,
        "fnr": 0.0,
        "query_pos_num": 0,
        "query_neg_num": query_num,
        "false_positives": false_positives,
        "false_negatives": 0,
        "test_time_ms": sum(test_time_list),
        "test_time_list_ms": test_time_list
    }
    print("Save the result to", result_json_path)
    with open(result_json_path, "w") as f:
        json.dump(result, f, indent=4)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Construct PHBF")
    parser.add_argument("--all_pos_key_path", type=str, required=True, help="Path to all positive key file")
    parser.add_argument("--all_pos_X_path", type=str, required=True, help="Path to all positive X file")
    parser.add_argument("--X_val_key_path", type=str, required=True, help="Path to X validation key file")
    parser.add_argument("--X_val_path", type=str, required=True, help="Path to X validation file")
    parser.add_argument("--y_val_path", type=str, required=True, help="Path to y validation file")
    parser.add_argument("--X_test_key_path", type=str, required=True, help="Path to X test key file")
    parser.add_argument("--X_test_path", type=str, required=True, help="Path to X test file")
    parser.add_argument("--y_test_path", type=str, required=True, help="Path to y test file")
    parser.add_argument("--bit_sizes", type=int, nargs="+", required=True, help="Bit sizes")
    parser.add_argument("--hash_counts", type=int, nargs="+", required=True, help="Hash counts")
    parser.add_argument("--pos_query_ratios", type=float, nargs="+", required=True, help="Positive query ratios")
    parser.add_argument("--query_nums", type=int, nargs="+", required=True, help="Query numbers")
    parser.add_argument("--model_dir_root", type=str, required=True, help="Root directory to save PHBF models")
    parser.add_argument("--result_dir_root", type=str, required=True, help="Root directory to save PHBF results")
    
    # pos_query_ratios=(0.0)
    # query_nums=(40000)

    args = parser.parse_args()

    # Load data
    all_pos_key = load_key(args.all_pos_key_path)
    all_pos_X = load_data(args.all_pos_X_path, type="float")
    X_val_key = load_key(args.X_val_key_path)
    X_val = load_data(args.X_val_path, type="float")
    y_val = load_data(args.y_val_path, type="int")
    X_test_key = load_key(args.X_test_key_path)
    X_test = load_data(args.X_test_path, type="float")
    y_test = load_data(args.y_test_path, type="int")
    print("Data loaded")

    for bit_size in args.bit_sizes:
        for hash_count in args.hash_counts:
            phbf_folder_path = f"{args.model_dir_root}/bit_size_{bit_size}_hash_count_{hash_count}"
            for pos_query_ratio in args.pos_query_ratios:
                pos_query_ratio_str = str(pos_query_ratio).replace(".", "_")
                for query_num in args.query_nums:
                    result_folder_path = f"{args.result_dir_root}/bit_size_{bit_size}_hash_count_{hash_count}/pr_{pos_query_ratio_str}_qn_{query_num}"
                    if os.path.exists(result_folder_path):
                        print(f"Result folder already exists at {result_folder_path}. Skipping constructing the PHBF.")
                        continue
                    assert(pos_query_ratio == 0.0), "Has not implemented for pos_query_ratio != 0.0"
                    construct_phbf(
                        all_pos_key, all_pos_X, X_val_key, X_val, y_val, phbf_folder_path, bit_size, hash_count,
                        X_test_key, X_test, y_test, pos_query_ratio, query_num, result_folder_path
                    )
