import json
import os

# Directory to search for jsonl files
input_dir = '/home/abdulr81/RestWork/REST/llm_judge/data/mt_bench/model_answer'

# Prepare to collect traces and labels
traces = []
labels = []

# Loop through the files in the specified directory
for filename in os.listdir(input_dir):
    if "test_temp" in filename and filename.endswith('.jsonl'):
        file_path = os.path.join(input_dir, filename)
        # Read the JSONL file
        with open(file_path, 'r') as file:
            for line in file:
                data = json.loads(line)
                question_id = data.get("question_id")
                # label = data.get("category")
                # Collect question IDs
                labels.append(question_id)
                # labels.append(label)
                # Initialize trace data for this question
                question_trace = []
                # Process each choice for the current line
                for choice in data.get("choices", []):
                    if "individual_token_times" in choice and "accept_lengths:" in choice:
                        token_times = choice["individual_token_times"]
                        accept_lengths = choice["accept_lengths:"]
                        # Compute the time per token based on accepted lengths
                        for time, length in zip(token_times, accept_lengths):
                            if length > 0:
                                per_token_time = time / length
                                question_trace.extend([per_token_time] * length)
                # Append the trace for this question
                traces.append(question_trace)

# Create the final JSON object
final_output = {
    "traces": traces,
    "labels": labels
}

# Save the final output to a JSON file
output_file_path = os.path.join(input_dir, 'final_output_9.json')
with open(output_file_path, 'w') as f:
    json.dump(final_output, f, indent=2)

print(f"Data saved to {output_file_path}")
print(f"length of traces: {len(traces)}")
print(f"length of labels: {len(labels)}")