import os, sys, json, argparse
from tqdm import tqdm
import datasets

sys.path.append(os.path.join(sys.path[0], '..'))
from src.utils.data import dataset_info, data_keys


def main(args):
    with open(f'{args.in_path}/all.jsonl', 'r') as f:
        all_instances = [json.loads(line) for line in f]
    all_ids = [x['id'] for x in all_instances]

    for split in ['train', 'dev', 'test']:
        dataset = []
        hf_split, num_instances = dataset_info['ecqa'][split]
        raw_dataset = datasets.load_dataset(dataset_info['ecqa']['hf'])[hf_split]
        for instance in tqdm(raw_dataset, desc=f'Processing {split} set'):
            id = instance['q_no']
            assert id in all_ids
            idx = all_ids.index(id)
            dataset.append(all_instances[idx])
        
        assert len(dataset) == num_instances

        for instance in dataset:
            with open(f'{args.out_path}/{split}_resplit.jsonl', 'a') as f:
                f.write(json.dumps(instance) + '\n')

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Re-split ECQA')
    parser.add_argument('--in_path', '-i', type=str, default='../data/gpt/data/ecqa/old')
    parser.add_argument('--out_path', '-o', type=str, default='../data/gpt/data/ecqa')
    args = parser.parse_args()
    main(args)
