import argparse
import logging
import time
from typing import List
import json

try:
    from vllm import LLM, SamplingParams
    VLLM_AVAILABLE = True
except ImportError:
    print("vLLM is not installed. Please install it using: pip install vllm")
    VLLM_AVAILABLE = False
    exit(1)

# --- Basic Logging Setup ---
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - [%(levelname)s] - %(message)s'
)
logger = logging.getLogger(__name__)


# runner.py
# from __future__ import annotations
import os, sys, json, csv, time, argparse, asyncio
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
import logging
from openai import AsyncOpenAI

from data_handler import BinaryDataGenerator, ScalarSortedDataGenerator, Dyck2DataGenerator, PatternBasedDataGenerator, PrimeDataGenerator
from data_handler import BinaryDataGenerator, \
    ScalarSortedDataGenerator, \
    Dyck2DataGenerator, \
    PatternBasedDataGenerator, \
    PalindromeDataGenerator, \
    PrimeDataGenerator, \
    PrimeBinaryDataGenerator, \
    PrimeDecimalNonPrimeNoSmallDivisorsDataGenerator, \
    PrimeDecimalTailRestrictedDataGenerator, \
    CoprimeRandomSplitDecimalDataGenerator

from typing import Mapping

# =========================
# Prompt (cache-friendly)
# =========================
def build_user_prompt(data_examples, seq_len: int, decimal: bool = False, incontext: bool = False, tdata: any = None) -> str:
    if incontext:
        if decimal:
            problem_statement = (
                    f"**Problem Statement:**\n"
                    f"Given a sequence of input vectors (binary, length {seq_len}) and their corresponding scalar binary outputs ('0' or '1'), "
                    f"you have to learn a hypothesis that approximates the underlying relationship. "
                    f"Given the data below, determine what is the label for the given string and output ONLY the label. "
                )
        else:
            problem_statement = (
                    f"**Problem Statement:**\n"
                    f"Given a sequence of input vectors (decimal, length {seq_len}) and their corresponding scalar binary outputs ('0' or '1'), "
                    f"you have to learn a hypothesis that approximates the underlying relationship. "
                    f"Given the data below, determine what is the label for the given string and output ONLY the label. "
                )
    else:
        if decimal:
            problem_statement = (
                f"**Problem Statement:**\n"
                f"Given a sequence of input vectors (decimal, length {seq_len}) and their corresponding scalar binary outputs ('0' or '1'), "
                f"find a concise Python function `f(x)` that accurately approximates the underlying relationship. "
                f"The function should not be a trainable model, but a direct logical or mathematical representation of the target function."
            )
        else:
            problem_statement = (
            f"**Problem Statement:**\n"
            f"Given a sequence of input vectors (binary, length {seq_len}) and their corresponding scalar binary outputs ('0' or '1'), "
            f"find a concise Python function `f(x)` that accurately approximates the underlying relationship. "
            f"The function should not be a trainable model, but a direct logical or mathematical representation of the target function."
            )
    prompt = f"{problem_statement}\n"

    # 2. Add the fixed data examples
    prompt += "**Data Examples:**\n"
    prompt += "```\n"
    prompt += "\n".join(data_examples)
    prompt += "\n```\n\n"
    
    if incontext:
        prompt += "**Test Input:**\n"
        prompt += "```\n"
        prompt += tdata[:-5]
        prompt += "\n```\n\n"

    # 5. Add the concluding instruction/disclaimer
    # prompt += "Your output should ONLY be the Python code for the target function, and nothing else."
    prompt += """You must output ONLY a single JSON object: {"label": "<your predicted label>"}."""
    return prompt

# =========================
# Data
# =========================

FUNCTION_NAME_MAPPING = {"fn_a" : "parity_all",
                         "fn_b" : "parity_first_half",
                         "fn_c" : "patternmatch1",
                         "fn_d" : "patternmatch2",
                         "fn_e" : "parity_rand_3",
                         "fn_f" : "parity_rand_10",
                         "fn_g" : "palindrome",
                         "fn_h" : "dyck2",
                         "fn_i" : "prime_decimal",
                         "fn_j" : "prime_binary",
                         "fn_k" : "coprimality",
                         "fn_l" : "automata_parity",
                         "fn_m" : "high_divisible",
                         "fn_n" : "prime_decimal_tf_check",
                        }

DECIMAL_FN = ["fn_i", "fn_m", "fn_n"]


