import gzip
import json
import random


middle_id = 4
doc_num = 10
get_random = True
path = './qa/doc{}/nq-open-{}_total_documents_gold_at_{}.jsonl.gz'.format(doc_num, doc_num, middle_id)

dump_data_path = './qa/doc{}/nq-open-{}_total_documents_gold_at_{}.json'.format(doc_num, doc_num, middle_id)
if get_random:
   dump_data_path = dump_data_path[:-5] + '_random.json'
dump_data = []

prompt = ["""[INST] Write a high-quality one-sentence answer for the given question using only the provided search results (some of which might be irrelevant).

""", """

Question: {question}?

"""]

with gzip.open(path, "r") as f:
   line = f.readline().strip()
   while line:
      data = json.loads(line)
      origin_doc_prompt = ""
      ps_doc_prompt = []
      if get_random:
         random.shuffle(data['ctxs'])
      for doc_idx, doc in enumerate(data['ctxs']):
         #   temp_doc_prompt = f"Document [{doc_idx+1}](Title: {doc['title']}) {doc['text']}\n"
        temp_doc_prompt = f"Document (Title: {doc['title']}) {doc['text']}\n"
        origin_doc_prompt += temp_doc_prompt
        ps_doc_prompt.append(temp_doc_prompt)
      # final_origin_prompt = prompt[0] + prompt[1].format(question=data['question']) + origin_doc_prompt + "[/INST]"
      # final_ps_prompt = [prompt[0], prompt[1].format(question=data['question']), ps_doc_prompt, "[/INST]"]
      final_origin_prompt = prompt[0] + origin_doc_prompt + prompt[1].format(question=data['question']) + "[/INST]"
      final_ps_prompt = [prompt[0], ps_doc_prompt, prompt[1].format(question=data['question']), "[/INST]"]

      # final_ps_prompt = ["Write documents, here are some examples: \n\n\n", [e + '-'*10 for e in ps_doc_prompt],  "Document [11](Title: Nobel Prize in China)"]

      # final_origin_prompt = origin_doc_prompt + prompt[1].format(question=data['question']) + "[/INST]"
      # final_ps_prompt = [ps_doc_prompt, prompt[1].format(question=data['question']), "[/INST]"]

      answers = data['answers']
      dump_data.append([
         final_origin_prompt,
         final_ps_prompt,
         answers
      ])
      line = f.readline().strip()
print(len(dump_data))
with open(dump_data_path, 'w') as f:
    json.dump(dump_data, f, indent=4)