import os
import yaml
from io import StringIO
import time
import torch

from transformers import AutoTokenizer,  GenerationConfig, DataCollatorWithPadding
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from reader.llama_prompt_generator import prompt_generator

def vllm_reader(model, tokenizer, params, query, system_prompt):

    # generate instruction prompt
    prompt = prompt_generator(query, system_prompt, params["system_prompt"])
    
    # tokenization
    input_ids = tokenizer.apply_chat_template(
    prompt,
    add_generation_prompt=True,
    tokenize = False,
    enable_thinking = False
    )
    
    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    # define sampling params
    sampling_params = SamplingParams(
        top_p=params["top_p"],
        temperature=params["temperature"],
        max_tokens=params["max_new_tokens"],
        seed = 42
    )

    if params["enable_lora"]:
        outputs = model.generate(
            input_ids,
            sampling_params,
            lora_request = LoRARequest(params["lora_path"],1,params["lora_path"]),
        )

        response = outputs.outputs
        
        return response
    
    else:
        outputs = model.generate(
            input_ids,
            sampling_params,
        )
        response = outputs[0].outputs[0].text
        
        return response

def vllm_reader_batch(model, tokenizer, params, query_list, system_prompt_list, batch_size):

    # create empty list to store processed prompt
    prompt_list = []
    
    # Format prompt
    for i in range(len(query_list)):
        prompt = prompt_generator(query_list[i], system_prompt_list[i], params["system_prompt"])
        prompt_list.append(prompt)

    prompts = [tokenizer.apply_chat_template(
    prompt,
    add_generation_prompt=True,
    tokenize = False,
    enable_thinking=False
    ) for prompt in prompt_list]
    response = []

    # Create batches by tokenizing each subset of `texts`
    batches = [
        prompts[i : i + batch_size]
        for i in range(0, len(prompts), batch_size)
    ]
    # define sampling params
    sampling_params = SamplingParams(
        top_p=params["top_p"],
        temperature=params["temperature"],
        max_tokens=params["max_new_tokens"],
        repetition_penalty=params["repetition_penalty"],
        seed = 42,
    )
    # Generate responses
    response = []
    if params["enable_lora"]:
        for batch in batches:
            output = model.generate(
                batch,
                sampling_params,
                lora_request = LoRARequest(params["lora_path"],1,params["lora_path"]),

            )
            # Append the generated text to the response list
            for i in range(len(batch)):
                response.append(output[i].outputs[0].text)

        return response
    else:
        for batch in batches:
            output = model.generate(
                batch,
                sampling_params,
            )
            # Append the generated text to the response list
            for i in range(len(batch)):
                response.append(output[i].outputs[0].text)
        return response




    