import json
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")

# salient_translation
global_template = "{query_text}\n\nAnswer: {answer}\n\n"

with open('many-shot/salient_translation_error_detection/32k/corpus.jsonl') as f:
# with open('many-shot/tracking_shuffled_objects_seven_objects/32k/corpus.jsonl') as f:
# with open('many-shot/date_understanding/32k/corpus.jsonl') as f:
    data = [json.loads(jline) for jline in f]

global_context = ''
for i, eg in enumerate(data):
    example = global_template.format(query_text=eg['query_text'], answer=eg['answers'][0][0])
    global_context += example

## load test examples
with open('many-shot/salient_translation_error_detection/32k/test_queries.jsonl') as f:
    test_data = [json.loads(jline) for jline in f]

all_test_cases = []
for i in range(0, len(test_data), 5):
    test_cases = test_data[i:i+5]
    multi_turns = []
    for test_case in test_cases:
        input_ = test_case['query_text'] + '\n\n' + 'Answer: '
        multi_turns.append(
            {
                "input": input_,
                "answer": test_case['answers'][0][0][1:-1]
            }
        )
    all_test_cases.append(
        {
            "context": global_context,
            "multi_turns": multi_turns,
            "id": len(all_test_cases),
            "task": "salient_translation_error_detection"
        }
    )

## save to file
# with open('data/multiturn/multi_turn_many_shot_salient.jsonl' , 'w') as f:
#     for test_case in all_test_cases:
#         f.write(json.dumps(test_case) + '\n')

# tracking_shuffled_objects_seven_objects
with open('many-shot/tracking_shuffled_objects_seven_objects/32k/corpus.jsonl') as f:
    data = [json.loads(jline) for jline in f]

global_context = ''
for i, eg in enumerate(data):
    example = global_template.format(query_text=eg['query_text'], answer=eg['answers'][0][0])
    global_context += example

## load test examples
with open('many-shot/tracking_shuffled_objects_seven_objects/32k/test_queries.jsonl') as f:
    test_data = [json.loads(jline) for jline in f]

# all_test_cases = []
for i in range(0, len(test_data), 5):
    test_cases = test_data[i:i+5]
    multi_turns = []
    for test_case in test_cases:
        input_ = test_case['query_text'] + '\n\n' + 'Answer: '
        multi_turns.append(
            {
                "input": input_,
                "answer": test_case['answers'][0][0][1:-1]
            }
        )
    all_test_cases.append(
        {
            "context": global_context,
            "multi_turns": multi_turns,
            "id": len(all_test_cases),
            "task": "tracking_shuffled_objects_seven_objects"
        }
    )

## save to file
# with open('data/multiturn/multi_turn_many_shot_tracking.jsonl' , 'w') as f:
#     for test_case in all_test_cases:
#         f.write(json.dumps(test_case) + '\n')

# date_understanding
with open('many-shot/date_understanding/32k/corpus.jsonl') as f:
    data = [json.loads(jline) for jline in f]

global_context = ''
for i, eg in enumerate(data):
    example = global_template.format(query_text=eg['query_text'], answer=eg['answers'][0][0])
    global_context += example

## load test examples
with open('many-shot/date_understanding/32k/test_queries.jsonl') as f:
    test_data = [json.loads(jline) for jline in f]

for i in range(0, len(test_data), 5):
    test_cases = test_data[i:i+5]
    multi_turns = []
    for test_case in test_cases:
        input_ = test_case['query_text'] + '\n\n' + 'Answer: '
        multi_turns.append(
            {
                "input": input_,
                "answer": test_case['answers'][0][0][1:-1]
            }
        )
    all_test_cases.append(
        {
            "context": global_context,
            "multi_turns": multi_turns,
            "id": len(all_test_cases),
            "task": "date_understanding"
        }
    )
    
## save to file
with open('data/multiturn/multi_turn_many_shot.jsonl' , 'w') as f:
    for test_case in all_test_cases:
        f.write(json.dumps(test_case) + '\n')