from typing import Dict
import re
from datasets import load_dataset
from transformers import AutoTokenizer
from functools import partial

QUERY_TEMPLATE_NOANSWER = """{Question}""".strip()

def preprocess(text):
    if text is None:
        return " "
    text = text.strip()
    text = text.replace(" [title]", ". ")
    text = re.sub("\\[.*?\\]", "", text)
    text = text.replace("  ", " ")
    return text

def process_cot_example(
    example: Dict,
    tokenizer,
    num_think_tags: int,
):
    thinking_trajectories = example["thinking_trajectories"]
    question = example["question"]
    answers = example["attempt"] 
    prompt = QUERY_TEMPLATE_NOANSWER.format(Question=question)
    # TODO: add different kinds of thinks
    texts = []
    for sot_idx in range(num_think_tags):
        sot_label = sot_idx+1
        s_think_tag = f"<|im_start|><think{sot_label}>\n"
        e_think_tag = f"</think{sot_label}>"
        for tgt_idx in range(len(answers)):
            try:
                think_trajectory = thinking_trajectories[tgt_idx].strip()
                answer = answers[tgt_idx]
            except:
                import ipdb; ipdb.set_trace()
                print(f"Some models didn't generate answer for this question: {question}")
            answer = "Answer: " + answer if "Answer:" not in answer else answer
            text = tokenizer.apply_chat_template([
                {"role": "user", "content": prompt},
                {
                    "role": "assistant", 
                    "content":  s_think_tag + think_trajectory + e_think_tag + "\n<|im_start|>answer\n" + answer.strip()
                }
            ], tokenize=False)
            texts.append(text)
    assert len(texts) == (num_think_tags*len(thinking_trajectories))
    return dict(text=texts)

def mathcot_sft(num_think_tags: int, upload_data_path: str, num_proc: int,
                download_data_path):

    dataset = load_dataset(download_data_path, download_mode='force_redownload')
    if 'train' in dataset:
        dataset = dataset['train']
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-32B-Instruct")
    process_example_map = partial(process_cot_example, tokenizer=tokenizer, num_think_tags=num_think_tags)
    dataset = dataset.map(
        process_example_map,
        num_proc=num_proc,
        desc="Tokenizing SoT SFT data",
    )
    dataset.push_to_hub(upload_data_path, private=True)

if __name__ == "__main__":
    mathcot_sft(num_think_tags=4,
                download_data_path="....upload that file and tokenize it here....",
                upload_data_path="...upload this tokenized file...", 
                num_proc=1,)
