import os, sys, json, argparse
from tqdm import tqdm

sys.path.append(os.path.join(sys.path[0], '..'))
from src.utils.data import dataset_info


def main(args):
    with open(f'{args.in_path}/all.jsonl', 'r') as f:
        all_instances = [json.loads(line) for line in f]
    all_questions = [x['context'].strip() for x in all_instances]

    for split in ['train', 'dev', 'test']:
        num_instances = dataset_info['strategyqa'][split][1]
        filename = dataset_info['strategyqa']['file'][split]
        with open(filename, 'r', encoding='utf-8') as fin:
            raw_dataset = json.load(fin)

        gpt_dataset = []
        for instance in tqdm(raw_dataset, desc=f'Processing {split} set'):
            question = instance['question'].strip()
            assert question in all_questions
            idx = all_questions.index(question)
            gpt_dataset.append(all_instances[idx])
        
        assert len(gpt_dataset) == num_instances

        for instance in gpt_dataset:
            with open(f'{args.out_path}/{split}.jsonl', 'a') as f:
                f.write(json.dumps(instance) + '\n')

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Reorder instances in StrategyQA GPT files')
    parser.add_argument('--in_path', '-i', type=str, default='../data/strategyqa/generated-rationales/gpt-neox/v2')
    parser.add_argument('--out_path', '-o', type=str, default='../data/strategyqa/generated-rationales/gpt-neox')
    args = parser.parse_args()
    main(args)
