import json
import numpy as np
import joblib
import argparse
import string
import random
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder, StandardScaler


def get_time(file_path: str):
    with open(file_path, "r") as f:
        data = json.load(f)

    ttime = data["total_size"]
    print(sum(ttime)/len(ttime))


def get_size(file_path:str, D:int):
    with open(file_path, "r") as f:
        data = json.load(f)

    raw_traces = data["traces"]  # List of 150 lists, each containing 500 floats

    # Added packet size
    traces = []
    for trace in raw_traces:
        acc = 0
        for i, count in enumerate(trace):
            random_integer = random.randint(0, D)
            acc += random_integer
        trace_total = sum(trace) + acc
        traces.append(acc / trace_total)

    print(f"D={D} gives {sum(traces)/len(traces) * 100}%")


def pad_list(input_list, target_length, padding_value=0):
    """
    Pads the input list with the specified padding value until it reaches the target length.
    
    Parameters:
    input_list (list): The list to be padded.
    target_length (int): The desired length of the list after padding.
    padding_value: The value to use for padding. Default is 0.
    
    Returns:
    list: The padded list.
    """
    current_length = len(input_list)
    if current_length >= target_length:
        return input_list[:target_length]  # Return truncated list if it's longer than target length
    else:
        padding_needed = target_length - current_length
        return input_list + [padding_value] * padding_needed


def load_data(file_path: str, total_samples: int, num_samples: int, scaler_dump_path:str, label_dump_path:str):
    ### 1. Load data from traces file
    # Load data from JSON file
    with open(file_path, "r") as f:
        data = json.load(f)

    raw_traces = data["traces"]  # List of 150 lists, each containing 500 floats
    labels = data["labels"]  # List of 150 strings
    # traces = [pad_list(trace, 100) for trace in raw_traces]

    # Added packet size to 1024
    traces = []
    for trace in raw_traces:
        acc = []
        for i, count in enumerate(trace):
            acc.append(1024)
        traces.append(acc)

    traces = [pad_list(trace, 100) for trace in traces]


    # Added random packet size
    # traces = []
    # D = 48
    # for trace in raw_traces:
    #     acc = []
    #     for i, count in enumerate(trace):
    #         random_integer = random.randint(0, D)
    #         acc.append(count + random_integer)
    #     traces.append(acc)

    # traces = [pad_list(trace, 100) for trace in traces]

    # Granularity
    # traces = []
    # agg = 20
    # for trace in raw_traces:
    #     acc = []
    #     total_count = 0
    #     for i, count in enumerate(trace):
    #         if (i + 1) % agg == 0:
    #             total_count += count
    #             acc.append(total_count)
    #             total_count = 0
    #         else:
    #             total_count += count
    #     traces.append(acc)

    # traces = [pad_list(trace, 50) for trace in traces]

    label_encoder = LabelEncoder()
    scaler = StandardScaler()

    if num_samples != 0:

        new_traces = []
        new_labels = []

        # Iterate over the labels in chunks of per_label size
        for i in range(0, len(labels), total_samples):
            # Take the first n elements from each chunk
            new_traces.extend(traces[i:i + num_samples])
            new_labels.extend(labels[i:i + num_samples])

        # Convert labels to numerical labels using LabelEncoder
        y = label_encoder.fit_transform(new_labels)

        # Convert to numpy array for scaling
        new_traces = np.array(new_traces)
        X = scaler.fit_transform(new_traces)
    else:
        y = label_encoder.fit_transform(labels)

        # Convert to numpy array for scaling
        traces = np.array(traces)
        X = scaler.fit_transform(traces)

    # Check the shape of X and y
    print("Shape of X:", X.shape)
    print("Shape of y:", y.shape)

    joblib.dump(scaler, scaler_dump_path)
    joblib.dump(label_encoder, label_dump_path)

    return X, y


def load_prompts_from_txt(file_path):
    prompts = []
    with open(file_path, "r") as f:
        for line in f:
            # Strip the surrounding double quotes and any leading/trailing whitespace
            prompt = line.strip().strip('"')
            
            # Check if the last character is punctuation, and replace with question mark
            if prompt and prompt[-1] in string.punctuation:
                prompt = prompt.rstrip(string.punctuation) + "?"
            elif prompt:
                prompt += "?"

            prompts.append(prompt)
    return prompts


def fit_model(X, y, dump_path:str):
    ### Train classifier with X and y
    clf = RandomForestClassifier(
        n_estimators=150,
        max_depth=15,
        min_samples_split=10,
        min_samples_leaf=1,
        max_features='sqrt',
        random_state=42
    )
    clf.fit(X, y)
    joblib.dump(clf, dump_path)


if __name__ == "__main__":
    scaler_path, label_path, model_path = "fingerprinting_attack/model/scaler.pkl", "fingerprinting_attack/model/label_encoder.pkl", "fingerprinting_attack/model/random_forest_model.pkl"
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--Inkind",
        type=str,
        required=True,
        help="The trace Input kinds",
    )
    parser.add_argument(
        "--kind",
        type=str,
        required=True,
        help="The trace kinds",
    )
    parser.add_argument(
        "--folder",
        type=str,
        required=True,
        help="The trace folder",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.0,
        help="The temperature for sampling.",
    )
    parser.add_argument(
        "--size",
        type=int,
        default=0,
        help="The training size.",
    )
    parser.add_argument(
        "--exp",
        type=int,
        default=0,
        help="The experiment performed.",
    )
    args = parser.parse_args()
    # X, y = load_data(f"{args.folder}/temp_{args.temperature}_{args.Inkind}_trace_30.out", 30, args.size, scaler_path, label_path)
    
    X, y = load_data(f"{args.folder}/BILD_{args.Inkind}_temp_{args.temperature}_t_30.json", 30, args.size, scaler_path, label_path)
    fit_model(X, y, model_path)

    # get_time(f"{args.folder}/{args.Inkind}_temp_{args.temperature}_t_30.out")

    for D in [6, 12, 24, 48]:
        get_size(f"{args.folder}/BILD_{args.kind}_temp_{args.temperature}_t_5.json", D)
