import json
import numpy as np
import joblib
import argparse
import random
from sklearn.metrics import f1_score
import string

def load_model(scaler_dump_path:str, label_dump_path:str, model_dump_path:str):
    return joblib.load(scaler_dump_path), joblib.load(label_dump_path), joblib.load(model_dump_path)


def eval(tester, file_path:str, label_encoder, clf, in_name:str):
    redicted_labels_encoded = clf.predict(tester)
    predicted_labels = label_encoder.inverse_transform(redicted_labels_encoded)

    # Write predicted labels to a .txt file
    with open(file_path, "w") as file:
        for label in predicted_labels:
            file.write(f"{label}\n")

    
    prompts = load_prompts_from_txt(f"data/{in_name}.txt")
    expected = [label for label in prompts for _ in range(5)]
    true_labels_encoded = label_encoder.transform(expected)
    f1 = f1_score(true_labels_encoded, redicted_labels_encoded, average='weighted')
    print(f"F1 Score: {f1}")

    print("Predicted labels have been written to predicted_labels.txt")


def pad_list(input_list, target_length, padding_value=0):
    """
    Pads the input list with the specified padding value until it reaches the target length.
    
    Parameters:
    input_list (list): The list to be padded.
    target_length (int): The desired length of the list after padding.
    padding_value: The value to use for padding. Default is 0.
    
    Returns:
    list: The padded list.
    """
    current_length = len(input_list)
    if current_length >= target_length:
        return input_list[:target_length]  # Return truncated list if it's longer than target length
    else:
        padding_needed = target_length - current_length
        return input_list + [padding_value] * padding_needed


def load_test_data(file_path:str, scaler):
    with open(file_path, "r") as f:
        data = json.load(f)

    # Initialize the new reduced list
    reduced_data = data["traces"]

    # reduced_data = [pad_list(trace, 101) for trace in reduced_data]

    # Added packet size to 1024
    traces = []
    for trace in reduced_data:
        acc = []
        for i, count in enumerate(trace):
            acc.append(1024)
        traces.append(acc)

    reduced_data = [pad_list(trace, 101) for trace in traces]

    # Added random packet size
    # traces = []
    # D = 48
    # for trace in reduced_data:
    #     acc = []
    #     for i, count in enumerate(trace):
    #         random_integer = random.randint(0, D)
    #         acc.append(count + random_integer)
    #     traces.append(acc)

    # reduced_data = [pad_list(trace, 100) for trace in traces]

    # Granularity
    # traces = []
    # agg = 20
    # for trace in reduced_data:
    #     acc = []
    #     total_count = 0
    #     for i, count in enumerate(trace):
    #         if (i + 1) % agg == 0:
    #             total_count += count
    #             acc.append(total_count)
    #             total_count = 0
    #         else:
    #             total_count += count
    #     traces.append(acc)

    # reduced_data = [pad_list(trace, 50) for trace in traces]

    traces = np.array(reduced_data)
    X = scaler.fit_transform(traces)
    print("Shape of X:", X.shape)
    return X


def load_prompts_from_txt(file_path):
    prompts = []
    with open(file_path, "r") as f:
        for line in f:
            # Strip the surrounding double quotes and any leading/trailing whitespace
            prompt = line.strip().strip('"')
            
            # Check if the last character is punctuation, and replace with question mark
            if prompt and prompt[-1] in string.punctuation:
                prompt = prompt.rstrip(string.punctuation) + "?"
            elif prompt:
                prompt += "?"

            prompts.append(prompt)
    return prompts


if __name__ == "__main__":
    scaler_path, label_path, model_path = "model/scaler.pkl", "model/label_encoder.pkl", "model/random_forest_model.pkl"
    scaler, label_encoder, clf = load_model(scaler_path, label_path, model_path)
    # kind = "semantics_sim"
    # kind = "structural_sim"
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--Inkind",
        type=str,
        required=True,
        help="The trace kinds",
    )
    parser.add_argument(
        "--kind",
        type=str,
        required=True,
        help="The trace kinds",
    )
    parser.add_argument(
        "--folder",
        type=str,
        required=True,
        help="The trace folder",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.0,
        help="The temperature for sampling.",
    )

    parser.add_argument(
        "--exp",
        type=int,
        default=0,
        help="The experiment performed.",
    )

    args = parser.parse_args()
    # X = load_test_data(f"{args.folder}/temp_{args.temperature}_{args.kind}_trace_5.out", scaler)
    X = load_test_data(f"{args.folder}/{args.kind}_temp_{args.temperature}_t_5.out", scaler)
    out_file = f"labels/predicted_{args.kind}.txt"
    eval(X, out_file, label_encoder, clf, args.Inkind)