import os
import json
import argparse
from pathlib import Path
from typing import List, Dict, Any, Tuple
from transformers import AutoTokenizer
from src import utils,loader
from datasets import load_dataset
from datetime import datetime
def load_data():
    data=load_dataset("math_distill_dataset_easy")
    data=data['train']
    print(data)

    columns_to_keep = ['id', 'question', 'correct_answer']  # 你想保留的列
    columns_to_remove = [col for col in data.column_names if col not in columns_to_keep]
    filtered_dataset = data.remove_columns(columns_to_remove)
    print(filtered_dataset)
    list_pt = filtered_dataset.to_list()
    
    print(f"list_pt len:{len(list_pt)}")
    return list_pt

en_batch_prompt='''
Please answer the following math problems in order, insert the separator <###> between each solution, and place the final answer for each question in \\boxed{{}}.

Your response should follow the format below, Please don't include question numbers:
<###> Detailed solution for problem 1...\n\\boxed{{answer1}}
<###> Detailed solution for problem 2...\n\\boxed{{answer2}}
<###> Detailed solution for problem 3...\n\\boxed{{answer3}}

Here is the list of questions:
{numbered_questions}
'''
def batch_prompt(questions:list[str],language='en'):
    numbered_questions = "\n".join([f"{i+1}. {q}" for i, q in enumerate(questions)])
    return en_batch_prompt.format(numbered_questions=numbered_questions)

def build_batch_questions(question_list,batch_size=2):
    questions=[]
    for i in range(0,len(question_list),batch_size):
        batch_pt=question_list[i:i+batch_size]
        id_list=[pt['id'] for pt in batch_pt]
        origin_question_list=[pt['question'] for pt in batch_pt]
        origin_correct_answer_list=[pt['correct_answer'] for pt in batch_pt]
        question=loader.BatchPrompt(
            id_list=id_list,
            origin_question_list=origin_question_list,
            origin_correct_answer_list=origin_correct_answer_list,
            prompt=batch_prompt(origin_question_list),
            text=''
        )
        questions.append(question)
    print(f"build_batch_questions batch_size:{batch_size} questions len:{len(questions)}")
    return questions

def get_batch_infer_answer(questions:list[loader.BatchPrompt])->list[loader.BatchOutput]:
    """
    批量获取问题的答案
    """
    batch_q=[q.prompt for q in questions]
    batch_resp=utils.multi_chat_completion_api(batch_q,model=' ')
    ret=[]
    for idx,question in enumerate(questions):
        item=loader.BatchOutput(
           id_list=question.id_list,
           origin_question_list=question.origin_question_list,
           origin_correct_answer_list=question.origin_correct_answer_list,
           prompt=question.prompt,
           generated_text=batch_resp[idx],
           origin_prompt=question.prompt,
        )
        ret.append(item)
    return ret


def main():
   
    parser = argparse.ArgumentParser(description="Model Distillation Pipeline")
    parser.add_argument("--model", type=str, required=False, default=' ',
                        help="Path to the model directory or file")
    parser.add_argument("--batch_size",type=int, required=False, default=3,
                        help="List of dataset file paths")
    parser.add_argument("--output_dir", type=str, default="/out_put/distillation_batch_results",
                        help="Output directory for results")
    
    args = parser.parse_args()
    
    print(args)
    
    
    data=load_data()
    questions=build_batch_questions(data,args.batch_size)
    batch_outputs=get_batch_infer_answer(questions)
    json_list=[x.__json__() for x in batch_outputs]

    date_time=datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    file_name=f"{args.model}_{args.batch_size}_{date_time}.jsonl"
    utils.write_json_list(json_list,f"{args.output_dir}/{file_name}")
    correct_num=0
    tt=0
    for item in batch_outputs:
        orgin_r=utils.split_string_by_separator(item.generated_text)
        if len(orgin_r)!=len(item.origin_correct_answer_list):
            print("ERROR!")
        else:
            for idx,r in enumerate(orgin_r):
                is_correct=utils.is_answer_correct(r,item.origin_correct_answer_list[idx])
                if is_correct:
                    correct_num+=1
                tt+=1
    print(correct_num,tt,correct_num/tt)



if __name__ == "__main__":
    main()