import json

def preprocess(input_file, output_file, task_name):
    with open(input_file, 'r') as f1, open(output_file, 'w') as f2:
        lines = f1.readlines()[:200]
        for line in lines:
            line = json.loads(line)
            
            if task_name == 't1':
                # prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|> Please give answer to the following question about knowledge. Note: If you are asked for true or false, just answer \"true\" or \"false\" only. If you are asked for similarity, just answer with the entity name only. Do not give anything other than the answers. Question:\n{line['context']}\n{line['input']} <|eot_id|><|start_header_id|>assistant<|end_header_id|>"
                prompt = f"Please give answer to the following question about knowledge. Note: If you are asked for true or false, just answer \"true\" or \"false\" only. If you are asked for similarity, just answer with the entity name only. Do not give anything other than the answers. Question:\n{line['context']}\n{line['input']}"
                ground_truth = line['outputs']
            elif task_name == 't2':
                prompt = f"Please answer the following question within 200 words:\n{line['context']}\n{line['input']}"
                # prompt = line['input']
                ground_truth = line['outputs'][0]
            elif task_name == 't3':
                # prompt = f"Please complete the code given below. \n{line['prompt']}Next line of code:\n"
                # ground_truth = line['canonical_solution']
                # prompt = line['prompt']
                prompt = f"Answer the following math word problem. Show your reasoning step by step, and finish with the final answer. Question: {line['question']}\nAnswer: "
                ground_truth = line['answer']
            elif task_name == 't4':
                # prompt = f"You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{line['context']}\n\nNow, write a one-page summary of all the news.\n\nSummary:"
                prompt = f"Summarize the following news article in English, focusing on the main facts and key points.\n\nNews:\n{line['article']}\n\nSummary:"
                ground_truth = line['article']
            # if task_name == 't1':
            #     prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|> {prompt_1} <|eot_id|><|start_header_id|>assistant<|end_header_id|>"
            # elif task_name == 't2':
            #     prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|> {prompt_2} <|eot_id|><|start_header_id|>assistant<|end_header_id|>"
            # elif task_name == 't3':
            #     prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|> {prompt_3} <|eot_id|><|start_header_id|>assistant<|end_header_id|>"
            # elif task_name == 't4':
            #     prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|> {prompt_4} <|eot_id|><|start_header_id|>assistant<|end_header_id|>"

            result = {
                'prompt': prompt,
                'natural_text': ground_truth,
                # 'natural_text': [line['highlights']],
                # 'prompt_length': len(prompt) if task_name != 't1' else len(prompt) - len("<|begin_of_text|><|start_header_id|>" + "<|end_header_id|>" + "<|eot_id|><|start_header_id|>" + "<|end_header_id|>")
                'prompt_length': len(prompt)
            }

            f2.write(json.dumps(result, ensure_ascii=False) + '\n')

if __name__ == '__main__':
    # t1
    input_file = './datasets/t1/1-2_konwledge_understanding.jsonl'
    output_file = './datasets/t1/t1.jsonl'
    preprocess(input_file, output_file, 't1')
    
    # t2
    input_file = './datasets/t2/2-1_longform_qa.jsonl'
    output_file = './datasets/t2/t2.jsonl'
    preprocess(input_file, output_file, 't2')

    # t3
    input_file = './datasets/t3/gsm8k.jsonl'
    output_file = './datasets/t3/t3.jsonl'
    preprocess(input_file, output_file, 't3')

    # t4
    input_file = './datasets/t4/cnn_dailymail.jsonl'
    output_file = './datasets/t4/t4.jsonl'
    preprocess(input_file, output_file, 't4')

    # import pandas as pd

    # # 读取 parquet 文件
    # df = pd.read_parquet("./datasets/t3/test-00000-of-00001.parquet")

    # # 保存成 JSON 文件（数组形式）
    # df.to_json("./datasets/t3/gsm8k.jsonl", orient="records", lines=True, force_ascii=False)