import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import json

from minimal_multitask.data import DATASETS, FileDataset
from minimal_multitask.utils import encode_with_messages_format

from tqdm import tqdm
import argparse
import os
import pickle

parser = argparse.ArgumentParser()
parser.add_argument("--train_dataset", type=str, default="alpaca")
parser.add_argument("--filter_dataset", type=str)
parser.add_argument("--output_file", type=str)
args = parser.parse_args()

assert os.path.exists(args.train_dataset)

ids_to_filter = []
with open(args.filter_dataset, 'r') as f:
    for line in f:
        ids_to_filter.append(json.loads(line)['id'])

print(f'Filterting {len(ids_to_filter)} samples.')

all_cnt = 0
selected_cnt = 0

with open(args.output_file, "w") as fout:
    with open(args.train_dataset, "r") as fin:
        for i, line in tqdm(enumerate(fin)):
            all_cnt += 1
            if json.loads(line)['id'] not in ids_to_filter:
                selected_cnt += 1
                fout.write(line)

print(f'Selected {selected_cnt}/{all_cnt} entries')

            