import argparse
import json
import os
import ray
import glob


from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.single_controller.base.decorator import register

from utils import dataload, VLLM_INFER, json_to_parquet

def main(args):
    os.environ["TOKENIZERS_PARALLELISM"] = "false"

    total_gpu_num = args.gpu_per_node * args.node_num

    inputs, gts, problems = dataload(args)

    ray.init()

    resource_pool = RayResourcePool([args.gpu_per_node] * args.node_num, use_gpu=True, max_colocate_count=1)
    class_with_args = RayClassWithInitArgs(cls=VLLM_INFER, args=args)
    worker_group = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix="pre_exploration_infer")

    total_num = len(inputs)
    print('Total samples:', total_num)

    if args.vllm_batch_size < 0:
        bsz = len(inputs)
    else:
        bsz = min(args.vllm_batch_size, len(inputs))

    results = []

    for j in range(0, len(inputs), bsz):
        e = min(j + args.vllm_batch_size, len(inputs))
        print(f'Processing {j} to {e}')
        input_bsz = inputs[j: e]
        gt_bsz = gts[j: e]
        problem_bsz = problems[j: e]
        outputs = worker_group.passk(inputs=input_bsz, gts=gt_bsz, problems=problem_bsz, k=0)
        assert len(outputs)==e-j, f'out_pass len: {len(outputs)}, bsz: {e-j}'

        results.extend(outputs)

        with open(args.output_path, 'w', encoding='utf-8') as f:
            json.dump(results, f, ensure_ascii=False, indent=4)
        
    ray.shutdown()

prompt_length = 1024 * 1
response_length = 1024 * 8

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--k', type=int, default=1)
    parser.add_argument('--vllm_k', type=int, default=16)
    parser.add_argument('--vllm_batch_size', type=int, default=1024)
    parser.add_argument('--vllm_dtype', type=str, default='bfloat16')
    parser.add_argument('--vllm_gpu_memory_utilization', type=float, default=0.85)
    parser.add_argument('--vllm_max_model_len', type=int, default=prompt_length+response_length) # prompt + response
    parser.add_argument('--vllm_max_num_batched_tokens', type=int, default=prompt_length+response_length) # prompt + response
    parser.add_argument('--vllm_temperature', type=float, default=1.2)
    parser.add_argument('--vllm_top_p', type=float, default=0.95)
    parser.add_argument('--vllm_top_k', type=int, default=20)
    parser.add_argument('--vllm_max_tokens', type=int, default=response_length) # response
    parser.add_argument('--vllm_model_path', type=str, default='/mnt/xhunter/shancun/huggingface/Qwen3-1.7B')
    parser.add_argument('--data_path', type=str, default='/mnt/xhunter/shancun/dataset/rllm_data/deepscaler_train.parquet')
    parser.add_argument('--data_source', type=str, default='')
    parser.add_argument('--output_path', type=str, default='/mnt/xhunter/shancun/log_tmp/PASS_K')
    parser.add_argument('--gpu_per_node', type=int, default=8)
    parser.add_argument('--node_num', type=int, default=1)
    
    args = parser.parse_args()

    main(args)


