



import argparse 
import json 
import sys 
from pathlib import Path 
import re 
from tqdm import tqdm 
import ray 
import os 

DEFAULT_SYSTEM_PROMPT ="You are a highly capable assistant. Your task is to estimate the output of the given Python code. The reasoning process and output are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> output here </answer>"



def parse_final_output (text :str )->str :
    pattern =r"<answer>\s*(.*?)\s*</answer>"
    match =re .search (pattern ,text ,re .DOTALL )
    if match :
        return match .group (1 ).strip ()
    return ""

@ray .remote 
class VLLMInferenceActor :
    def __init__ (self ,model_path :str ,tokenizer_path :str ,tensor_parallel_size :int ,system_prompt :str ,llm_config_extras :dict =None ):
        from transformers import AutoTokenizer 
        from vllm import LLM 

        actor_pid =os .getpid ()
        print (f"Actor (PID {actor_pid }) loading tokenizer from {tokenizer_path } ...",file =sys .stderr )
        self .tokenizer =AutoTokenizer .from_pretrained (tokenizer_path )

        print (f"Actor (PID {actor_pid }) loading VLLM LLM model from {model_path } (TP={tensor_parallel_size })...",file =sys .stderr )
        effective_llm_config ={
        "model":model_path ,
        "tensor_parallel_size":tensor_parallel_size ,
        }
        if llm_config_extras :
            effective_llm_config .update (llm_config_extras )

        self .llm =LLM (**effective_llm_config )
        self .system_prompt =system_prompt 

    def process_batch (self ,batch_input_items :list [tuple [int ,dict ]],sampling_params_obj ):



        batch_results_final =[None ]*len (batch_input_items )

        prompts_for_llm_generation =[]


        valid_items_metadata =[]

        for i ,(line_idx ,data_item )in enumerate (batch_input_items ):
            try :
                user_prompt =data_item ["input"][0 ]["content"]
                assistant_content =data_item ["output"]["content"]




                expected_final_output =parse_final_output (assistant_content )



                messages =[
                {"role":"system","content":self .system_prompt },
                {"role":"user","content":user_prompt },

                ]
                prompt_text =self .tokenizer .apply_chat_template (
                messages ,
                tokenize =False ,

                add_generation_prompt =True ,
                )

                prompts_for_llm_generation .append (prompt_text )
                valid_items_metadata .append ({
                "original_batch_idx":i ,
                "line_index":line_idx ,
                "prompt_text":prompt_text ,
                "expected_final_output":expected_final_output 
                })
            except (KeyError ,IndexError ,TypeError ):

                batch_results_final [i ]={
                "line_index":line_idx ,
                "error":"Missing input or output structure",
                "prompt":data_item .get ("input",[{}])[0 ].get ("content","ERROR_NO_PROMPT_DATA"),
                "model_output":"","expected_final_output":"",
                "model_final_output":"","is_correct":False ,"is_truncated_parse_error":True ,
                }


        if prompts_for_llm_generation :
            llm_request_outputs =self .llm .generate (prompts_for_llm_generation ,
            sampling_params =sampling_params_obj ,
            use_tqdm =False )


            for i ,request_output in enumerate (llm_request_outputs ):
                meta =valid_items_metadata [i ]
                original_idx_in_batch =meta ["original_batch_idx"]

                model_output_text =request_output .outputs [0 ].text 
                model_final_output =parse_final_output (model_output_text )

                is_truncated_flag =not bool (model_final_output )
                is_correct_flag =(meta ["expected_final_output"]==model_final_output )

                batch_results_final [original_idx_in_batch ]={
                "line_index":meta ["line_index"],
                "prompt":meta ["prompt_text"],
                "model_output":model_output_text ,
                "expected_final_output":meta ["expected_final_output"],
                "model_final_output":model_final_output ,
                "is_correct":is_correct_flag ,
                "is_truncated_parse_error":is_truncated_flag ,
                }


        for i in range (len (batch_results_final )):
            if batch_results_final [i ]is None :



                line_idx ,data_item =batch_input_items [i ]
                batch_results_final [i ]={
                "line_index":line_idx ,"error":"Unknown processing error in batch",
                "prompt":data_item .get ("input",[{}])[0 ].get ("content","ERROR_NO_PROMPT_DATA"),
                "model_output":"","expected_final_output":"",
                "model_final_output":"","is_correct":False ,"is_truncated_parse_error":True ,
                }

        return batch_results_final 