def generate_data( function: str, seq_len: int, data_size: int ) -> str:
    """"Generates train data (100) based on function given for the sequence length"""
    if FUNCTION_NAME_MAPPING[function] == 'dyck2':
        data_gen = Dyck2DataGenerator(seq_len, data_size)
    elif FUNCTION_NAME_MAPPING[function] in ['patternmatch1', 'patternmatch2']:
        data_gen = PatternBasedDataGenerator(seq_len, data_size)
    elif FUNCTION_NAME_MAPPING[function] == "palindrome":
        data_gen = PalindromeDataGenerator( seq_len, data_size )
    elif FUNCTION_NAME_MAPPING[function] == "prime_decimal":
        data_gen = PrimeDataGenerator( seq_len, data_size, True )
    elif FUNCTION_NAME_MAPPING[function] == "prime_binary":
        data_gen = PrimeBinaryDataGenerator( seq_len, data_size )
    elif FUNCTION_NAME_MAPPING[function] == "prime_decimal_tf_check":
        data_gen = PrimeDecimalTailRestrictedDataGenerator( seq_len, data_size, True )
    elif FUNCTION_NAME_MAPPING[function] == "high_divisible":
        data_gen = PrimeDecimalNonPrimeNoSmallDivisorsDataGenerator( seq_len, data_size, True )
    elif FUNCTION_NAME_MAPPING[function] == "coprimality":
        data_gen = CoprimeRandomSplitDecimalDataGenerator( seq_len, data_size, True )
    else:
        data_gen = BinaryDataGenerator(FUNCTION_NAME_MAPPING[function], seq_len, data_size)
    
    if FUNCTION_NAME_MAPPING[function] == 'patternmatch2':
        train_data =  data_gen.generate_data('00111111')
    else:
        train_data =  data_gen.generate_data()
    return [f"{''.join(sample['Input'].tolist())} -> {sample['Output']}" for sample in train_data]


