import sys
import pandas as pd
from os import path
from nltk import Tree
from random import seed, randrange
from transform_rlang import transform

script_dir = path.dirname(__file__)
NUM_STATEMENTS = 50000
RLANG_TYPE_MAP = {
    'policy': 'Policy',
    'option': 'Option', 
}

# Lunar lander data
VARIABLE_NAME_TO_REPLACEMENTS = {
    'policy_name': ['adjust_lander', 'position_lander', 'avoid_mountain', 'do_nothing'],
    'variable_name_1': ['lander_position', 'mountain', 'lander_coordinates'],
    'variable_name_2': ['target_position', 'target_coordinates'],
    'execute_variable_name': ['fire_right_thruster', 'point_nose_left', 'point_nose_right', 'main_engine', 'fire_thrusters', 'fire_left_thruster'],
    'conditional_variable': ['above_landing_target', 'left_leg_in_contact', 'right_leg_in_contact'],
    'numeric_variable': ['horizontal_speed', 'vertical_speed', 'fuel_level', 'altitude', 'velocity_y', 'velocity_x', 'remaining_angle', 'remaining_hover'],
    'number': []
}

# POSSIBLE_POLICY_NAME = ['adjust lander', 'position lander', 'avoid mountain', 'do nothing']
# POSSIBLE_VARIABLE_NAME_1 = ['lander position', 'mountain', 'lander coordinates']
# POSSIBLE_VARIABLE_NAME_2 = ['target position', 'target coordinates']
# POSSIBLE_CONDITIONAL_VARIABLE_NAME = ['above landing target', 'left leg_in_contact', 'right leg in contact']
# POSSIBLE_NUMERIC_VARIABLE_NAME = ['horizontal speed', 'vertical speed', 'fuel level', 'altitude', 'velocity y', 'velocity x', 'remaining angle', 'remaining_hover']
# POSSIBLE_EXECUTE_VARIABLE_NAME = ['fire right thruster', 'point nose left', 'point nose right', 'main engine', 'fire thrusters', 'fire left thruster']

"""
The purpose of this script is to generate an RLang/NL dataset.

- INPUT: Tokenized RLang file.

- OUTPUT: CSV of randomly generated RLang/NL statements that are based on the tokenized RLang input. 
          The NL values used to generate the sentences (e.g. "policy name" and "variable names" are 
          pulled from the set of variables above.)

- HOW TO USE:
    1. Add possible variable names in sets above
       - ensure that these map to all values you want to be exchanged from the tokenized templates
       - ensure that these values are replaced in the `randomize_line` function
    2. Make sure that your desired RLANG_TYPE is added to the RLANG_TYPE_MAP
    3. Run `python3 generate_dataset.py <RLANG_TYPE>`
"""

# 
def randomize_line(tokenized_line):
    vocab = []
    final_string = tokenized_line

    # replacements based on CFG
    for v_name in VARIABLE_NAME_TO_REPLACEMENTS.keys():
        new_string, replacements_used = randomize_variable(final_string, v_name, VARIABLE_NAME_TO_REPLACEMENTS[v_name])
        final_string = new_string
        vocab = vocab + replacements_used
    
    return final_string, vocab

def randomize_variable(original_string, var_to_replace, possible_replacements):
    replacements_used = []
    new_string = original_string
    replacement = None
    num_occurrences = original_string.count(var_to_replace)
    for i in range(num_occurrences):
        if var_to_replace == 'number':
            new_string = new_string.replace('number', str(randrange(-1000, 1000)), 1)
        else:
            replacement = possible_replacements[randrange(0, len(possible_replacements))]
            replacements_used.append(replacement)
            new_string = new_string.replace(var_to_replace, replacement, 1)
    return new_string, replacements_used

def transform_line(tokenized_line):
    in_tree = Tree.fromstring(tokenized_line)
    out_tree = transform(in_tree)
    
    output = ' '.join(out_tree.leaves())
    return output

def generate_nl(rlang_type):
    output_rlang_file = f'../data/rlang/{rlang_type}_rlang.txt'
    tokenized_file = f'../data/tokenized/tokenized_{rlang_type}_output.txt'
    output_csv = f'../data/lunar_lander_nl_rlang_{rlang_type}.csv'
    print(f'Generating {rlang_type.upper()} NL statements')
    print('This may take a while...\n...')
    
    with open(path.join(script_dir, tokenized_file), 'r') as tokenized_file:
        tokenized_lines = tokenized_file.readlines()
        all_nl_statements = []
        all_rlang_statements = []
        all_vocab_lists = []

        sample_size = int(NUM_STATEMENTS / len(tokenized_lines))
        print(f'sampling {sample_size} possible statements for each tokenized template...')
        for i in range(len(tokenized_lines)):
            sampled_nl_statements = []
            sampled_rlang_statements = []
            sampled_vocab = []
            while len(sampled_nl_statements) < sample_size:
                # generate sample_size # of statements per tokenized line
                randomized_line, vocab = randomize_line(tokenized_lines[i])
                transformed_line = transform_line(randomized_line)
                if transformed_line not in sampled_nl_statements:
                    sampled_rlang_statements.append(' '.join(Tree.fromstring(randomized_line).leaves()))
                    sampled_nl_statements.append(transformed_line)
                    sampled_vocab.append(vocab)

            all_nl_statements = all_nl_statements + list(sampled_nl_statements)
            all_rlang_statements = all_rlang_statements + sampled_rlang_statements
            all_vocab_lists = all_vocab_lists + sampled_vocab

            if i % 50 == 0:
                print(f'finished generating sentences for {i} / {len(tokenized_lines)} templates')
        
        df = pd.DataFrame({"english": all_nl_statements, "vocab": all_vocab_lists, "rlang": all_rlang_statements})
        df.to_csv(output_csv, index=False)

    print(f'Done! Dataset written to {output_csv}')

def main(argv):
    valid_rlang = set(RLANG_TYPE_MAP.keys())
    if len(argv) != 2:
        print('Invalid number of arguments')
        print(f'Expected input: `python3 generate_dataset.py <{valid_rlang}>`')
        print(f'i.e. `python3 generate_dataset.py policy`')
        return
    elif argv[1] not in valid_rlang:
        print(f'Invalid argument "{argv[1]}"". generate_dataset.py only generate the following RLang statements:', valid_rlang)
        print(f'{argv[1]} is not a valid RLang statement')
        return

    generate_nl(argv[1])

if __name__ == '__main__':
    seed(0)
    main(sys.argv)