import os
import yaml
from io import StringIO
import time
import torch

from transformers import AutoTokenizer,  GenerationConfig, DataCollatorWithPadding
from reader.llama_prompt_generator import prompt_generator

def ds2_reader(model, tokenizer, params, query, system_prompt):

    # generate instruction prompt
    prompt = prompt_generator(query, system_prompt, params["system_prompt"])
    
    # tokenization
    input = tokenizer.apply_chat_template(
    prompt,
    add_generation_prompt=True,
    tokenize =False
    )

    outputs = model(
        input,
        top_p = params["top_p"],
        temperature = params["temperature"],
        do_sample = params["do_sample"],
        max_new_tokens = params["max_new_tokens"],
    )
    model.destroy()
    try:
        response = outputs[0].generated_text
        return response
    except:
        print("\n")

def ds2_reader_batch(model, tokenizer, params, query_list, system_prompt_list, batch_size):

    prompt_list = []

    for query, system_prompt in zip(query_list, system_prompt_list):
        # generate instruction prompt
        prompt = prompt_generator(query, system_prompt, params["system_prompt"])
        input_ = tokenizer.apply_chat_template(
                prompt,
                add_generation_prompt=True,
                tokenize =False
                )
        
        prompt_list.append(input_)
    
    # Batch the prompt list
    batched_prompts = [
        prompt_list[i:i + batch_size] for i in range(0, len(prompt_list), batch_size)
    ]

    response = []

    for batch in batched_prompts:
        outputs = model(
            batch,
            top_p = params["top_p"],
            temperature = params["temperature"],
            do_sample = params["do_sample"],
            max_new_tokens = params["max_new_tokens"],
        )
        for output in outputs:
            response.append(output.generated_text)
            print(output.generated_text)
            
    return response

