import fire
import os
import json
import glob
import random


def main(t0_data_dir='./processed_t0_data', output_dir='./cappy_pretrain_data'):
    file_idx_list = []
    files = []
    n_data = 0
    for dataset_name in json.load(open('data_utils/split.json'))['T0']:
        for filename in glob.glob(f'{t0_data_dir}/{dataset_name}/*.jsonl'):
            files.append(open(filename))
            for line in open(filename):
                file_idx_list.append(len(files) - 1)
                example = json.loads(line)
                n_data += len(example['responses'])
            print(filename, f'total={n_data}')

    print('total examples', len(file_idx_list))
    print('total instruction-response pairs', n_data)

    random.shuffle(file_idx_list)

    os.makedirs(output_dir, exist_ok=True)
    output_file = open(f'{output_dir}/train_0.jsonl', 'w')
    output_idx = 0
    n_written = 0
    for file_idx in file_idx_list:
        example = json.loads(files[file_idx].readline())

        for response, score in example['responses']:
            output_file.write(json.dumps({
                'instruction': example['instruction'],
                'response': response,
                'label': score,
                'answer_choices': example['answer_choices'],
                'dataset_name': example['dataset_name'],
                'template_name': example['template_name'],
                'example_idx': example['example_idx']
            }) + '\n')

            n_written += 1
            if n_written >= n_data // 16:
                print(f'{output_file.name}: {n_written} examples written.')

                output_idx += 1
                output_file = open(
                    f'{output_dir}/train_{output_idx}.jsonl', 'w')
                n_written = 0

    with open(f'{output_dir}/test.jsonl', 'w') as output_file:
        for dataset_name in \
                json.load(open('pretrain_data_utils/split.json'))['evaluation']:
            for filename in glob.glob(f'{t0_data_dir}/{dataset_name}/*.jsonl'):
                for line in open(filename):
                    example = json.loads(line)
                    for response, score in example['responses']:
                        output_file.write(json.dumps({
                            'instruction': example['instruction'],
                            'response': response,
                            'label': score,
                            'answer_choices': example['answer_choices'],
                            'dataset_name': example['dataset_name'],
                            'template_name': example['template_name'],
                            'example_idx': example['example_idx']
                        }) + '\n')


if __name__ == '__main__':
    fire.Fire(main)
