import argparse
import torch
import sys
import os
import json

## code patch 
# from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
# import torch
# PyNcclCommunicator.device = torch.device(f"cuda:{torch.cuda.current_device()}")

from vllm import LLM, SamplingParams
from datasets import load_dataset, concatenate_datasets


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, help="Path to the model directory or identifier")
    parser.add_argument("--data_path", type=str, default="pissa-dataset", help="Path to the dataset")
    parser.add_argument('--sub_task', nargs='+', help='List of subtasks to include')
    parser.add_argument('--dataset_split', type=str, default="test", help='Dataset split to use')
    parser.add_argument('--output_file', type=str, default="model_response.jsonl", help="File to write outputs to")
    parser.add_argument("--batch_size", type=int, default=400, help="Batch size for generation")
    parser.add_argument('--temperature', type=float, default=0.0, help="Sampling temperature")
    parser.add_argument('--top_p', type=float, default=1.0, help="Nucleus sampling probability")
    parser.add_argument('--max_tokens', type=int, default=1024, help="Max tokens to generate")
    parser.add_argument("--gpus", default="0", help="Comma-separated GPU ids you want to use, e.g. '0,2,3'")
    args = parser.parse_args()

    # Multiprocessing start method for vLLM
    os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
    # Restrict visible devices
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
    gpu_list = args.gpus.split(",")
    tensor_parallel_size = len(gpu_list)
    print(f"tensor_parallel_size: {tensor_parallel_size}")

    # Setup sampling parameters
    stop_tokens = []
    sampling_params = SamplingParams(
        temperature=args.temperature,
        top_p=args.top_p,
        max_tokens=args.max_tokens,
        stop=stop_tokens
    )

    # Initialize LLM with tensor parallel
    llm = LLM(model=args.model, tensor_parallel_size=tensor_parallel_size)

    # Helper to batch data into chunks
    def batch_data(data_list, batch_size=1):
        n_full = len(data_list) // batch_size
        batches = [data_list[i*batch_size:(i+1)*batch_size] for i in range(n_full)]
        # Add remaining
        if len(data_list) % batch_size != 0:
            batches.append(data_list[n_full*batch_size:])
        return batches

    # Load dataset
    if args.sub_task is None:
        dataset = load_dataset(args.data_path, split=args.dataset_split)
    else:
        all_test = []
        for task in args.sub_task:
            ds = load_dataset(args.data_path, data_dir=task, split=args.dataset_split)
            all_test.append(ds)
        dataset = concatenate_datasets(all_test)

    # Prepare batched queries and references
    batch_queries = batch_data(dataset["instruction"], batch_size=args.batch_size)
    batch_answers = batch_data(dataset["output"], batch_size=args.batch_size)
    batch_tasks   = batch_data(dataset["type"], batch_size=args.batch_size)

    # Generation loop
    for batch_query, batch_answer, batch_task in zip(batch_queries, batch_answers, batch_tasks):
        with torch.no_grad():
            completions = llm.generate(batch_query, sampling_params)
        # Write outputs
        # with open(args.output_file, 'a') as f:
        with open(args.output_file, 'a', encoding='utf-8') as f:
            for query, completion, answer, task in zip(batch_query, completions, batch_answer, batch_task):
                record = {
                    'type': task,
                    'query': query,
                    'output': completion.outputs[0].text,
                    'answer': answer
                }
                f.write(json.dumps(record, ensure_ascii=False) + "\n")


if __name__ == '__main__':
    main()
