#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
把 DataFrame 格式的 pkl 转换为 list[dict] 格式的 pkl，
保证能被 train.py/base_dataset 正常读取。
"""

import argparse
import pandas as pd
import pickle
import os

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--in_pkl", required=True, help="输入的 DataFrame 格式 pkl 文件")
    parser.add_argument("--out_pkl", required=True, help="输出的 list[dict] 格式 pkl 文件")
    args = parser.parse_args()

    raw = pd.read_pickle(args.in_pkl)

    if not isinstance(raw, pd.DataFrame):
        raise ValueError(f"{args.in_pkl} 不是 DataFrame，请检查输入。")

    df = raw.copy()

    # 确保关键字段类型正确
    if "hd_target" in df.columns:
        df["hd_target"] = df["hd_target"].astype(float)
    if "hd_bin" in df.columns:
        df["hd_bin"] = df["hd_bin"].astype(int)

    # 转为 list[dict]
    records = df.to_dict(orient="records")

    os.makedirs(os.path.dirname(args.out_pkl), exist_ok=True)
    with open(args.out_pkl, "wb") as f:
        pickle.dump(records, f, protocol=pickle.HIGHEST_PROTOCOL)

    print(f"[done] 转换完成: {args.out_pkl}, 共 {len(records)} 条样本")

if __name__ == "__main__":
    main()
