from pycountry import languages
from src import prompt
from src  import utils
from src  import loader
import concurrent.futures
from tqdm import tqdm
import argparse

def InferenceAll(
    question_list:list[loader.Question],
    model:str,
    temperature:float,
    max_tokens:int,
    api_dict:dict=None,
    language:str="en",
    ):
    """
    Inference function to process a list of questions and return the answers.
    """
    promot=getPrompt(question_list,language)
    # Initialize an empty list to store the answers
    output,reason_content=utils.chat_completion_deepseek_with_cot(model,
                                   promot,
                                   temperature,
                                   max_tokens,
                                   api_dict)
    return {
        'id_list':[q.id for q in question_list],
        'origin_correct_answer_list':[q.correct_answer for q in question_list],
        'model':model,
        'prompt':promot,
        'generated_text':output,
        'reason_content':reason_content,
        'temperature':temperature,
        'max_tokens':max_tokens
    }

def InferenceWithBatch(
    question_list:list[loader.Question],
    batch_size:int,
    model:str,
    api_dict:dict=None,
    temperature=0.5,
    max_tokens=8192,
    max_workers=50,
    language="en",
    ):
    batches=[
        question_list[i:i+batch_size]
        for i in range(0,len(question_list),batch_size)
    ]
    executer=concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
    future_to_batch = {
        executer.submit(
            InferenceAll,
            question_list=batch,
            model=model,
            temperature=temperature,
            max_tokens=max_tokens,
            api_dict=api_dict,
            language=language,
        ): batch for batch in batches
    }
    results = []
    for future in tqdm(concurrent.futures.as_completed(future_to_batch), total=len(batches)):
        result = future.result()  # 获取单个批次的处理结果
        results.append(result)
    return results
    
def getPrompt(questions:list[loader.Question],language='en'):
    """
    Generate a prompt for the given list of questions.
    """
    return prompt.batch_chat_prompt([q.origin_question for q in questions],language)

def getAnswer(question_list:list[loader.Question],batch_size,out_put_dir,max_tokens=8192,language="en"):
    api_key=''
    result=InferenceWithBatch(
        question_list=question_list,
        batch_size=batch_size,
        model='deepseek-reasoner',
        max_tokens=max_tokens,
        api_dict={
            'api_key':api_key
        },
        language=language
    )
    utils.write_json_list(result,f"{out_put_dir}/{batch_size}.jsonl")


if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size_list', type=str, default='2', help='Batch size for processing questions')
    args=parser.parse_args()
    print(f"args:{args}")
    ok,origin_data=utils.read_jsonl('100_cot.jsonl')
    if not ok:
        print('ReadJsonl failed')
        exit(1)
    questionList=list()
    for index,item in enumerate(origin_data):
        questionList.append(Question(
            id=index,
            origin_question=item['prompt'],
            correct_answer=item['correct_ans'])
        )
    batch_size_list=[int(item) for item in args.batch_size_list.split(",")]
    for batch_size in batch_size_list:
        getAnswer(questionList,batch_size)
        print(f"finish questionList num: {len(questionList)} batch_size:{batch_size}")
        