import json
import os
import re
import time
os.environ["TOKENIZERS_PARALLELISM"]="true"
import torch
import pandas as pd

from argparse import ArgumentParser
from transformers import AutoTokenizer
from multiprocessing import Pool
from tqdm import tqdm

parser = ArgumentParser()
parser.add_argument('--input_file', default="sft_outputs/sft_data_toy.jsonl")
parser.add_argument('--output_file', default="sft_outputs/sft_data_toy.preprocessed.jsonl")
parser.add_argument('--tokenizer_name_or_path', default="./deepseek_distill_qwen_tokenizer_fix")
parser.add_argument('--max_seq_length', type=int, default=32768)
parser.add_argument('--preprocessing_num_workers', type=int, default=64)
args = parser.parse_args()

tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path)
print(f"load tokenizer from {args.tokenizer_name_or_path} done.")
max_seq_length = args.max_seq_length

input_data = []
with open(args.input_file) as fd:
    for line in fd:
        example = json.loads(line)
        assert "messages" in example
        input_data.append(example)
print(f"load input data from {args.input_file} done. len(input_data): {len(input_data)}")

def encode_sft_example(example, verbose=False):
    """
    This function encodes a single example into a format that can be used for sft training.
    Here, we assume each example has a 'messages' field. Each message in it is a dict with 'role' and 'content' fields.
    We use the `apply_chat_template` function from the tokenizer to tokenize the messages and prepare the input and label tensors.
    """
    messages = example["messages"]
    if len(messages) == 0:
        raise ValueError("messages field is empty.")
    if verbose:
        chat_messages = tokenizer.apply_chat_template(
            conversation=messages,
            tokenize=False,
            return_tensors="pt",
            padding=False,
            truncation=True,
            max_length=max_seq_length,
            add_generation_prompt=False,
        )
        print(f"chat_messages:\n[{chat_messages}]")
    
    input_ids = tokenizer.apply_chat_template(
        conversation=messages,
        tokenize=True,
        return_tensors="pt",
        padding=False,
        truncation=True,
        max_length=max_seq_length,
        add_generation_prompt=False,
    )
    labels = input_ids.clone()
    if verbose:
        print(f"labels: {labels[0].tolist()}")
    
    # mask the non-assistant part for avoiding loss
    for message_idx, message in enumerate(messages):
        if message["role"] != "assistant":
            # we calculate the start index of this non-assistant message
            if message_idx == 0:
                message_start_idx = 0
            else:
                message_start_idx = tokenizer.apply_chat_template(
                    conversation=messages[:message_idx],  # here marks the end of the previous messages
                    tokenize=True,
                    return_tensors="pt",
                    padding=False,
                    truncation=True,
                    max_length=max_seq_length,
                    add_generation_prompt=False,
                ).shape[1]
            # next, we calculate the end index of this non-assistant message
            if message_idx < len(messages) - 1 and messages[message_idx + 1]["role"] == "assistant":
                # for intermediate messages that follow with an assistant message, we need to
                # set `add_generation_prompt=True` to avoid the assistant generation prefix being included in the loss
                # (e.g., `<|assistant|>`)
                message_end_idx = tokenizer.apply_chat_template(
                    conversation=messages[: message_idx + 1],
                    tokenize=True,
                    return_tensors="pt",
                    padding=False,
                    truncation=True,
                    max_length=max_seq_length,
                    add_generation_prompt=True,
                ).shape[1]
            else:
                # for the last message or the message that doesn't follow with an assistant message,
                # we don't need to add the assistant generation prefix
                message_end_idx = tokenizer.apply_chat_template(
                    conversation=messages[: message_idx + 1],
                    tokenize=True,
                    return_tensors="pt",
                    padding=False,
                    truncation=True,
                    max_length=max_seq_length,
                    add_generation_prompt=False,
                ).shape[1]
            # set the label to -100 for the non-assistant part
            labels[:, message_start_idx:message_end_idx] = -100
            
        # 2. For assistant messages, find and mask ```output...``` blocks
        else:
            content = message["content"]
            
            # Find output blocks using a more specific pattern
            output_blocks = []
            for match in re.finditer(r'```output([\s\S]*?)```\n', content):
                # The entire match is the output block we want to mask
                block_start_pos = match.start()
                block_end_pos = match.end()
                output_blocks.append((block_start_pos, block_end_pos))
                
            if verbose and output_blocks:
                print(f"Found {len(output_blocks)} output blocks")
                for i, (start, end) in enumerate(output_blocks):
                    print(f"Output block {i+1}: '{content[start:end]}'")

            # Get the start index of this assistant message within the overall tokens
            if message_idx == 0:
                message_start_idx = 0
            else:
                tmp_input_ids = tokenizer.apply_chat_template(
                    conversation=messages[:message_idx],
                    tokenize=True,
                    return_tensors="pt",
                    padding=False,
                    truncation=True,
                    max_length=max_seq_length,
                    add_generation_prompt=True,
                )
                message_start_idx = tmp_input_ids.shape[1]
            
            # For each output block, identify and mask the corresponding tokens
            if output_blocks:
                # First, get the full assistant message tokens
                assistant_message_ids = tokenizer.encode(content, add_special_tokens=False)
                
                # Then, for each output block, identify the token positions
                for i, (block_start_pos, block_end_pos) in enumerate(output_blocks):
                    # Get the text before the output block
                    text_before_block = content[:block_start_pos]
                    # Encode just this text to find its token length
                    tokens_before_block = tokenizer.encode(text_before_block, add_special_tokens=False)
                    if verbose:
                        print(f"text_before_block: [{text_before_block}]")
                        print(f"tokens_before_block: {tokens_before_block}")
                    block_start_token_idx = message_start_idx + len(tokens_before_block)
                    
                    # Get the output block text
                    output_block_text = content[block_start_pos:block_end_pos]
                    # Encode just this text to find its token length
                    output_block_tokens = tokenizer.encode(output_block_text, add_special_tokens=False)
                    if verbose:
                        print(f"output_block_text: [{output_block_text}]")
                        print(f"output_block_tokens: {output_block_tokens}")
                    block_end_token_idx = block_start_token_idx + len(output_block_tokens)
                    
                    # Mask the output block tokens
                    labels[:, block_start_token_idx:block_end_token_idx] = -100

    attention_mask = torch.ones_like(input_ids)
    
    if verbose:
        # double check!
        # Decode labels to see what parts are actually included in loss computation
        # Note: tokens with label -100 will be skipped in the decoded output
        print("\n== DECODED LABELS (showing only content included in loss calculation) ==")
        # Create a copy of labels for decoding (replace -100 with pad token id to avoid errors)
        decode_labels = labels.clone()
        pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
        decode_labels[decode_labels == -100] = pad_token_id
        decoded_text = tokenizer.decode(decode_labels[0])
        print(decoded_text)
        print("== END OF DECODED LABELS ==\n")

    return {
        "input_ids": input_ids.flatten().tolist(),
        "labels": labels.flatten().tolist(),
        "attention_mask": attention_mask.flatten().tolist(),
    }

print(encode_sft_example(input_data[0], verbose=True))
# raise
time.sleep(10)
print("please check your sft example works as expected!")

tokenized_data = []
input_ids_lens = []
with Pool(args.preprocessing_num_workers) as p:
    pbar = tqdm(input_data, desc=f"tokenizing")
    for tokenized_example in p.imap(encode_sft_example, pbar):
        input_ids_lens.append(len(tokenized_example["input_ids"]))
        dump = json.dumps(tokenized_example)
        tokenized_data.append(dump)
print(f"input_ids_lens stats on {args.input_file}")
print(pd.Series(input_ids_lens).describe())

with open(args.output_file, "w") as fw:
    for dump in tokenized_data:
        fw.write(dump + "\n")
print(f"written to {args.output_file}")