def run_batch_inference(
    model_name: str,
    prompt: str,
    num_requests: int = 100,
    batch_size: int = 8,
    max_new_tokens: int = 512,
    temperature: float = 0.8,
    top_p: float = 0.95,
    output_file: str = None,
    max_model_len: int = None,
    tensor_parallel_size: int = 1
):
    """
    Initializes a vLLM model and runs a single prompt multiple times in batches.

    Args:
        model_name (str): The Hugging Face model identifier.
        prompt (str): The single prompt to be executed multiple times.
        num_requests (int): The total number of times to run the prompt.
        batch_size (int): The number of requests to process in a single batch.
        max_new_tokens (int): The maximum number of tokens to generate for each response.
        temperature (float): The temperature for sampling. 0 for greedy decoding.
        top_p (float): The top-p value for nucleus sampling.
        tensor_parallel_size (int): The number of GPUs to use for tensor parallelism.
    """
    logger.info(f"--- Starting vLLM Batch Inference ---")
    logger.info(f"Model: {model_name}, Total Requests: {num_requests}, Batch Size: {batch_size}")
    # 1. Initialize the vLLM engine
    # This is done only once and can take some time depending on the model size.
    logger.info(f"Loading vLLM engine for model '{model_name}'...")
    try:
        llm = LLM(
            model=model_name,
            tensor_parallel_size=tensor_parallel_size,
            trust_remote_code=True,
            max_model_len=max_model_len,
            # Add other vLLM options here if needed, e.g., gpu_memory_utilization
            # gpu_memory_utilization=0.90
        )
    except Exception as e:
        logger.error(f"Failed to initialize vLLM engine: {e}", exc_info=True)
        return
    logger.info("vLLM engine loaded successfully.")

    # 2. Define the sampling parameters for generation
    sampling_params = SamplingParams(
        temperature=temperature,
        top_p=top_p,
        max_tokens=max_new_tokens
    )
    logger.info(f"Sampling Parameters: temp={temperature}, top_p={top_p}, max_tokens={max_new_tokens}")
    # 3. Prepare the list of prompts
    # Since all calls are independent with the same prompt, we just duplicate it.
    prompts_to_run = []
    for L in [20, 25, 30, 50, 100]:
        for fn in ["fn_a", "fn_f", "fn_g", "fn_l", "fn_n", "fn_i"]:
            train_data = generate_data( fn, L, 200 )
            test_data = generate_data( fn, L, 100 )
            for tdata in test_data:
                if fn in DECIMAL_FN:
                    prompts_to_run.append( [build_user_prompt( train_data, L, True, True, tdata ), tdata[-1], L, fn] )
                else:
                    prompts_to_run.append( [build_user_prompt( train_data, L, False, True, tdata ), tdata[-1], L, fn] )
    num_requests = len(prompts_to_run)
    logger.info(f"Prepared {len(prompts_to_run)} prompts for generation.")

    # 4. Process prompts in batches
    all_outputs: List[str] = []
    total_processing_time = 0
    num_batches = (num_requests + batch_size - 1) // batch_size # Ceiling division

    logger.info(f"--- Starting generation in {num_batches} batches ---")
    start_time_total = time.perf_counter()

    for i in range(0, num_requests, batch_size):
        batch_start_time = time.perf_counter()
        
        # Get the slice of prompts for the current batch
        if len(prompts_to_run) < (i + batch_size):
            current_batch_prompts = [item[0] for item in prompts_to_run[i:len(prompts_to_run)]]
            current_batch_labels = [(item[1], item[2], item[3]) for item in prompts_to_run[i:len(prompts_to_run)]]
        else:
            current_batch_prompts = [item[0] for item in prompts_to_run[i:i+batch_size]]
            current_batch_labels = [(item[1], item[2], item[3]) for item in prompts_to_run[i:i+batch_size]]
        batch_num = (i // batch_size) + 1
        
        logger.info(f"Processing batch {batch_num}/{num_batches} (size: {len(current_batch_prompts)})...")

        # The core vLLM call
        request_outputs = llm.generate(current_batch_prompts, sampling_params, use_tqdm=False)
        
        # Extract the generated text from the output objects
        for output, test_labels in zip( request_outputs, current_batch_labels ):
            generated_text = output.outputs[0].text.strip()
            all_outputs.append([ generated_text, test_labels ])
            
        batch_end_time = time.perf_counter()
        batch_duration = batch_end_time - batch_start_time
        total_processing_time += batch_duration
        
        logger.info(f"Batch {batch_num} finished in {batch_duration:.2f} seconds.")

    end_time_total = time.perf_counter()
    total_duration = end_time_total - start_time_total

    # VVVVV ADD THIS ENTIRE BLOCK VVVVV
    # 5. Save outputs to a file if a path is provided
    if output_file:
        logger.info(f"Saving {len(all_outputs)} outputs to {output_file}...")
        try:
            with open(output_file, "w", encoding="utf-8") as f:
                # Save as a JSON array of strings
                if output_file.endswith(".json"):
                    json.dump(all_outputs, f, indent=4, ensure_ascii=False)
                # Save as a plain text file with separators
                elif output_file.endswith(".txt"):
                    for i, output_text in enumerate(all_outputs):
                        f.write(f"--- Response {i + 1} ---\n")
                        f.write(f"{output_text}\n\n")
                else:
                    logger.warning(
                        "Unsupported file extension. Saving as plain text. "
                        "Use .json or .txt for formatted output."
                    )
                    f.write("\n".join(all_outputs))
            logger.info(f"Successfully saved outputs to {output_file}.")
        except IOError as e:
            logger.error(f"Error saving outputs to file: {e}", exc_info=True)

    # 6. Display results and performance metrics
    logger.info("--- Batch Inference Complete ---")
    
    # Print the first few results as a sample
    # print("\n--- Sample of Generated Outputs ---")
    # for idx, output_text in enumerate(all_outputs[:5]):
    #     print(f"Response {idx + 1}:\n{output_text}\n" + "-"*20)

    # if len(all_outputs) > 5:
    #     print(f"... and {len(all_outputs) - 5} more responses.")

    # # Print performance summary
    # throughput = num_requests / total_duration
    # print("\n--- Performance Summary ---")
    # print(f"Total requests processed: {len(all_outputs)}")
    # print(f"Total time taken:         {total_duration:.2f} seconds")
    # print(f"Throughput:               {throughput:.2f} requests/second")
    # print("-----------------------------\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run batch inference with vLLM.")
    
    parser.add_argument(
        "--model",
        type=str,
        default="Qwen/QwQ-32B",
        # default="deepseek-ai/deepseek-coder-33b-instruct",
        help="Hugging Face model ID to use."
    )
    parser.add_argument(
        "--prompt",
        type=str,
        default="Write a short story about a robot who discovers music.",
        help="The prompt to run for all requests."
    )
    parser.add_argument(
        "--num-requests",
        type=int,
        default=100,
        help="Total number of requests to generate."
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=8,
        help="Number of prompts to process in a single batch."
    )
    parser.add_argument(
        "--max-new-tokens",
        type=int,
        default=1024,
        help="Maximum number of new tokens to generate per response."
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.2,
        help="Sampling temperature. Set to 0 for greedy decoding."
    )
    parser.add_argument(
        "--top-p",
        type=float,
        default=0.95,
        help="Nucleus sampling top-p value."
    )
    parser.add_argument(
        "--output-file",
        type=str,
        default=None,
        help="Path to save the generated outputs. Supports .txt and .json extensions."
    )
    parser.add_argument(
        "--max-model-len",
        type=int,
        default=25000,
        help="The maximum total sequence length for the model. Reduces memory usage."
    )
    parser.add_argument(
        "--tensor-parallel-size",
        type=int,
        default=1,
        help="Number of GPUs for tensor parallelism."
    )

    args = parser.parse_args()

    run_batch_inference(
        model_name=args.model,
        prompt=args.prompt,
        num_requests=args.num_requests,
        batch_size=args.batch_size,
        max_new_tokens=args.max_new_tokens,
        temperature=args.temperature,
        top_p=args.top_p,
        tensor_parallel_size=args.tensor_parallel_size,
        output_file=args.output_file,
        max_model_len=args.max_model_len
    )