# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess the Law dataset to parquet format
"""

import json
import os
import datasets

import argparse
from tqdm import tqdm


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_dir', default='experiments/data/math_total')
    parser.add_argument('--chat_model', action='store_false', help='Whether to use chat model')

    args = parser.parse_args()

    data_source = 'math_dapo'

    dataset = datasets.load_from_disk("experiments/raw_data/dapo_math")["train"]

    processed_data = []
    for item in tqdm(dataset):
        processed_data.append(
            {"input": item['prompt'].strip(),
             "ground_truth": item['solution'].strip()}
        )
    print(processed_data[0])

    # follow R1
    system_prompt = r"""
You must write your answer strictly following the XML-like format defined below. Failure to comply with this format will result in an invalid response.

**Definitions and Rules:**

* `<guideline>`: A container for one or more `<plan>` tags. It sets the objective for the current stage of reasoning.
* `<plan>i:</plan>`: A single, specific, and actionable task or hypothesis to be executed. Multiple plans within a guideline represent parallel exploration.
* `<step>i:</step>`: The detailed execution of the corresponding `<plan>i`. The number of `<step>` tags must exactly match the number of `<plan>` tags in the preceding `<guideline>`. **Crucially, the content of this step must be generated *as if* you have no knowledge of the content of its sibling steps.**
* `<takeaway>`: Use the `<takeaway>` tag to analyze steps and generate a *concise* summary. Compare the outcomes of the different steps, identify the most promising path, or consolidate the findings. The takeaway determines the next action: either proceeding to the next `<guideline>` for deeper analysis or moving to the final answer. **Only analyze the executed steps, NO additional computation or reasoning is allowed here.**
* After analysis, add the final, user-facing conclusion that summarizes the entire logical journey from all preceding steps and takeaways into a clear, final response for the user. For questions with a definitive, short answer, you must include `\\boxed{...}` containing only the final result.

**Strict Requirements:**

1. **Execute Independently:** For each `<plan>`, generate a corresponding `<step>`.
    * Each of the plans and steps must be a *self-contained, complete strategy* for solving the task or subtask.
    * You must treat each `<step>` as an independent execution unit. The reasoning within `<step>i:` must only be based on `<plan>i:`, not on the content of any other `<step>`.
    * The number of `<step>` tags must always equal the number of `<plan>` tags in the directly preceding `<guideline>`.
    * Avoid words implying sequence or dependency (e.g. “then”, “after”, “next”).
2. **Explore in Parallel:** When a problem or previous analysis involves multiple hypotheses, alternative methods, or independent sub-tasks, your next `<guideline>` should contain multiple `<plan>` tags.
    * Each `<plan>` represents a parallel line of reasoning.
    * `<guideline>` with a single `<plan>` is allowed if one plan is needed.
    * Multiple alternative plans are recommended and will be awarded.
3. **Meaningful content:** All tags must contain meaningful content. Do not add any text or explanation between the tags.
4. No other tags or text outside the defined structure is allowed. Directly generate output. Do not wrap it in triple backticks or any other code block formatting.


**Example Output Format:**

<guideline>
<plan>1: [A concise one-sentence, indepedent high-level plan.]</plan>
...
</guideline>
<step>
1: [Detailed analysis trajectory of plan 1. Must be entirely self-contained.]
</step>
...
<takeaway>
[Compare the results from the steps above. Synthesize the findings and determine the next action.]
</takeaway>

<guideline>
<plan>1: [A one-sentence, high-level strategy]</plan>
<plan>2: [A one-sentence, high-level strategy]</plan>
...
</guideline>
<step>
1: [Detailed analysis trajectory of plan 1. Must be entirely self-contained.]
</step>
<step>
2: [Detailed analysis trajectory of plan 2. Must be entirely self-contained.]
</step>
...
<takeaway>
[Compare the results from the steps above. Synthesize the findings and determine the next action.]
</takeaway>

... [more guidelines, steps and takeaways]

[The final, summarized conclusion based on all takeaways. Include definitive answers in \\boxed{...} format.]
"""
    print(system_prompt)

    # add a row to each data item that represents a unique id
    def make_map_fn(split):

        def process_fn(example, idx):

            question = example.pop("input")
            answer = example.pop("ground_truth")

            data = {
                "data_source": data_source,
                "prompt": [
                    # {"role": "system", "content": system_prompt}, 
                    {"role": "user", "content": question + "\n" + system_prompt}
                ],
                "ability": "math",
                "reward_model": {"style": "rule", "ground_truth": answer},
                "extra_info": {
                    'apply_chat_template': True,
                    "split": split,
                    'index': idx,
                    "answer": answer,
                    "question": question,
                }
            }
            return data

        return process_fn
    

    dataset = datasets.Dataset.from_list(processed_data)

    dataset = dataset.map(function=make_map_fn("train"), with_indices=True)

    print(dataset[0])

    local_dir = args.local_dir

    dataset.to_parquet(os.path.join(local_dir, 'train_dapo.parquet'))
