import argparse

from transformers import AutoTokenizer
from utils.utils import load_jsonl

from utils.build_prompt import build_prompt
from utils.utils import CodexTokenizer, CodeGenTokenizer, StarCoderTokenizer,DeepseekCoderTokenizer

def parser_args():
    parser = argparse.ArgumentParser(description="Generate response from llm")
    parser.add_argument('--input_file_name', default='api_level', type=str)
    parser.add_argument('--model', default='gpt-3.5-turbo-instruct', type=str)
    parser.add_argument('--mode', default='retrieval', type=str, choices=['infile', 'retrieval'])
    parser.add_argument('--max_top_k', default=10, type=int)
    parser.add_argument('--max_new_tokens', default=100, type=int)

    return parser.parse_args()

def main(args, input_cases):
    if args.model == 'gpt-3.5-turbo-instruct':
        tokenizer = CodexTokenizer()
        max_num_tokens = 2048

    elif args.model == 'deepseek-coder':
        tokenizer_raw = AutoTokenizer.from_pretrained(f"./models_cache/{args.model}-tokenizer/", trust_remote_code=True)
        tokenizer = DeepseekCoderTokenizer(tokenizer_raw)
        max_num_tokens = 2048


    elif args.model == 'CodeLlama':
        tokenizer_raw = AutoTokenizer.from_pretrained(f"./models_cache/{args.model}-tokenizer/", trust_remote_code=True)
        tokenizer = CodeGenTokenizer(tokenizer_raw)
        max_num_tokens = 2048


    elif args.model in [ 'codegen2-7b', 'codegen2-1b']:
        tokenizer_raw = AutoTokenizer.from_pretrained(f"./models_cache/{args.model}-tokenizer/", trust_remote_code=True)
        tokenizer = CodeGenTokenizer(tokenizer_raw)
        max_num_tokens = 2048

    max_prompt_tokens = max_num_tokens - args.max_new_tokens

    total_tokens = 0

    for case in input_cases:
        # prompt = case['prompt']
        prompt = build_prompt(case, tokenizer, max_prompt_tokens, max_top_k=args.max_top_k, mode=args.mode)
        tokens = tokenizer.tokenize(prompt)
        length = len(tokens)
        if length > max_prompt_tokens:
            length = max_prompt_tokens
        total_tokens += length
    print(f"input Total tokens: {total_tokens}")

if __name__ == "__main__":
    args = parser_args()
    input_cases = load_jsonl(f"./search_results/{args.input_file_name}.search_res.jsonl")
    main(args, input_cases)

# python compute_input_token.py --input_file_name line_level.python.test --model codegen2-7b --max_new_tokens 100