"""
This script is used to filter by game ids

python src/dataset/tmp_filter.py --dataset a24_longterm_sft_bp_hh --save_as data/a24_longterm_sft_bp_hh_1108 --include data/a24_hh_game_ids_1108.txt
"""

from absl import flags, app, logging
import datasets
from pathlib import Path
from src.utils import read_txt_as_list

flags.DEFINE_string("dataset", None, "dataset to be filtered", required=True)
flags.DEFINE_string("save_as", None, "save as", required=True)
flags.DEFINE_string("include", None, "include", required=True)

FLAGS = flags.FLAGS


def main(_):
    input_dir = Path("data")
    dset = datasets.load_from_disk(input_dir / FLAGS.dataset)
    game_ids = set(read_txt_as_list(FLAGS.include))
    filtered = dset.filter(lambda x: x["game_id"] in game_ids)
    filtered.save_to_disk(FLAGS.save_as)
    filtered.to_json(f"{FLAGS.save_as}/plain.json")
    print(filtered)
    logging.info(f"Filtered dataset saved to {FLAGS.save_as}")


if __name__ == "__main__":
    app.run(main)
