import json
from transformers import AutoTokenizer
import numpy as np

tokenizer=AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.3")


jsonl_file_list = [
   "../../mt_bench/7b-final-seed2024/GCSpS0.1binary5.txt",
   "../../mt_bench/7b-final-seed2024/spechub0.1binary5.txt",
    "../../mt_bench/7b-final-seed2024/RRS_wo_replacement0.1binary5.txt",
    "../../mt_bench/7b-final-seed2024/GCSpS0.6binary5.txt",
   "../../mt_bench/7b-final-seed2024/spechub0.6binary5.txt",
    "../../mt_bench/7b-final-seed2024/RRS_wo_replacement0.6binary5.txt",
    "../../mt_bench/7b-final-seed2024/GCSpS1binary5.txt",
   "../../mt_bench/7b-final-seed2024/spechub1binary5.txt",
    "../../mt_bench/7b-final-seed2024/RRS_wo_replacement1binary5.txt",

    "../../mt_bench/7b-final-seed2024/GCSpS0.1branch4.txt",
    "../../mt_bench/7b-final-seed2024/RRS_wo_replacement0.1branch4.txt",
     "../../mt_bench/7b-final-seed2024/GCSpS0.6branch4.txt",
    "../../mt_bench/7b-final-seed2024/RRS_wo_replacement0.6branch4.txt",
     "../../mt_bench/7b-final-seed2024/GCSpS1branch4.txt",
    "../../mt_bench/7b-final-seed2024/RRS_wo_replacement1branch4.txt",


    "../../mt_bench/7b-final-seed2024/GCSpS0.1mc_sim_7b_63.txt",
    "../../mt_bench/7b-final-seed2024/RRS_wo_replacement0.1mc_sim_7b_63.txt",
     "../../mt_bench/7b-final-seed2024/GCSpS0.6mc_sim_7b_63.txt",
    "../../mt_bench/7b-final-seed2024/RRS_wo_replacement0.6mc_sim_7b_63.txt",
     "../../mt_bench/7b-final-seed2024/GCSpS1mc_sim_7b_63.txt",
    "../../mt_bench/7b-final-seed2024/RRS_wo_replacement1mc_sim_7b_63.txt",
]
jsonl_file_base_list = [
    "../../mt_bench/7b-final-seed2024/RRS0.1binary5.txt",
    "../../mt_bench/7b-final-seed2024/RRS0.1binary5.txt",
    "../../mt_bench/7b-final-seed2024/RRS0.1binary5.txt",
    "../../mt_bench/7b-final-seed2024/RRS0.6binary5.txt",
    "../../mt_bench/7b-final-seed2024/RRS0.6binary5.txt",
    "../../mt_bench/7b-final-seed2024/RRS0.6binary5.txt",
    "../../mt_bench/7b-final-seed2024/RRS1binary5.txt",
    "../../mt_bench/7b-final-seed2024/RRS1binary5.txt",
    "../../mt_bench/7b-final-seed2024/RRS1binary5.txt",
    "../../mt_bench/7b-final-seed2024/RRS0.1branch4.txt",
    "../../mt_bench/7b-final-seed2024/RRS0.1branch4.txt",
    "../../mt_bench/7b-final-seed2024/RRS0.6branch4.txt",
    "../../mt_bench/7b-final-seed2024/RRS0.6branch4.txt",
    "../../mt_bench/7b-final-seed2024/RRS1branch4.txt",
    "../../mt_bench/7b-final-seed2024/RRS1branch4.txt",

    "../../mt_bench/7b-final-seed2024/RRS0.1mc_sim_7b_63.txt",
    "../../mt_bench/7b-final-seed2024/RRS0.1mc_sim_7b_63.txt",
    "../../mt_bench/7b-final-seed2024/RRS0.6mc_sim_7b_63.txt",
    "../../mt_bench/7b-final-seed2024/RRS0.6mc_sim_7b_63.txt",
    "../../mt_bench/7b-final-seed2024/RRS1mc_sim_7b_63.txt",
    "../../mt_bench/7b-final-seed2024/RRS1mc_sim_7b_63.txt",

]

