# extract_semantic_codes.py
import argparse
import os
import json
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig  # environment compatibility
from huggingface_hub import snapshot_download
import tqdm

import multiprocessing as mp
import torch
from typing import Tuple, List

# Faster JSON (optional)
try:
    import orjson as fastjson
    def _loads(s): return fastjson.loads(s)
    def _dumps(o): return fastjson.dumps(o).decode("utf-8")
except Exception:
    _loads = json.loads
    def _dumps(o): return json.dumps(o, ensure_ascii=False)

from kimia_infer.api.prompt_manager import KimiAPromptManager


def _read_all_lines(path: str) -> List[str]:
    with open(path, "r") as f:
        return f.readlines()


def _resume_state(output_file: str) -> Tuple[int, bool]:
    """
    Return (completed_line_count, last_line_fixed).
    Logic:
      - File does not exist -> (0, False)
      - Exists -> count lines N; if the last line fails JSON parsing,
        drop it (rewrite file) and return (N-1, True)
    """
    if not os.path.exists(output_file):
        return 0, False

    # For simplicity and robustness in crash scenarios, read all lines.
    lines = _read_all_lines(output_file)
    if not lines:
        return 0, False

    last = lines[-1].rstrip("\n")
    try:
        _ = _loads(last)
        return len(lines), False
    except Exception:
        # Drop the last (partial/corrupt) line
        safe_lines = lines[:-1]
        with open(output_file, "w") as fw:
            for ln in safe_lines:
                fw.write(ln if ln.endswith("\n") else (ln + "\n"))
        return len(safe_lines), True


def _worker(rank, device_id, cache_path, kimia_token_offset, kimia_text_audiodelaytokens, lines, out_path):
    """
    Each process pins to a single GPU, handles its own shard, and writes to a shard file.
    Output format matches the original script: add `audio_tokens` for messages where message_type == "audio".
    """
    torch.cuda.set_device(device_id)
    torch.set_grad_enabled(False)
    torch.backends.cuda.matmul.allow_tf32 = True

    prompt_manager = KimiAPromptManager(
        model_path=cache_path,
        kimia_token_offset=kimia_token_offset,
        kimia_text_audiodelaytokens=kimia_text_audiodelaytokens
    )

    with open(out_path, "w", buffering=1024 * 1024) as f_out:
        pbar = tqdm.tqdm(lines, position=rank, leave=False, desc=f"GPU{device_id}")
        for line in pbar:
            data = _loads(line)
            conv = data.get("conversation", [])
            for msg in conv:
                if msg.get("message_type") == "audio":
                    audio_path = msg["content"]
                    try:
                        audio_tokens = prompt_manager._tokenize_audio(audio_path)
                        msg["audio_tokens"] = audio_tokens
                    except Exception as e:
                        msg["audio_tokens"] = None
                        msg["tokenize_error"] = str(e)
            f_out.write(_dumps(data) + "\n")


def main():
    # python -m finetune_codes.extract_semantic_codes
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name_or_path", type=str, default="YOUR_MADEL_PATH")
    parser.add_argument("--input_file", type=str, default="DATA_SHOULD_BE_PRETOKENIZED_DIR")
    parser.add_argument("--output_file", type=str, default="PRETOKENIZED_DATA_OUPUT_DIR")
    args = parser.parse_args()

    # Prepare model path
    if os.path.exists(args.model_name_or_path):
        cache_path = args.model_name_or_path
    else:
        cache_path = snapshot_download(args.model_name_or_path)

    # Read required config fields
    model_config = AutoConfig.from_pretrained(cache_path, trust_remote_code=True)
    kimia_token_offset = getattr(model_config, "kimia_token_offset", 0)
    kimia_text_audiodelaytokens = getattr(model_config, "kimia_mimo_audiodelaytokens", 0)

    # Read input
    with open(args.input_file, "r") as f:
        lines = f.readlines()
    total = len(lines)

    # ----- Resume: check existing output and decide where to continue ----- #
    done_n, fixed = _resume_state(args.output_file)
    if done_n > 0:
        tqdm.tqdm.write(f"[resume] detected {done_n} completed line(s){' (fixed last line)' if fixed else ''}.")
    if done_n >= total:
        tqdm.tqdm.write("[resume] output already complete. nothing to do.")
        return

    # Remaining lines to process
    remaining = lines[done_n:]

    num_gpus = torch.cuda.device_count()

    # Single GPU: sequential append
    if num_gpus <= 1:
        prompt_manager = KimiAPromptManager(
            model_path=cache_path,
            kimia_token_offset=kimia_token_offset,
            kimia_text_audiodelaytokens=kimia_text_audiodelaytokens
        )
        # Append mode is critical for resume safety
        with open(args.output_file, "a", buffering=1024 * 1024) as f_out:
            for line in tqdm.tqdm(remaining, desc="GPU0"):
                data = _loads(line)
                for msg in data.get("conversation", []):
                    if msg.get("message_type") == "audio":
                        audio_path = msg["content"]
                        try:
                            audio_tokens = prompt_manager._tokenize_audio(audio_path)
                            msg["audio_tokens"] = audio_tokens
                        except Exception as e:
                            msg["audio_tokens"] = None
                            msg["tokenize_error"] = str(e)
                f_out.write(_dumps(data) + "\n")
        return

    # Multi-GPU: shard remaining lines across processes; merge shards in order
    parts = []
    rem_total = len(remaining)
    for r in range(num_gpus):
        s = r * rem_total // num_gpus
        e = (r + 1) * rem_total // num_gpus
        parts.append((s, e))

    mp.set_start_method("spawn", force=True)

    # Clean up any stale shard files from a previous crash
    for rank in range(num_gpus):
        pf = f"{args.output_file}.part{rank:02d}"
        if os.path.exists(pf):
            try:
                os.remove(pf)
            except:
                pass

    procs, part_files = [], []
    for rank, (s, e) in enumerate(parts):
        out_part = f"{args.output_file}.part{rank:02d}"
        part_files.append(out_part)
        p = mp.Process(
            target=_worker,
            args=(
                rank,               # rank
                rank,               # device_id: 0..N-1 maps to CUDA_VISIBLE_DEVICES
                cache_path,
                kimia_token_offset,
                kimia_text_audiodelaytokens,
                remaining[s:e],
                out_part,
            ),
            daemon=False
        )
        p.start()
        procs.append(p)
    for p in procs:
        p.join()

    # Append-merge shard files into the main output (resume-critical step)
    with open(args.output_file, "a", buffering=1024 * 1024) as f_out:
        for pf in part_files:
            with open(pf, "r") as fin:
                for line in fin:
                    f_out.write(line)

    # Cleanup shards
    for pf in part_files:
        try:
            os.remove(pf)
        except:
            pass


if __name__ == "__main__":
    main()
