import argparse
import torch
import sys
import os
import json

## code patch 
# from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
# import torch
# PyNcclCommunicator.device = torch.device(f"cuda:{torch.cuda.current_device()}")

from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from datasets import load_dataset, concatenate_datasets

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, help="Path to the model directory or identifier")
    parser.add_argument('--lora', type=str, default=None, help="Path to the lora directory or identifier")
    parser.add_argument("--data_path", type=str, default="pissa-dataset", help="Path to the dataset")
    parser.add_argument('--sub_task', nargs='+', help='List of subtasks to include')
    parser.add_argument('--dataset_split', type=str, default="test", help='Dataset split to use')
    parser.add_argument('--output_file', type=str, default="model_response.jsonl", help="File to write outputs to")
    parser.add_argument("--batch_size", type=int, default=200, help="Batch size for generation")
    parser.add_argument('--temperature', type=float, default=0.0, help="Sampling temperature")
    parser.add_argument('--top_p', type=float, default=1.0, help="Nucleus sampling probability")
    parser.add_argument('--max_tokens', type=int, default=1024, help="Max tokens to generate")
    parser.add_argument("--gpus", default="0", help="Comma-separated GPU ids you want to use, e.g. '0,2,3'")
    parser.add_argument('--temp_path', type=str, required=True, help="folder to save temp model")
    
    args = parser.parse_args()

    # Multiprocessing start method for vLLM
    os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
    # Restrict visible devices
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
    gpu_list = args.gpus.split(",")
    tensor_parallel_size = len(gpu_list)
    print(f"tensor_parallel_size: {tensor_parallel_size}")

    # Setup sampling parameters
    stop_tokens = []
    sampling_params = SamplingParams(
        temperature=args.temperature,
        top_p=args.top_p,
        max_tokens=args.max_tokens,
        stop=stop_tokens
    )

    # Initialize LLM with tensor parallel
    # llm = LLM(model=args.model, tensor_parallel_size=tensor_parallel_size)
    # llm = LLM(
    #     model=args.model, # this need to be set as residual model (will be related to PEFT method used)
    #     tensor_parallel_size=tensor_parallel_size, 
    #     enable_lora=True,
    #     max_loras=1,
    #     max_lora_rank=128  # will need to adjust this?
    # )
    # lora_request = LoRARequest(
    #     "my-lora", 
    #     1, 
    #     args.lora  # LoRA adapter path
    # )

    print(f'Load base model from {args.model}')
    base_model = AutoModelForCausalLM.from_pretrained(args.model, device_map="cpu")
    if args.lora is not None:
        print(f'Load LoRA from {args.lora}')
        model = PeftModel.from_pretrained(base_model, args.lora)
        merged_model = model.merge_and_unload()
    else:
        print('No LoRA specified, directly use the base model')
        merged_model = base_model
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    
    # import tempfile
    # temp_dir = tempfile.mkdtemp(prefix="temp_merged_model_", dir=args.temp_path)
    print(f'Merge the base and LoRA and save to {args.temp_path}')
    merged_model.save_pretrained(args.temp_path)
    tokenizer.save_pretrained(args.temp_path)

    # del base_model, model, merged_model
    del base_model, merged_model


    ## Adjust here to set different gpu_memory_utilization of VLLM (useful when multiple users share the same GPUs)
    # llm = LLM(model=args.temp_path, tensor_parallel_size=tensor_parallel_size)
    llm = LLM(model=args.temp_path, tensor_parallel_size=tensor_parallel_size, gpu_memory_utilization=0.7) 
    # llm = LLM(model=args.temp_path, tensor_parallel_size=tensor_parallel_size, gpu_memory_utilization=0.5) 
    # llm = LLM(model=args.temp_path, tensor_parallel_size=tensor_parallel_size, gpu_memory_utilization=0.35) 
    # here can add gpu_memory_utilization=0.8 or so to avoid OOM when other users also using the same GPUs


    # Helper to batch data into chunks
    def batch_data(data_list, batch_size=1):
        n_full = len(data_list) // batch_size
        batches = [data_list[i*batch_size:(i+1)*batch_size] for i in range(n_full)]
        # Add remaining
        if len(data_list) % batch_size != 0:
            batches.append(data_list[n_full*batch_size:])
        return batches

    # Load dataset
    if args.sub_task is None:
        dataset = load_dataset(args.data_path, split=args.dataset_split)
    else:
        import re
        all_test = []
        for task in args.sub_task:
            print(task)
            task_base = task.split(":", 1)[0]
            task_name = re.sub(r"-ep\d+$", "", task_base)
            print(task_name)
            ds = load_dataset(args.data_path, data_dir=task_name, split=args.dataset_split)
            all_test.append(ds)
        dataset = concatenate_datasets(all_test)
    # if args.sub_task is None:
    #     dataset = load_dataset(args.data_path, split=args.dataset_split)
    # else:
    #     all_test = []
    #     for task in args.sub_task:
    #         print(task)
    #         ## if training on a subset of task, we still evaluate on full testing set!
    #         if ":" in task:
    #             task = task.split(":")[0]
    #         else:
    #             ds = load_dataset(args.data_path, data_dir=task, split=args.dataset_split)
    #         # ds = ds.select(range(min(50, len(ds))))  # 每個task取前50筆 --> [WARNING] should only uncomment this during test run
    #         all_test.append(ds)
    #     dataset = concatenate_datasets(all_test)

    # Prepare batched queries and references
    batch_queries = batch_data(dataset["instruction"], batch_size=args.batch_size)
    batch_answers = batch_data(dataset["output"], batch_size=args.batch_size)
    batch_tasks   = batch_data(dataset["type"], batch_size=args.batch_size)

    # Generation loop
    for batch_query, batch_answer, batch_task in zip(batch_queries, batch_answers, batch_tasks):
        with torch.no_grad():
            completions = llm.generate(batch_query, sampling_params)
            # completions = llm.generate(batch_query, sampling_params, lora_request=lora_request)
        # Write outputs
        # with open(args.output_file, 'a') as f:
        with open(args.output_file, 'a', encoding='utf-8') as f:
            for query, completion, answer, task in zip(batch_query, completions, batch_answer, batch_task):
                record = {
                    'type': task,
                    'query': query,
                    'output': completion.outputs[0].text,
                    'answer': answer
                }
                f.write(json.dumps(record, ensure_ascii=False) + "\n")
    
    import shutil
    print("Cleaning up temp merged model...")
    shutil.rmtree(args.temp_path)

if __name__ == '__main__':
    main()