import json
import random
import string

from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3.1-8B-Instruct')

def make_new_keys():
    def make_key():
        return ''.join(random.choices(string.digits, k=5))
    template = "The {name} is {num}. Remember it. {num} is the {name}."
    input_template = "What is the {name}?\n\nThe {name} is"

    # 4 new keys, with the existing passkey, we have 5 keys.
    name_list = ['private key', 'public key', 'magic number', 'master key']
    key_list = [make_key() for _ in range(4)]

    new_template = [template.format(name=name, num=num) for name, num in zip(name_list, key_list)]
    multi_turns = [
        {
            'input': input_template.format(name=name),
            'answer': key
        }
        for key, name in zip(key_list, name_list)
    ]

    return new_template, multi_turns

def find_all(string, substring):
    start = 0
    while True:
        start = string.find(substring, start)
        if start == -1:
            return
        yield start
        start += 1

# make summary + needle test

with open('data/multi_turn_summary.jsonl', 'r') as f:
    summary_data = [json.loads(line) for line in f]

## insert needle to context
new_data = []
for doc_idx, case in enumerate(summary_data):
    needles, multi_turns = make_new_keys()

    # find the start of each doc
    start_idxes = find_all(case['context'], 'Paper Title: ')
    start_idxes = list(start_idxes)

    # choose five position among positions, evenly
    selected_positions = start_idxes[::max(1, len(start_idxes) // 4)][:4]

    # insert needle
    for idx, pos in enumerate(selected_positions):
        confirm_pos = case['context'].find('Paper Title: ', pos)
        case['context'] = case['context'][:confirm_pos] + '\n\n' + needles[idx] + '\n\n' + case['context'][confirm_pos:]
    
    turns = case['multi_turns']

    # mix the new multi_turns to the existing one
    # Interleave existing turns with new multi_turns
    combined_turns = []
    max_length = max(len(turns), len(multi_turns))
    
    for i in range(max_length):
        # if idx % 2 == 0, add summary first, then passkey
        # if idx % 2 == 1, add passkey first, then summary
        if doc_idx % 2 == 0:
            if i < len(turns):
                combined_turns.append({**turns[i], 'task': 'multi_turn_summary'})
            if i < len(multi_turns):
                combined_turns.append({**multi_turns[i], 'task': 'multi_turn_passkey'})
        else:
            if i < len(multi_turns):
                combined_turns.append({**multi_turns[i], 'task': 'multi_turn_passkey'})
            if i < len(turns):
                combined_turns.append({**turns[i], 'task': 'multi_turn_summary'})
    
    # Update the case with the interleaved turns
    case['multi_turns'] = combined_turns[:8]

    # check the length of the context
    context = case['context']
    context_length = len(tok.encode(context))
    if context_length > 125_000:
        continue
    print(context_length)
    new_data.append(case)

# Write the new data to disk
output_file = 'data/multi_turn_summary_with_needles.jsonl'
with open(output_file, 'w') as f:
    for item in new_data:
        json.dump(item, f)
        f.write('\n')

print(f"Data with needles has been written to {output_file}")