def main ():
    parser =argparse .ArgumentParser (description ="Batch inference & evaluation on a JSONL file (using VLLM and Ray).")
    parser .add_argument ("--model-path",type =str ,required =True ,
    help ="Path to the model checkpoint directory (Hugging Face format).")
    parser .add_argument ("--tokenizer-path",type =str ,required =True ,
    help ="Path to the tokenizer directory (Hugging Face format).")
    parser .add_argument ("--jsonl-file",type =str ,required =True ,
    help ="Path to the input JSONL file.")
    parser .add_argument ("--output-file",type =str ,default =None ,
    help ="Path to write inference results (JSONL). If not set, prints to stdout.")
    parser .add_argument ("--max-new-tokens",type =int ,default =512 ,
    help ="Maximum number of new tokens for generation.")
    parser .add_argument ("--system-prompt",type =str ,default =None ,help ="System-prompt text itself. If omitted, --system-prompt-file is used.")
    parser .add_argument ("--temperature",type =float ,default =0.7 ,
    help ="Sampling temperature.")
    parser .add_argument ("--top-p",type =float ,default =0.95 ,
    help ="Top-p sampling parameter.")
    parser .add_argument ("--no-sample",action ="store_true",
    help ="If set, disables sampling (greedy).")
    parser .add_argument ("--tensor-parallel-size",type =int ,default =1 ,
    help ="Tensor parallel size for VLLM (e.g., 1, 2, 4, etc.). Each actor will use this many GPUs.")
    parser .add_argument ("--num-workers",type =int ,default =1 ,
    help ="Number of Ray worker actors to create.")
    parser .add_argument ("--batch-size",type =int ,default =1 ,
    help ="Number of prompts to batch together for inference per actor call.")
    parser .add_argument ("--gpu-memory-utilization",type =float ,default =0.9 ,help ="GPU memory utilization for VLLM.")
    parser .add_argument ("--use-guided-decoding",action ="store_true",default =False ,help ="If set, enables guided decoding with a predefined regex pattern.")
    parser .add_argument ("--dtype",type =str ,default ="auto",help ="dtype for VLLM (e.g., float16, bfloat16, auto).")

    args =parser .parse_args ()

    from vllm import SamplingParams 


    print ("Initializing Ray...",file =sys .stderr )
    ray .init (ignore_reinit_error =True )

    if args .batch_size <=0 :
        print ("Error: --batch-size must be a positive integer.",file =sys .stderr )
        sys .exit (1 )

    if args .no_sample :
        temperature_for_vllm =0.0 
        top_p_for_vllm =1.0 
    else :
        temperature_for_vllm =args .temperature 
        top_p_for_vllm =args .top_p 

    if args .system_prompt :
        system_prompt_text =args .system_prompt 
    else :
        system_prompt_text =DEFAULT_SYSTEM_PROMPT 

    from vllm .sampling_params import GuidedDecodingParams 


    guided_pattern =r"<think>\n0 Statement(?s).*?"
    guided =GuidedDecodingParams (regex =guided_pattern )

    if not args .use_guided_decoding :
        print ("Guided decoding is disabled.",file =sys .stderr )
        sampling_params =SamplingParams (
        temperature =temperature_for_vllm ,
        top_p =top_p_for_vllm ,
        max_tokens =args .max_new_tokens ,

        )
    else :
        print ("Guided decoding is enabled.",file =sys .stderr )
        sampling_params =SamplingParams (
        temperature =temperature_for_vllm ,
        top_p =top_p_for_vllm ,
        max_tokens =args .max_new_tokens ,
        guided_decoding =guided ,
        )

    llm_config_extras_arg ={
    "gpu_memory_utilization":args .gpu_memory_utilization ,
    "dtype":args .dtype ,
    }

    print (f"Creating {args .num_workers } Ray worker actors...",file =sys .stderr )
    ActorClassWithOptions =VLLMInferenceActor .options (num_gpus =args .tensor_parallel_size )
    actors =[
    ActorClassWithOptions .remote (
    args .model_path ,args .tokenizer_path ,args .tensor_parallel_size ,system_prompt_text ,llm_config_extras_arg 
    )for _ in range (args .num_workers )
    ]


    try :

        ray .get ([actor .process_batch .remote ([],sampling_params )for actor in actors ],timeout =0.2 )
    except ray .exceptions .GetTimeoutError :
        print ("Actor initialization check timed out (actors might still be loading models).",file =sys .stderr )
    except Exception as e :
        print (f"Error during actor readiness check: {e }",file =sys .stderr )
    print (f"All {args .num_workers } actors launched.",file =sys .stderr )

    items_to_process =[]
    with open (args .jsonl_file ,"r",encoding ="utf-8")as f :
        for line_idx ,line in enumerate (f ,start =1 ):
            line =line .strip ()
            if not line :
                continue 
            try :
                data =json .loads (line )
                items_to_process .append ((line_idx ,data ))
            except json .JSONDecodeError :
                print (f"[Warning] Line {line_idx } could not be parsed as JSON. Skipping.",file =sys .stderr )
                continue 

    total_num_items_to_process =len (items_to_process )
    if not items_to_process :
        print ("No items to process from JSONL file.",file =sys .stderr )
        ray .shutdown ()
        return 

    results_futures =[]
    current_batch_items =[]
    actor_idx_turn =0 
    for line_idx ,item_data in items_to_process :
        current_batch_items .append ((line_idx ,item_data ))
        if len (current_batch_items )>=args .batch_size :
            actor_to_use =actors [actor_idx_turn %args .num_workers ]
            results_futures .append (actor_to_use .process_batch .remote (list (current_batch_items ),sampling_params ))
            current_batch_items .clear ()
            actor_idx_turn +=1 

    if current_batch_items :
        actor_to_use =actors [actor_idx_turn %args .num_workers ]
        results_futures .append (actor_to_use .process_batch .remote (list (current_batch_items ),sampling_params ))

    num_total_processed_items =0 
    num_correct =0 
    num_truncated_parse_errors =0 

    out_f =sys .stdout 
    if args .output_file :
        out_path =Path (args .output_file )
        out_f =out_path .open ("w",encoding ="utf-8")

    with tqdm (total =total_num_items_to_process ,desc =f"Processing prompts ({args .num_workers } workers, batch sz {args .batch_size })")as pbar :
        for future in results_futures :
            try :
                batch_of_results =ray .get (future )
                for result in batch_of_results :
                    num_total_processed_items +=1 
                    if result .get ("error"):
                        print (f"[Warning] Line {result ['line_index']}: {result ['error']}",file =sys .stderr )

                    if result .get ("is_correct"):
                        num_correct +=1 
                    if result .get ("is_truncated_parse_error"):
                        num_truncated_parse_errors +=1 

                    result_entry ={
                    "line_index":result ["line_index"],
                    "prompt":result .get ("prompt",""),
                    "model_output":result .get ("model_output",""),
                    "expected_final_output":result .get ("expected_final_output",""),
                    "model_final_output":result .get ("model_final_output",""),
                    "is_correct":result .get ("is_correct",False ),
                    }
                    out_f .write (json .dumps (result_entry ,ensure_ascii =False )+"\n")
                    if args .output_file :
                        out_f .flush ()
                    pbar .update (1 )
            except ray .exceptions .RayTaskError as e :



                print (f"A Ray task (batch) failed: {e }",file =sys .stderr )

            except Exception as e :
                print (f"An error occurred while processing a result batch: {e }",file =sys .stderr )

    if args .output_file and out_f is not sys .stdout :
        out_f .close ()
        print (f"Inference results written to: {args .output_file }",file =sys .stderr )

    if num_total_processed_items >0 :
        accuracy =num_correct /num_total_processed_items 
        truncation_rate =num_truncated_parse_errors /num_total_processed_items 
        print (f"Total items processed: {num_total_processed_items }")
        print (f"Correct: {num_correct }/{num_total_processed_items } = {accuracy :.4f}")
        print (f"Truncated (model output <answer> unparseable): {num_truncated_parse_errors }/{num_total_processed_items } = {truncation_rate :.4f}")
    else :
        print ("No results processed.")

    print ("Shutting down Ray...",file =sys .stderr )
    ray .shutdown ()

if __name__ =="__main__":
    main ()
