import time
from openai import OpenAI
import json, argparse
import os
from transformers import AutoTokenizer



tokenizer = AutoTokenizer.from_pretrained(
    "lmsys/vicuna-13b-v1.3",
    trust_remote_code=True,
    use_fast=False,           # ← force SentencePiece
)

parser = argparse.ArgumentParser()
parser.add_argument(
    "--sd",
    type=str,
    default="EAGLE",
    help="The speculative dicoding.",
)
parser.add_argument(
    "--temperature",
    type=float,
    default=1.0,
    help="The temperature.",
)
parser.add_argument(
    "--in_data",
    type=str,
    default='EK1',
    help="The input dataset.",
)
parser.add_argument(
    "--trial",
    type=int,
    default=30,
    help="Number of iterations.",
)
args = parser.parse_args()

openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8001/v1"
client = OpenAI(api_key=openai_api_key, base_url=openai_api_base)

models = client.models.list()
model = models.data[0].id

stream = True
input_filename = f"{args.in_data}.txt"
os.makedirs(f"./{args.sd}_output", exist_ok=True)
output_filename = f"./{args.sd}_output/{args.sd}_{args.in_data}_{args.temperature}_{args.trial}.json"

with open(input_filename) as f:
    prompts = [line.strip() for line in f]

result = {
    "labels": [],
    "responses": [],
    # for each prompt iteration, store a list of [timestamp, size] records:
    "traces": [],
    "tokens": [],
    "times": []
}

for idx, prompt in enumerate(prompts):
    # if idx == 0:
    #     total_run = args.trial + 1
    # else:
    total_run = args.trial
    for run in range(total_run):
        # send the request
        resp = client.completions.create(
            model=model,
            prompt=prompt,
            echo=False,
            max_tokens=500,
            n=1,
            temperature=args.temperature,
            stream=stream,
        )
        chunk_records = []
        this_time = []
        # num_tokens = []
        # full_text = ""
        for chunk in resp:
            now = time.time()
            txt = chunk.choices[0].text
            size = len(txt.encode("utf-8"))
            chunk_records.append(size)
            # token_ids = tokenizer.encode(txt, add_special_tokens=False)
            # num_token = len(token_ids)
            # num_tokens.append(num_token)
            # chunk_records.append({
            #     "timestamp": now,
            #     "packet_size": size,
            # })
            this_time.append(now)
            # full_text += txt
            print(f"Text Streamed:    {txt}")
            print(f"Packet Length:    {size}")
            print(f"Packet Time:    {now}")
            # print(f"Token number:     {num_token}")

        # Optionally compute interarrival deltas here:
        # interarrival_ms = [
        #     (chunk_records[i][0] - chunk_records[i-1][0]) * 1000
        #     for i in range(1, len(chunk_records))
        # ]

        result["labels"].append(prompt)
        # result["responses"].append(full_text)
        result["traces"].append(chunk_records)
        result["times"].append(this_time)

        # result["tokens"].append(num_tokens)

        print(f"Completed prompt {idx+1}, run {run+1}")

# write everything out
with open(output_filename, "w") as out:
    json.dump(result, out, indent=2)
