import argparse
import sys
import os
import json
import torch
import jsonlines
from tqdm import tqdm
from vllm import LLM, SamplingParams
import argparse

BASE_DIR = ""
print(BASE_DIR)
sys.path.append(BASE_DIR)

MAX_INT = sys.maxsize
MAX_TRY = 10
INVALID_ANS = "[INVALID]"


def filter_output(response):
    if len(response.strip()) <= 1:
        return None
    return response


def process_batch_data(data_list, batch_size=1):
    num_batches = len(data_list) // batch_size
    batches = []
    for i in range(num_batches - 1):
        start = i * batch_size
        end = (i + 1) * batch_size
        batches.append(data_list[start:end])
    last_start = (num_batches - 1) * batch_size
    batches.append(data_list[last_start:MAX_INT])
    return batches


def append_response_to_file(data_items, generated_texts, output_file, idx):
    mode = "a" if idx != 0 else "w"
    with jsonlines.open(output_file, mode=mode) as writer:
        for item, gen_text in zip(data_items, generated_texts):
            item["response"] = gen_text
            writer.write(item)


def infer_batch(
    model_path,
    batch_size,
    input_file,
    output_file,
    max_token_length=8192,
    phase="",
    limit=None,
):
    print(f"Input file: {input_file}", flush=True)

    data_items = []
    prompts = []
    with open(input_file, "r") as f:
        for idx, item in enumerate(jsonlines.Reader(f)):
            # if limit and idx == limit:
            #     break
            prompts.append(item["prompt"])
            data_items.append(item)
    if not limit:
        prompts = prompts[limit:]
        data_items = data_items[limit:]

    # prompts = prompts[:10]
    # data_items = data_items[:10]

    print("First prompt example:", prompts[0], flush=True)
    print("Number of samples to infer:", len(prompts))

    prompt_batches = process_batch_data(prompts, batch_size=batch_size)
    data_batches = process_batch_data(data_items, batch_size=batch_size)

    if "qwen2_72B" in model_path or "70B-Instruct" in model_path:
        tp_size = 4
        gpu_util = 0.9
    elif "gemma" in model_path:
        os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
        tp_size = 4
        gpu_util = 0.95
    elif "phi-4" in model_path:
        tp_size = 2
        gpu_util = 0.95
    elif "DeepSeek-Coder-V2-Lite-Instruct" in model_path:
        tp_size = 1
        gpu_util = 0.90
    else:
        tp_size = 4
        gpu_util = 0.95

    llm = LLM(
        model=model_path,
        tensor_parallel_size=tp_size,
        trust_remote_code=True,
        gpu_memory_utilization=gpu_util,
        max_model_len=max_token_length,
    )

    tokenizer = llm.get_tokenizer()
    stop_tokens = ["</FINAL_ANSWER>", "<|EOT|>"]
    sampling_params = SamplingParams(
        temperature=0.0,
        top_p=1.0,
        max_tokens=3000,
        presence_penalty=0.0,
        frequency_penalty=0.0,
        stop=stop_tokens,
        stop_token_ids=[tokenizer.eos_token_id],
    )
    print("Sampling Params:", sampling_params)

    use_chat_template = "sqlcoder-7b-2" not in model_path.lower()

    for idx, (prompts_in_batch, items_in_batch) in enumerate(
        tqdm(zip(prompt_batches, data_batches))
    ):
        print(f"Inferencing batch {idx}...", flush=True)
        if not isinstance(prompts_in_batch, list):
            prompts_in_batch = [prompts_in_batch]

        if use_chat_template:
            conversations = [
                tokenizer.apply_chat_template(
                    [{"role": "user", "content": prompt}],
                    tokenize=False,
                    add_generation_prompt=True,
                )
                for prompt in prompts_in_batch
            ]
        else:
            conversations = prompts_in_batch

        invalid_response = True
        attempts = 0
        completions = None

        while invalid_response:
            attempts += 1
            with torch.no_grad():
                completions = llm.generate(conversations, sampling_params)
            for completion in completions:
                generated_text = completion.outputs[0].text
                if not filter_output(generated_text):
                    invalid_response = True
                    break
                else:
                    invalid_response = False
            if attempts > MAX_TRY:
                print("Exceeded max invalid output attempts.")
                invalid_response = False

        batch_generated_texts = []
        for completion in completions:
            generated_text = completion.outputs[0].text
            if not filter_output(generated_text):
                generated_text = INVALID_ANS
            batch_generated_texts.append(generated_text)

        append_response_to_file(items_in_batch, batch_generated_texts, output_file, idx)

    print("All batch inference completed.")


def parse_args() -> argparse.Namespace:
    """Parse command-line options."""
    parser = argparse.ArgumentParser(
        description="Run vLLM batch inference over a JSONL prompt file."
    )
    # ── required paths ───────────────────────────────────────
    parser.add_argument(
        "--model_path",
        help="HF model folder or hub repo",
    )
    parser.add_argument(
        "--prompt_path",
        help="JSONL containing prompts",
    )
    parser.add_argument(
        "--output_path",
        help="Where to save responses",
    )
    parser.add_argument(
        "--batch_size", type=int, default=100, help="Batch size (default: 100)"
    )
    parser.add_argument(
        "--max_token_length",
        type=int,
        default=15000,
        help="Max model length (default: 15000)",
    )
    parser.add_argument(
        "--cuda_visible_devices",
        default="4,5,6,7",
        help='CUDA device list for this run (default: "4,5,6,7")',
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_devices

    infer_batch(
        model_path=args.model_path,
        batch_size=args.batch_size,
        input_file=args.prompt_path,
        output_file=args.output_path,
        max_token_length=args.max_token_length,
    )
