import json
import numpy as np
import joblib
import random
from sklearn.metrics import f1_score
import string
import argparse
from typing import List, Union


def load_model(scaler_dump_path:str, label_dump_path:str, model_dump_path:str):
    return joblib.load(scaler_dump_path), joblib.load(label_dump_path), joblib.load(model_dump_path)

def compute_interarrival(times: List[List[Union[float,str]]]) -> List[List[float]]:
    """
    Given times as a list of lists of timestamps,
    return a list of lists where each inner list contains
    the differences between consecutive timestamps.
    """
    interarrivals: List[List[float]] = []
    for segment in times:
        # convert each entry to float and compute diffs
        floats = [float(t) for t in segment]
        diffs = [
            curr - prev 
            for prev, curr in zip(floats, floats[1:])
        ]
        interarrivals.append(diffs)

    # return interarrivals

    acc = []
    for lst in interarrivals:
        new_lst = []
        for stime in lst:
            mstime = stime * 1000
            if mstime > 57:
                noise = np.random.lognormal(mean=1, sigma=0)
                noisy_mstime = mstime * noise
                new_lst.append(noisy_mstime)
            else:
                new_lst.append(mstime)
        acc.append(new_lst)
    return acc


def eval(tester, file_path:str, in_data:str, label_encoder, clf):
    redicted_labels_encoded = clf.predict(tester)
    predicted_labels = label_encoder.inverse_transform(redicted_labels_encoded)

    # Write predicted labels to a .txt file
    with open(file_path, "w") as file:
        for label in predicted_labels:
            file.write(f"{label}\n")

    
    prompts = load_prompts_from_txt(f"{in_data}.txt")
    expected = [label for label in prompts for _ in range(5)]
    true_labels_encoded = label_encoder.transform(expected)
    f1 = f1_score(true_labels_encoded, redicted_labels_encoded, average='weighted')
    print(f"F1 Score: {f1}")

    print("Predicted labels have been written to predicted_labels.txt")


def pad_list(input_list, target_length, padding_value=0):
    """
    Pads the input list with the specified padding value until it reaches the target length.
    
    Parameters:
    input_list (list): The list to be padded.
    target_length (int): The desired length of the list after padding.
    padding_value: The value to use for padding. Default is 0.
    
    Returns:
    list: The padded list.
    """
    current_length = len(input_list)
    if current_length >= target_length:
        return input_list[:target_length]  # Return truncated list if it's longer than target length
    else:
        padding_needed = target_length - current_length
        return input_list + [padding_value] * padding_needed


def load_test_data(file_path:str, scaler, toggle=True):
    with open(file_path, "r") as f:
        data = json.load(f)

    # Initialize the new reduced list
    if toggle:
        reduced_data = data["traces"]  # List of 150 lists, each containing 500 floats
    else: 
        reduced_data = compute_interarrival(data["times"])

    reduced_data = [pad_list(trace, 100) for trace in reduced_data]

    traces = np.array(reduced_data)
    X = scaler.fit_transform(traces)
    print("Shape of X:", X.shape)
    return X

# def check_test_data(file_path:str):
#     with open(file_path, "r") as f:
#         data = json.load(f)

#     # Initialize the new reduced list
#     reduced_data = data["traces"]

#     reduced_data = [pad_list(trace, 50) for trace in reduced_data]

#     traces = np.array(reduced_data)
#     print("Shape of X:", traces.shape)

def load_prompts_from_txt(file_path):
    prompts = []
    with open(file_path, "r") as f:
        for line in f:
            # Strip the surrounding double quotes and any leading/trailing whitespace
            # prompt = line.strip().strip('"')
            
            # # Check if the last character is punctuation, and replace with question mark
            # if prompt and prompt[-1] in string.punctuation:
            #     prompt = prompt.rstrip(string.punctuation) + "?"
            # elif prompt:
            #     prompt += "?"

            # prompts.append(prompt)
            prompts.append(line.strip())
    return prompts


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--sd",
        type=str,
        default="",
        help="The speculative dicoding.",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.3,
        help="The temperature.",
    )
    parser.add_argument(
        "--in_data",
        type=str,
        default='',
        help="The input dataset.",
    )
    parser.add_argument(
        "--use_trace",
        type=int,
        choices=[0, 1],
        default=1,
        help="Using trace (1) or time (0).",
    )
    parser.add_argument(
        "--out_data",
        type=str,
        default='',
        help="The output dataset.",
    )
    args = parser.parse_args()

    scaler_path, label_path, model_path = "model/scaler.pkl", "model/label_encoder.pkl", "model/random_forest_model.pkl"
    scaler, label_encoder, clf = load_model(scaler_path, label_path, model_path)

    # X = check_test_data(f"{args.sd}_output/ek1_0.8_5.json")
    if int(args.use_trace) == 1:
        assert False
        # X = load_test_data(f"{args.sd}_output/{args.sd}_{args.out_data}_{args.temperature}_5.json", scaler, args.use_trace)
        X = load_test_data(f"non-deepmind/wireshirk_{args.sd}_output/{args.out_data}_{args.sd}_{args.temperature}_5.json", scaler)
    else:
        # X = load_test_data(f"non-deepmind/wireshirk_{args.sd}_output/{args.out_data}_{args.sd}_{args.temperature}_5.json", scaler)

        X = load_test_data(f"wireshirk_{args.sd}_output/{args.out_data}_{args.sd}_{args.temperature}_5.json", scaler, False)
    
    
    out_file = f"labels/predicted_{args.out_data}.txt"
    eval(X, out_file, args.in_data, label_encoder, clf)