import json
import numpy as np
import argparse
import string
import random
import matplotlib.pyplot as plt
from typing import List, Union


def compute_interarrival(times: List[List[Union[float,str]]]) -> List[List[float]]:
    """
    Given times as a list of lists of timestamps,
    return a list of lists where each inner list contains
    the differences between consecutive timestamps.
    """
    interarrivals: List[List[float]] = []
    for segment in times:
        floats = [float(t) for t in segment]
        diffs = [
            curr - prev 
            for prev, curr in zip(floats, floats[1:])
        ]
        interarrivals.append(diffs)
    return interarrivals


def load_data(file_path: str, total_samples: int, num_samples: int, scaler_dump_path:str, label_dump_path:str):
    with open(file_path, "r") as f:
        data = json.load(f)

    raw_traces = compute_interarrival(data["times"])

    inter_ms = [t * 1000 for t in raw_traces[454]]
    # threshold = np.percentile(inter_ms, 15)

    # Token IDs start from the second token
    token_ids = list(range(2, len(inter_ms) + 2))

    # Plot
    plt.figure()
    plt.plot(token_ids, inter_ms)
    plt.axhline(57, linestyle='--', label="Threshold 15th percentile")
    plt.xlabel("Token ID")
    plt.ylabel("Interarrival Time (ms)")
    plt.title("Interarrival Times per Token")
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"./temp.png")
    plt.close()




if __name__ == "__main__":
    scaler_path, label_path, model_path = "model/scaler.pkl", "model/label_encoder.pkl", "model/random_forest_model.pkl"
    parser = argparse.ArgumentParser()
    
    parser.add_argument(
        "--size",
        type=int,
        default=0,
        help="The training size.",
    )
    parser.add_argument(
        "--sd",
        type=str,
        default="EAGLE",
        help="The speculative dicoding.",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.3,
        help="The temperature.",
    )
    parser.add_argument(
        "--in_data",
        type=str,
        default='EK1',
        help="The input dataset.",
    )
    args = parser.parse_args()
    
    load_data(f"wireshirk_{args.sd}_output/{args.in_data}_{args.sd}_{args.temperature}_30.json", 30, args.size, scaler_path, label_path)
    
