import json
import random
from xml.parsers.expat import model
random.seed(42)

def text_to_num_tokens(tokenizer, prompt):
    return tokenizer(prompt, return_tensors="pt").input_ids.size(1)

def text_to_num_tokens(tokenizer, prompt):
    return tokenizer(prompt, return_tensors="pt").input_ids.size(1)

def generate_random_segments(N):
    segment_length = 0.99 / N
    random_numbers = []
    for j in range(N):
        random_number = j * segment_length + random.uniform(0, segment_length)
        random_numbers.append(random_number)
    return random_numbers


def get_input_ctx_multi(tokenizer, ctx_len, last_words="", needles=[], meta_prompt=True):
    depths = generate_random_segments(len(needles))
    text_list = json.load(open('data/PaulGrahamEssays.json', 'r'))["text"].split(".")
    res = []
    if meta_prompt:
        task_description = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n"
    else:
        task_description = ""
    if ctx_len > 66000:
        step = 50
    elif ctx_len > 16000:
        step = 20
    else:   
        step = 1
            
    for i in range(0, len(text_list), step):
        text_chunk = ".".join(text_list[i:(i + step)])
        context = ""
        needle_places = [0]
        for j, depth in enumerate(depths):
            needle_places.append(int(depth * len(res)))
        for j in range(1, len(needles)+1):
            begin = needle_places[j-1]
            end = needle_places[j]
            context += ".".join(res[begin:end]) + ". " + needles[j-1]
        context +=  ". ".join(res[needle_places[-1]:]) + ". \n" + last_words
        chunk_len = text_to_num_tokens(tokenizer, context) + text_to_num_tokens(tokenizer, task_description)
        if chunk_len > ctx_len:  # for system prompt and output
            if context[0] == ".":
                context = context[1:]
            return task_description + context
        else:
            res.append(text_chunk)


def get_config(eval_set, model_path):
    if eval_set == "number-4":
        if "llama3.1" in model_path:
            num_4_question = "\n\nWhat are the magic numbers mentioned in the provided text?\n "
        else:
            num_4_question = "\n\nWhat are the magic numbers mentioned in the provided text?\nThe numbers are"
        print(num_4_question)
        begin = 100000
        end = begin*10
        num_4_keys = [str(random.randint(begin, end)) for _ in range(4)]
        num_4_needles = [f" One of the magic number is: {num_4_keys[i]}. " for i in range(4)]
        print("using 4 needles")
        return num_4_question,  num_4_needles,  num_4_keys

