import re
import pandas as pd
from datasets import load_dataset, Dataset
import os
from typing import Optional
import json
from concurrent.futures import ProcessPoolExecutor
from itertools import repeat

from transformers import AutoTokenizer
from tqdm import tqdm


tokenizer = None


def get_tokenizer():
    """Lazily initialize and cache the tokenizer per process."""
    global tokenizer
    if tokenizer is None:
        tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B", trust_remote_code=True)
    return tokenizer

def get_token_length(messages):
    tk = get_tokenizer()
    tokenized_messages = tk.apply_chat_template(messages, tokenize=False)
    input_ids = tk(tokenized_messages, return_tensors="pt")["input_ids"]
    return input_ids.shape[1]


def parse_operations(search_path, target):
    # Extract the operations from the line that claims the goal is reached.
    goal_lines = re.finditer(r"\d+,\d+ equal: Goal Reached", search_path)
    goal_lines = list(goal_lines)
    if not goal_lines:
        return "No goal reached statement found."

    goal_line = goal_lines[0]
    # get the last operation line before the goal reached statement
    operations = re.findall(r"Exploring Operation: (.*?=\d+), Resulting Numbers: \[(.*?)\]",
                            search_path[:goal_line.start()])
    if not operations:
        raise ValueError("No operations found leading to the goal.")

    final_operation = operations[-1][0]
    try:
        predicted_result = int(final_operation.split('=')[1])
    except:
        print("couldnt parse last op", final_operation)
        raise ValueError("Couldnt parse last op")
    if predicted_result != target:
        raise ValueError("Invalid path: Final operation does not result in target.")

    # get the last current state, operations before the goal reached statement, and extract the operations
    operation_list = re.findall(r"Current State: \d+:\[.*?\], Operations: \[(.*?)\]", search_path[:goal_line.start()])[
        -1].split(', ')
    operation_list = [op.replace("'", "") for op in operation_list]
    operation_list += [final_operation]

    return operation_list


def format_sample(item, max_length):
    nums = item["nums"]
    target = item["target"]
    search_type = item["search_type"]
    search_path = item["search_path"]
    heuristic = item["heuristic"]
    rating = item["rating"]
    operations = parse_operations(search_path, target)

    if isinstance(operations, list):
        assistant_content = search_path.strip() + "\n" + "Final operations:\n" + "\n".join(operations)
    else:
        assert operations == "No goal reached statement found."
        assistant_content = search_path.strip()

    messages = [
        {
            "role": "user",
            "content": f"Given the numbers {nums} and the target {target}, find the operations to reach the target.",
        },
        {
            "role": "assistant",
            "content": assistant_content,
        },
    ]

    if max_length > 0:
        tokenized_len = get_token_length(messages)
        # Log the tokenized length
        # print(f"Tokenized length: {tokenized_len}")
        if tokenized_len > max_length:
            print(f"Skipping item with length {tokenized_len}")
            return None

    new_item = {
        "messages": messages,
        "nums": nums,
        "target": target,
        "search_type": search_type,
        "heuristic": heuristic,
        "rating": rating,
    }

    return new_item


def _init_worker():
    # Ensure tokenizer is initialized once per process
    get_tokenizer()

   
def main():
    print("Loading Countdown dataset...")
    # Load Countdown dataset
    max_length = 2048
    name = "N4T200"
    data_path = f"/home/xxx2/stream-of-search/src/data/{name}"
    # Get all files that start with "train" and "grow"
    train_files = [os.path.join(data_path, f) for f in os.listdir(data_path) if f.startswith("train")]
    grow_files = [os.path.join(data_path, f) for f in os.listdir(data_path) if f.startswith("grow")]
    countdown_dataset = []
    num_workers = int(os.environ.get("NUM_WORKERS", os.cpu_count() or 1))
    for file in train_files + grow_files:
        with open(file, "r") as f:
            data = json.load(f)
        with ProcessPoolExecutor(max_workers=num_workers, initializer=_init_worker) as executor:
            for sample in tqdm(
                executor.map(format_sample, data, repeat(max_length)),
                total=len(data),
                desc=f"Processing {os.path.basename(file)}",
            ):
                if sample is None:
                    continue
                countdown_dataset.append(sample)
    print(f"Loaded Countdown with {len(countdown_dataset)} samples")

    countdown_val_dataset = []
    eval_files = [os.path.join(data_path, f) for f in os.listdir(data_path) if f.startswith("val_target")]
    assert len(eval_files) == 1, "Expected only one evaluation file"
    with open(eval_files[0], "r") as f:
        data = json.load(f)
        for item in tqdm(data, desc="Processing dataset"):
            sample = format_sample(item, 0)
            if sample is None:
                continue
            countdown_val_dataset.append(sample)
    print(f"Loaded Countdown with {len(countdown_val_dataset)} samples")

    countdown_dataset = Dataset.from_list(countdown_dataset)
    countdown_val_dataset = Dataset.from_list(countdown_val_dataset)
    print(f"Loaded Countdown with {len(countdown_dataset)} samples")

    # Upload to hub
    countdown_dataset.push_to_hub(f"xxx98/Countdown-{name}-{max_length}", private=True)
    countdown_val_dataset.push_to_hub(f"xxx98/Countdown-{name}-val", private=True)
    
    print("\nProcessing complete!")
    print(f"SFT samples used: {len(countdown_dataset)}")
    print(f"Val samples used: {len(countdown_val_dataset)}")


if __name__ == "__main__":
    main()