from os import path
from random import sample, seed
import sys

script_dir = path.dirname(__file__)

RLANG_TYPE_MAP = {
    'policy': 'Policy',
    'option': 'Option', 
}
RLANG_PROMPT = 'RLang:'
NL_PROMPT = 'English:'
NUM_PROMPTS = 100

def write_prompt(rlang_sentence, nl_sentence, out_file):
    out_file.write(f'{RLANG_PROMPT} rlang_sentence \n {NL_PROMPT} nl_sentence \n\n')

def generate_prompts(rlang_type):
    input_rlang_file = f'../data/rlang/{rlang_type}_rlang.txt'
    input_nl_file = f'../data/nl/{rlang_type}_nl.txt'
    output_gpt3_file = f'../data/gpt3/{rlang_type}_prompts.txt'
    print(f'Generating {rlang_type.upper()} prompts')

    prompts = []
    with open(path.join(script_dir, input_rlang_file), 'r') as rlang_f_input:
        with open(path.join(script_dir, input_nl_file), 'r') as nl_f_input:
            rlang_inputs = rlang_f_input.readlines()
            nl_inputs = nl_f_input.readlines()

            for i in range(len(rlang_inputs)):
                prompts.append(f'{RLANG_PROMPT} {rlang_inputs[i]}{NL_PROMPT} {nl_inputs[i]}\n')
    
    with open(path.join(script_dir, output_gpt3_file), 'w') as f_output:
        f_output.writelines(sample(prompts, NUM_PROMPTS))



def main(argv):
    valid_rlang = set(RLANG_TYPE_MAP.keys())
    if len(argv) != 2:
        print('Invalid number of arguments')
        print(f'Expected input: `python3 generate_prompts.py <{valid_rlang}>`')
        print(f'i.e. `python3 generate_dataset.py policy`')
        return
    elif argv[1] not in valid_rlang:
        print(f'Invalid argument "{argv[1]}"". generate_dataset.py only generate the following RLang statements:', valid_rlang)
        print(f'{argv[1]} is not a valid RLang statement')
        return

    generate_prompts(argv[1])

if __name__ == '__main__':
    seed(0)
    main(sys.argv)