from pathlib import Path
import json

import fire
import pandas as pd


DEFAULT_INPUT = (
    "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data/luffy/aime24_qwen3_8.parquet"
)
DEFAULT_OUTPUT = (
    "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data/luffy/aime24_qwen3_unique30.parquet"
)


def _normalize_prompt_key(prompt):
    if isinstance(prompt, (list, tuple)):
        if len(prompt) > 0 and isinstance(prompt[-1], dict):
            content = prompt[-1].get("content")
            if content is not None:
                return str(content).strip()

        serialized = []
        for msg in prompt:
            if isinstance(msg, dict):
                serialized.append(
                    {
                        "role": msg.get("role", ""),
                        "content": msg.get("content", ""),
                    }
                )
            else:
                serialized.append(msg)
        return json.dumps(serialized, ensure_ascii=False, sort_keys=True)

    if isinstance(prompt, dict):
        return str(prompt.get("content", "")).strip()

    return str(prompt).strip()


def main(
    input_file=DEFAULT_INPUT,
    output_file=DEFAULT_OUTPUT,
    num_questions=30,
    seed=42,
    data_source="aime",
    keep_repeats=False,
):
    input_path = Path(input_file)
    output_path = Path(output_file)

    if not input_path.exists():
        raise FileNotFoundError(f"Input file not found: {input_path}")

    df = pd.read_parquet(input_path)
    work_df = df.copy()

    if data_source is not None:
        if "data_source" not in work_df.columns:
            raise ValueError("Input parquet has no 'data_source' column.")
        work_df = work_df[work_df["data_source"] == data_source].copy()

    if len(work_df) == 0:
        raise ValueError("No rows left after filtering.")

    if "prompt" not in work_df.columns:
        raise ValueError("Input parquet has no 'prompt' column.")

    work_df["_question_key"] = work_df["prompt"].map(_normalize_prompt_key)
    unique_questions = work_df["_question_key"].drop_duplicates().tolist()

    if len(unique_questions) == 0:
        raise ValueError("Failed to extract any question keys.")

    sample_size = min(int(num_questions), len(unique_questions))
    sampled_question_keys = (
        pd.Series(unique_questions)
        .sample(n=sample_size, random_state=int(seed), replace=False)
        .tolist()
    )

    sampled_df = work_df[work_df["_question_key"].isin(
        sampled_question_keys)].copy()

    if not keep_repeats:
        sampled_df = sampled_df.drop_duplicates(
            subset=["_question_key"], keep="first")

    sampled_df = sampled_df.drop(
        columns=["_question_key"]).reset_index(drop=True)

    output_path.parent.mkdir(parents=True, exist_ok=True)
    sampled_df.to_parquet(output_path)

    print(f"input_file={input_path}")
    print(f"output_file={output_path}")
    print(f"data_source={data_source}")
    print(f"requested_num_questions={num_questions}")
    print(f"saved_num_questions={sample_size}")
    print(f"keep_repeats={keep_repeats}")
    print(f"rows_saved={len(sampled_df)}")


if __name__ == "__main__":
    fire.Fire(main)