def measure_speed(jsonl_file, jsonl_file_base):
    data = []
    with open(jsonl_file, 'r', encoding='utf-8') as file:
        for line in file:
            json_obj = json.loads(line)
            data.append(json_obj)



    speeds=[]
    for datapoint in data:
        qid=datapoint["question_id"]
        answer=datapoint["choices"][0]['turns']
        tokens=sum(datapoint["choices"][0]['new_tokens'])
        times = sum(datapoint["choices"][0]['wall_time'])
        speeds.append(tokens/times)


    data = []
    with open(jsonl_file_base, 'r', encoding='utf-8') as file:
        for line in file:
            json_obj = json.loads(line)
            data.append(json_obj)


    total_time=0
    total_token=0
    speeds0=[]
    for datapoint in data:
        qid=datapoint["question_id"]
        answer=datapoint["choices"][0]['turns']
        tokens = 0
        for i in answer:
            tokens += (len(tokenizer(i).input_ids) - 1)
        times = sum(datapoint["choices"][0]['wall_time'])
        speeds0.append(tokens / times)
        total_time+=times
        total_token+=tokens
    print(f"Speed ratio of {jsonl_file} and {jsonl_file_base}: {np.array(speeds).mean()/np.array(speeds0).mean():.2f}")

def measure_speed_v2(jsonl_file, jsonl_file_base):
    data = []
    with open(jsonl_file, 'r', encoding='utf-8') as file:
        for line in file:
            json_obj = json.loads(line)
            data.append(json_obj)

    token_nums = []
    sequence_times = []
    for datapoint in data:
        qid=datapoint["question_id"]
        answer=datapoint["choices"][0]['turns']
        tokens=datapoint["choices"][0]['new_tokens']
        times = datapoint["choices"][0]['wall_time']
        for i in range(len(tokens)):
            token_nums.append(tokens[i])
            sequence_times.append(times[i])

    data = []
    with open(jsonl_file_base, 'r', encoding='utf-8') as file:
        for line in file:
            json_obj = json.loads(line)
            data.append(json_obj)


    token_nums0 = []
    sequence_times0 = []
    for datapoint in data:
        qid=datapoint["question_id"]
        answer=datapoint["choices"][0]['turns']
        times = datapoint["choices"][0]['wall_time']
        tokens = 0
        j = 0
        for i in answer:
            token = (len(tokenizer(i).input_ids) - 1)
            sequence_times0.append(times[j])
            token_nums0.append(token)
            j+= 1
    
    sequence_times = np.array(sequence_times)
    sequence_times0 = np.array(sequence_times0)

    token_nums = np.array(token_nums)
    token_nums0 = np.array(token_nums0)

    mu_hat = sequence_times.sum() / token_nums.sum()
    mu_hat0 = sequence_times0.sum() / token_nums0.sum()

    token_per_time = sequence_times / token_nums
    token_per_time0 = sequence_times0 / token_nums0

    sigma_hat_squared = (1 / (len(token_nums) - 1)) * (np.sum(token_nums * (token_per_time**2)) - np.sum(token_nums* mu_hat**2) ) 

    sigma_hat_squared = sigma_hat_squared / token_nums.sum()

    sigma_hat_squared0 = (1 / (len(token_nums0) - 1)) * (np.sum(token_nums0 * (token_per_time0**2)) - np.sum(token_nums0* mu_hat0**2) )

    sigma_hat_squared0 = sigma_hat_squared0 / token_nums0.sum()
    
    final_std = (mu_hat0 / mu_hat) * np.sqrt(sigma_hat_squared0/ (mu_hat0 **2)  + sigma_hat_squared /(mu_hat ** 2))
    
    print(f"Speed ratio of {jsonl_file} and {jsonl_file_base}: {mu_hat0/mu_hat:.2f}")
    print(f"Std of Speed ratio of {jsonl_file} and {jsonl_file_base}: {final_std:.2f}")

for jsonl_file, jsonl_file_base in zip(jsonl_file_list, jsonl_file_base_list):
    measure_speed_v2(jsonl_file, jsonl_file_base)