import os
import json
import argparse
import datetime
import re

import torch
from transformers import AutoTokenizer

data_file_path = '../datas/system_benchmark_eval_datas.json'

def load_examples(dataset_filepath):
    data = json.load(open(dataset_filepath, encoding="utf-8"))
    return data

def converation_generator(sysmeg_id):
    for entry in load_examples(data_file_path):
        if entry['system_id'] == sysmeg_id:
            # print("System message ID:", sysmeg_id)
            for message in entry['messages']:
                # if message['role'] == 'assistant':
                #     continue # ignore ground truth
                yield message
            break
    else:
        raise ValueError(f"System message with id {sysmeg_id} not found")

def calc_token(sysmeg_id, tokenizer):
    n_prefill, n_decode = 0, 0
    messages = []
    
    last_prefill = 0
    for message in converation_generator(sysmeg_id):
        messages.append(message)
        
        if message['role'] == 'system':
            continue
        
        if message['role'] == 'user':
             tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
             n_prefill += tokenized_chat.shape[-1]
             last_prefill = tokenized_chat.shape[-1]
        elif message['role'] == 'assistant':
            tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=False, return_tensors="pt")
            n_decode += tokenized_chat.shape[-1] - last_prefill
    return n_prefill, n_decode

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", '-m', type=str, required=True)
    parser.add_argument("--sid", '-s', type=int, required=True)
    args = parser.parse_args()
    
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    
    if args.sid == -1:
        n_prefill, n_decode = 0, 0
        from tqdm import tqdm
        for sid in tqdm(range(1, 501)):
            n_prefill_, n_decode_ = calc_token(sid, tokenizer)
            n_prefill += n_prefill_
            n_decode += n_decode_
    else:
        n_prefill, n_decode = calc_token(args.sid, tokenizer)
    
    print("Tokenizer:", args.model)
    print(f"Prefill tokens: {n_prefill}, Decode tokens: {n_decode}")
    price = n_prefill * (0.015 * 1e-3) + n_decode * (0.06 * 1e-3)
    exchange_rate = 7.0
    print(f"Price: {price:.2f} USD, {price * exchange_rate:.2f} RMB.")