import os
import sys 
import yaml
from io import StringIO
import time
import torch
import deepspeed
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from transformers import AutoTokenizer,  GenerationConfig, DataCollatorWithPadding
from reader.llama_prompt_generator import prompt_generator

def ds_reader(model, tokenizer, params, query, system_prompt):

    local_rank = int(os.getenv('LOCAL_RANK'))
    # generate instruction prompt
    prompt = prompt_generator(query, system_prompt, params["system_prompt"])
    
    # tokenization
    input_ids = tokenizer.apply_chat_template(
    prompt,
    add_generation_prompt=True,
    return_tensors="pt"
    ).to(f'cuda:{local_rank}')
    
    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    outputs = model.generate(
        input_ids,
        top_p = params["top_p"],
        temperature = params["temperature"],
        do_sample = params["do_sample"],
        max_new_tokens = params["max_new_tokens"],
    )

    response = outputs[0][input_ids.shape[-1]:]
    response = tokenizer.decode(response, skip_special_tokens=True)

    return response

def ds_reader_batch(model, tokenizer, params, query_list, system_prompt_list, batch_size):
    
    local_rank = int(os.getenv('LOCAL_RANK'))
    # create empty list to store processed prompt
    prompt_list = []
    
    # Format prompt
    for i in range(len(query_list)):
        prompt = prompt_generator(query_list[i], system_prompt_list[i], params["system_prompt"])
        prompt_list.append(prompt)

    # Set a padding token
    tokenizer.pad_token_id = tokenizer.eos_token_id  
    eos_token_id = tokenizer.eos_token_id
    
    # define data collator
    collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True)

    texts = tokenizer.apply_chat_template(
    prompt_list,
    add_generation_prompt=True,
    tokenize = False
    )

    response = []

    # Create batches by tokenizing each subset of `texts`
    batches = [
        tokenizer(texts[i:i + batch_size], truncation=True, padding=False, return_tensors=None)
        for i in range(0, len(texts), batch_size)
    ]

    # Use data collator to pad batches( note that the collator accept type list and return tensor = "pt")
    collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True, return_tensors="pt")
    padded_batches = [collator(batch) for batch in batches]

    # Generate responses
    response = []
    
    for batch in padded_batches:
        # Move inputs to the device
        inputs = {key: val.to(f'cuda:{local_rank}') for key, val in batch.items()}

        # Decode inputs for slicing later
        temp_texts = tokenizer.batch_decode(inputs["input_ids"], skip_special_tokens=True)

        # Generate tokens
        gen_tokens = model.generate(
            input_ids=inputs["input_ids"], 
            attention_mask=inputs["attention_mask"],
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=eos_token_id,
            top_p=params["top_p"],
            temperature=params["temperature"],
            do_sample=params["do_sample"],
            max_new_tokens=params["max_new_tokens"],
        )

        # Decode and slice the responses
        response_ = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
        response_ = [i[len(temp_texts[idx]):] for idx, i in enumerate(response_)]
        response.append(response_)
        del inputs

    # Step 7: Flatten responses
    final_response = [element for sublist in response for element in sublist]

    return final_response



    