import argparse
import concurrent.futures
from tqdm import tqdm
from src import prompt,utils
from datetime import datetime
import time,random


def get_args():
    parser = argparse.ArgumentParser(description="Split Script")
    parser.add_argument(
        "--batch_size_list",
        type=str,
        default="1,2,3,5,10,15",
        help="Batch size for inference.",
    )
    
    parser.add_argument(
        "--language",
        type=str,
        default="en",
        choices=["zh", "en"],
        help="Language for the prompts.",     
    )
    
    parser.add_argument(
        "--input_dir",
        type=str,
        required=True,
        help="Path of the input dir which save the generated text.",
    )
    parser.add_argument(
        "--sample_num",
        type=int,
        required=False,
        default=-1,
        help="Number of samples to process. Default is -1.",
    )
    
    
    return parser.parse_args()


def batch_split_text(
    item_list,
    model:str,
    api_dict:dict=None,
    batch_size=1,
    temperature=0.5,
    max_tokens=8192*2,
    max_workers=500,
    language="en",
    ):
    executer=concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
    future_to_batch = {
        executer.submit(
            get_split_text,
            item=item,
            model=model,
            temperature=temperature,
            max_tokens=max_tokens,
            api_dict=api_dict,
            batch_size=batch_size,
            language=language,
        ): item for item in item_list
    }
    results = []
    for future in tqdm(concurrent.futures.as_completed(future_to_batch), total=len(item_list)):
        result = future.result()  # 获取单个批次的处理结果
        results.append(result)
        # if len(results)%5==0:
        #     print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}:results len: {len(results)}")
    return results
def check_api_split(message,output,batch_size):
    spl=utils.split_string_by_separator(output)
    if len(spl)==batch_size and len(output)*1.5>len(message):
        return 101
    if len(spl)==batch_size:
        return len(output)/10000000
    return -1
def get_split_text(
    item,
    model:str,
    api_dict:dict=None,
    batch_size=1,
    temperature=0.5,
    max_tokens=8192,
    language="en"):
    message=prompt.get_split_prompt(item['generated_text'],language=language)
    if batch_size==1:
        output=item['generated_text']
    else:
        output=""
        for _ in range(3):
            output=utils.chat_completion_deepseek(
                model,
                message,
                temperature,
                max_tokens,
                api_dict,
            )
            if len(output)*1.2>len(item['generated_text']) and len(utils.split_string_by_separator(output))==batch_size:
                break
            elif output=="$ERROR$":
                print(f"Split failed generate error")
                break
            elif len(item['generated_text'])>20000:
                print(f"Split failed generate error")
                break
            else:
                print(f"Split failed, retrying... ")
                time.sleep(1)
    ret={
        **item,
        'split_output':output,
    }
    return ret

def get_out_put_path(in_path,new_section="split"):
    return in_path.replace("vllm_inference", new_section)


def gen_split_text(dir_path='',batch_size=2,sample_num=-1):
    file_path=f"{dir_path}/{batch_size}.jsonl"
    ok,data=utils.read_jsonl(file_path)
    if not ok:
        print("Error reading file:", file_path)
    out_path=get_out_put_path(file_path)
    print(out_path)
    if sample_num>0:
        data=data[:sample_num]
    api_key=''
    base_url=""
    
    ret=batch_split_text(
        item_list=data,
        model='',
        api_dict={
            'api_key':api_key,
            'base_url':base_url,
        },
        batch_size=batch_size
    )
    # print(ret)
    utils.write_json_list(ret,out_path)
    return ret
if __name__ == '__main__':
    args=get_args()
    print(args)
    data_size_list=[int(x) for x in args.batch_size_list.split(',')]
    print(f"Batch size list: {data_size_list}")  
    for batch_size in data_size_list:
        gen_split_text(args.input_dir,batch_size,args.sample_num)

