from __future__ import annotations
from typing import List, Optional, Tuple
import json
import os
import tempfile
from concurrent.futures import ProcessPoolExecutor, as_completed
import glob

def _remove_consecutive_duplicates(nums: List[int]) -> List[int]:
    if not nums:
        return nums
    out = [nums[0]]
    for x in nums[1:]:
        if x == out[-1]:
            continue
        out.append(x)
    return out

def _process_line(line: str) -> str:
    line = line.strip()
    if not line:
        return ""
    obj = json.loads(line)
    if "path_node" in obj and isinstance(obj["path_node"], list):
        path_node_unique = []
        for lst in obj["path_node"]:
            if isinstance(lst, list) and all(isinstance(v, int) for v in lst):
                path_node_unique.append(_remove_consecutive_duplicates(lst))
            else:
                path_node_unique.append(lst)
        obj["path_node_unique"] = path_node_unique

    if "path_edge" in obj and isinstance(obj["path_edge"], list):
        path_edge_unique = []
        for lst in obj["path_edge"]:
            if isinstance(lst, list) and all(isinstance(v, (int, float)) for v in lst):
                path_edge_unique.append([v for v in lst if v != 0])
            else:
                path_edge_unique.append(lst)
        obj["path_edge_unique"] = path_edge_unique

    return json.dumps(obj, ensure_ascii=False) + "\n"

def _process_segment(args: Tuple[str, int, int, int]) -> str:
    input_path, start, end, chunk_lines = args
    fd, tmp_path = tempfile.mkstemp(prefix="jsonl_part_", suffix=".tmp")
    os.close(fd)
    with open(input_path, "rb") as fin, \
         open(tmp_path, "w", encoding="utf-8", newline="") as fout:

        fin.seek(start)
        if start != 0:
            fin.readline()

        buf = []
        while fin.tell() < end:
            bline = fin.readline()
            if not bline:
                break
            line = bline.decode("utf-8")
            out = _process_line(line)
            if out:
                buf.append(out)
            if len(buf) >= chunk_lines:
                fout.writelines(buf)
                buf.clear()

        if buf:
            fout.writelines(buf)

    return tmp_path
def _split_file_by_newlines(path: str, num_workers: int) -> List[Tuple[int, int]]:
    file_size = os.path.getsize(path)
    if file_size == 0:
        return [(0, 0)]
    step = file_size // num_workers
    bounds = []
    start = 0
    for i in range(num_workers):
        if i == num_workers - 1:
            end = file_size
        else:
            end = (i + 1) * step
        bounds.append((start, end))
        start = end
    return bounds

def transform_jsonl_parallel(
    input_path: str,
    output_path: Optional[str] = None,
    num_workers: int = 4,
    chunk_lines: int = 10_000,
) -> str:
    if output_path is None:
        output_path = f"{input_path}.processed.jsonl"

    if num_workers < 1:
        raise ValueError("num_workers must be at least 1.")
    if chunk_lines < 1:
        raise ValueError("chunk_lines must be at least 1.")

    segments = _split_file_by_newlines(input_path, num_workers)
    if num_workers == 1:
        tmp = _process_segment((input_path, segments[0][0], segments[0][1], chunk_lines))
        os.replace(tmp, output_path)
        return output_path

    tmp_paths = [None] * len(segments)
    try:
        with ProcessPoolExecutor(max_workers=num_workers) as ex:
            futures = {}
            for idx, (s, e) in enumerate(segments):
                fut = ex.submit(_process_segment, (input_path, s, e, chunk_lines))
                futures[fut] = idx

            for fut in as_completed(futures):
                idx = futures[fut]
                tmp_paths[idx] = fut.result()
        with open(output_path, "w", encoding="utf-8", newline="") as fout:
            for p in tmp_paths:
                if p is None:
                    continue
                with open(p, "r", encoding="utf-8", newline="") as fin:
                    for line in fin:
                        fout.write(line)
    finally:
        for p in tmp_paths:
            if p and os.path.exists(p):
                try:
                    os.remove(p)
                except OSError:
                    pass

    return output_path

def transform_jsonl(input_path: str, output_path: Optional[str] = None) -> str:
    if output_path is None:
        output_path = f"{input_path}.processed.jsonl"
    with open(input_path, "r", encoding="utf-8") as fin, open(output_path, "w", encoding="utf-8") as fout:
        for line in fin:
            out = _process_line(line)
            if out:
                fout.write(out)
    return output_path

if __name__ == "__main__":
    # MODEL_NAMES = [
    #     "Qwen2.5-Math-7B",
    #     "Qwen2.5-Math-7B-Oat-Zero",
    #     "DeepSeek-R1-Distill-Qwen-7B",
    #     "AceReason-Nemotron-7B",
    # ]
    # MODEL_NAMES = [
    #     "Qwen2.5-14B",
    #     "Qwen-2.5-14B-SimpleRL-Zoo",
    #     "DeepSeek-R1-Distill-Qwen-14B",
    #     "AceReason-Nemotron-14B",
    # ]
    
    MODEL_NAMES = [
        "Qwen2.5-Math-1.5B",
        "Qwen2.5-Math-1.5B-Oat-Zero",
        "DeepSeek-R1-Distill-Qwen-1.5B",
        "Nemotron-Research-Reasoning-Qwen-1.5B"
    ]
    
    DATA_NAMES = ["aime24"]
    for model in MODEL_NAMES:
        for data in DATA_NAMES:
            base_pattern = (
                f"XXXX"
                f"{model}_{data}_temp0.6_n256_seed1/eval_results/global_step_0/"
                f"{data}/test_*_-1_seed1_t0.6_s0_e-1_processed.jsonl"
            )
            files = glob.glob(base_pattern)
            print(f"[{model}, {data}] Found {len(files)} files.")
            for f in files:
                print(f"Processing {f} ...")
                out_path = transform_jsonl_parallel(
                    f,
                    num_workers=100,
                    chunk_lines=100,
                )
                print(f"Saved to {out_path}")
    

