import json
import numpy as np
import joblib
import argparse
import string
import random
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder, StandardScaler
from typing import List, Union


def get_time(file_path: str):
    with open(file_path, "r") as f:
        data = json.load(f)

    ttime = data["total_size"]
    print(sum(ttime)/len(ttime))


def get_size(file_path:str, D:int):
    with open(file_path, "r") as f:
        data = json.load(f)

    raw_traces = data["traces"]  # List of 150 lists, each containing 500 floats

    # Added packet size
    traces = []
    for trace in raw_traces:
        acc = 0
        for i, count in enumerate(trace):
            random_integer = random.randint(0, D)
            acc += random_integer
        trace_total = sum(trace) + acc
        traces.append(acc / trace_total)

    print(f"D={D} gives {sum(traces)/len(traces) * 100}%")


def pad_list(input_list, target_length, padding_value=0):
    """
    Pads the input list with the specified padding value until it reaches the target length.
    
    Parameters:
    input_list (list): The list to be padded.
    target_length (int): The desired length of the list after padding.
    padding_value: The value to use for padding. Default is 0.
    
    Returns:
    list: The padded list.
    """
    current_length = len(input_list)
    if current_length >= target_length:
        return input_list[:target_length]  # Return truncated list if it's longer than target length
    else:
        padding_needed = target_length - current_length
        return input_list + [padding_value] * padding_needed


def compute_interarrival(times: List[List[Union[float,str]]]) -> List[List[float]]:
    """
    Given times as a list of lists of timestamps,
    return a list of lists where each inner list contains
    the differences between consecutive timestamps.
    """
    interarrivals: List[List[float]] = []
    for segment in times:
        # convert each entry to float and compute diffs
        floats = [float(t) for t in segment]
        diffs = [
            curr - prev 
            for prev, curr in zip(floats, floats[1:])
        ]
        interarrivals.append(diffs)

    # return interarrivals

    acc = []
    for lst in interarrivals:
        new_lst = []
        for stime in lst:
            mstime = stime * 1000
            if mstime > 57:
                noise = np.random.lognormal(mean=0, sigma=1)
                # add it
                noisy_mstime = mstime * noise
                new_lst.append(noisy_mstime)
            else:
                new_lst.append(mstime)
        acc.append(new_lst)
    return acc


def load_data(file_path: str, total_samples: int, num_samples: int, scaler_dump_path:str, label_dump_path:str, toggle=True):
    ### 1. Load data from traces file
    # Load data from JSON file
    with open(file_path, "r") as f:
        data = json.load(f)

    if toggle:
        raw_traces = data["traces"]  # List of 150 lists, each containing 500 floats
    else: 
        
        raw_traces = compute_interarrival(data["times"])

    labels = data["labels"]  # List of 150 strings
    traces = [pad_list(trace, 100) for trace in raw_traces]

    # Added packet size to 1024
    # traces = []
    # for trace in raw_traces:
    #     acc = []
    #     for i, count in enumerate(trace):
    #         acc.append(1024)
    #     traces.append(acc)

    # traces = [pad_list(trace, 100) for trace in traces]


    # Added random packet size
    # traces = []
    # D = 48
    # for trace in raw_traces:
    #     acc = []
    #     for i, count in enumerate(trace):
    #         random_integer = random.randint(0, D)
    #         acc.append(count + random_integer)
    #     traces.append(acc)

    # traces = [pad_list(trace, 100) for trace in traces]

    # Granularity
    # traces = []
    # agg = 20
    # for trace in raw_traces:
    #     acc = []
    #     total_count = 0
    #     for i, count in enumerate(trace):
    #         if (i + 1) % agg == 0:
    #             total_count += count
    #             acc.append(total_count)
    #             total_count = 0
    #         else:
    #             total_count += count
    #     traces.append(acc)

    # traces = [pad_list(trace, 50) for trace in traces]

    label_encoder = LabelEncoder()
    scaler = StandardScaler()

    if num_samples != 0:

        new_traces = []
        new_labels = []

        # Iterate over the labels in chunks of per_label size
        for i in range(0, len(labels), total_samples):
            # Take the first n elements from each chunk
            new_traces.extend(traces[i:i + num_samples])
            new_labels.extend(labels[i:i + num_samples])

        # Convert labels to numerical labels using LabelEncoder
        y = label_encoder.fit_transform(new_labels)

        # Convert to numpy array for scaling
        new_traces = np.array(new_traces)
        X = scaler.fit_transform(new_traces)
    else:
        # print(labels[0])
        y = label_encoder.fit_transform(labels)

        # Convert to numpy array for scaling
        traces = np.array(traces)
        X = scaler.fit_transform(traces)

    # Check the shape of X and y
    print("Shape of X:", X.shape)
    print("Shape of y:", y.shape)

    joblib.dump(scaler, scaler_dump_path)
    joblib.dump(label_encoder, label_dump_path)

    return X, y


def load_prompts_from_txt(file_path):
    prompts = []
    with open(file_path, "r") as f:
        for line in f:
            # Strip the surrounding double quotes and any leading/trailing whitespace
            prompt = line.strip().strip('"')
            
            # Check if the last character is punctuation, and replace with question mark
            if prompt and prompt[-1] in string.punctuation:
                prompt = prompt.rstrip(string.punctuation) + "?"
            elif prompt:
                prompt += "?"

            prompts.append(prompt)
    return prompts


def fit_model(X, y, dump_path:str):
    ### Train classifier with X and y
    clf = RandomForestClassifier(
        n_estimators=150,
        max_depth=15,
        min_samples_split=10,
        min_samples_leaf=1,
        max_features='sqrt',
        random_state=42
    )
    clf.fit(X, y)
    joblib.dump(clf, dump_path)


if __name__ == "__main__":
    scaler_path, label_path, model_path = "model/scaler.pkl", "model/label_encoder.pkl", "model/random_forest_model.pkl"
    parser = argparse.ArgumentParser()
    
    parser.add_argument(
        "--size",
        type=int,
        default=0,
        help="The training size.",
    )
    parser.add_argument(
        "--sd",
        type=str,
        default="EAGLE",
        help="The speculative dicoding.",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.3,
        help="The temperature.",
    )
    parser.add_argument(
        "--use_trace",
        type=int,
        choices=[0, 1],
        default=1,
        help="Using trace (1) or time (0).",
    )
    parser.add_argument(
        "--in_data",
        type=str,
        default='',
        help="The input dataset.",
    )
    args = parser.parse_args()
    if int(args.use_trace) == 1:
        assert False
        # X, y = load_data(f"{args.sd}_output/{args.sd}_{args.in_data}_{args.temperature}_30.json", 30, args.size, scaler_path, label_path, args.use_trace)
        X, y = load_data(f"non-deepmind/wireshirk_{args.sd}_output/{args.in_data}_{args.sd}_{args.temperature}_30.json", 30, args.size, scaler_path, label_path)
    else:
        # X, y = load_data(f"non-deepmind/wireshirk_{args.sd}_output/{args.in_data}_{args.sd}_{args.temperature}_30.json", 30, args.size, scaler_path, label_path)

        X, y = load_data(f"wireshirk_{args.sd}_output/{args.in_data}_{args.sd}_{args.temperature}_30.json", 30, args.size, scaler_path, label_path, False)
    
    fit_model(X, y, model_path)
