from pathlib import Path

import fire
import pandas as pd


DEFAULT_INPUT = (
    "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data/luffy/aime24_qwen3_8.parquet"
)
DEFAULT_OUTPUT = (
    "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data/luffy/aime24_qwen3_8_sampled16.parquet"
)


def main(
    input_file=DEFAULT_INPUT,
    output_file=DEFAULT_OUTPUT,
    sample_size=16,
    seed=42,
    data_source="aime",
):
    input_path = Path(input_file)
    output_path = Path(output_file)

    if not input_path.exists():
        raise FileNotFoundError(f"Input file not found: {input_path}")

    df = pd.read_parquet(input_path)
    if "data_source" not in df.columns:
        raise ValueError("Input parquet has no 'data_source' column.")

    subset = df[df["data_source"] == data_source].copy()
    if len(subset) == 0:
        raise ValueError(f"No rows found for data_source={data_source!r}")

    sample_size = min(int(sample_size), len(subset))
    sampled = subset.sample(
        n=sample_size, random_state=int(seed)).reset_index(drop=True)

    output_path.parent.mkdir(parents=True, exist_ok=True)
    sampled.to_parquet(output_path)

    print(f"input_file={input_path}")
    print(f"output_file={output_path}")
    print(f"data_source={data_source}")
    print(f"sample_size={sample_size}")
    print(f"seed={seed}")
    print(f"rows_saved={len(sampled)}")


if __name__ == "__main__":
    fire.Fire(main)
