import os

from nltk import Tree
from random import randint, seed

script_dir = os.path.dirname(__file__)

POSSIBLE_VARIABLE_NAME_TOKENS = set(['VariableName', 'VariableName2', 'PolicyName', 'ExecuteVariableName', 'ConditionalVariableName', 'NumericVariableName'])

def get_random_statement(possible_statements):
    return possible_statements[randint(0, len(possible_statements) - 1)]

def transform(t):
    label = t.label()
    # switch statement based on t.label
    if label == 'Policy':
        return transform_Policy(t)
    elif label == 'Option':
        return transform_Option(t)

def transform_Option(t):
    out_tree = Tree('Option',[])
    possible_option_statements = [
        'given that',
        'in the case when',
        'in a scenario where',
        'if you see that',
    ]
    out_tree.append(get_random_statement(possible_option_statements))
    for elt in t[1:]:        
        label = elt.label()
        if label == 'Init':
            out_tree.append(transform_Init(elt))
        elif label == 'Until':
            out_tree.append(transform_Until(elt))
        elif label == 'Policy':
            out_tree.append(transform_Policy(elt))
        elif label == 'ConditionalExecute':
            out_tree.append(transform_ConditionalExecute(elt))
        elif label == 'BoolExp':
            out_tree.append(transform_BoolExp(elt))

    return out_tree

def transform_Policy(t):
    out_tree = Tree('Policy',[])
    for elt in t[1:]:
        label = elt.label()
        if label == 'PolicyName':
            possible_statements = [
                'states that you should',
                'suggests that it would be advantageous to',
                'suggests that it would not be advantageous to',
                'means that it would be good to',
                'is a strategy where you',
                'is a good strategy where you should',
                'is not a good strategy',
                'is an approach where you',
                'is an approach where it would be good to',
                'is a tactic where it is advantageous to',
            ]

            # skip inserting the policy name half the time
            if randint(0, 1) == 1:
                out_tree.append(transform_VariableName(elt))
                out_tree.append(get_random_statement(possible_statements))
        elif label == 'Colon':
            possible_statements = [
                'it would be advantageous to',
                'it would be good to',
                'it is a good idea if you',
                'you should',
                'you can try',
                'you can',
            ]
            out_tree.append(get_random_statement(possible_statements))

        #ignores functors like :, \n, \t
        elif label == 'ConditionalExecute':
            out_tree.append(transform_ConditionalExecute(elt))
        elif label == 'Execute':
            out_tree.append(transform_Execute(elt))

    return out_tree

def transform_Until(t):
    possible_until_statements = ['until', 'once', 'as soon as', 'when']
    out_tree = Tree('Until', [])
    out_tree.append(get_random_statement(possible_until_statements))
    for elt in t:
        # skip keyword
        if elt == 'until':
            continue

        label = elt.label()
        if label == 'BoolExp':
            out_tree.append(transform_BoolExp(elt))
    return out_tree

def transform_Init(t):
    out_tree = Tree('Init', [])
    for elt in t:
        # skip keyword
        if elt == 'init':
            continue

        label = elt.label()
        if label == 'BoolExp':
            out_tree.append(transform_BoolExp(elt))
    return out_tree

def transform_ConditionalExecute(t):
    out_tree = Tree('ConditionalExecute', [])
    for elt in t:
        label = elt.label()
        if label == 'If':
            out_tree.append(transform_If(elt))
        elif label == 'Elif':
            out_tree.append(transform_Elif(elt))
        elif label == 'Else':
            out_tree.append(transform_Else(elt))
    return out_tree

#note: current structure of if statement parse clashes with this because 'Execute' is included as string literal.
def transform_If(t):
    out_tree = Tree('If', ['if'])
    for elt in t[1:]:
        label = elt.label()
        if label in POSSIBLE_VARIABLE_NAME_TOKENS:
            out_tree.append(transform_VariableName(elt))
        elif label == 'BoolExp':
            out_tree.append(transform_BoolExp(elt))
        elif label == 'Execute':
            out_tree.insert(0, transform_Execute(elt))
    return out_tree

def transform_Elif(t):
    out_tree = Tree('Elif', ['otherwise, if'])
    for elt in t[1:]:
        label = elt.label()
        if label in POSSIBLE_VARIABLE_NAME_TOKENS:
            out_tree.append(transform_VariableName(elt))
        elif label == 'BoolExp':
            out_tree.append(transform_BoolExp(elt))
        elif label == 'Execute':
            out_tree.append(transform_Execute(elt))
    return out_tree

def transform_Else(t):
    out_tree = Tree('Else', [', or'])
    possible_else_statements = [
        'otherwise',
        'if not',
        'as a last resort',
        'if no other options are possible'
    ]

    for elt in t[1:]:
        label = elt.label()
        if label == 'Execute':
            out_tree.append(transform_Execute(elt))
    
    out_tree.append(get_random_statement(possible_else_statements)) # if not, as a last resort, if no other options are possible
    return out_tree

def transform_Execute(t):
    possible_execute_statements = [
        'try to',
        'attempt to',
        'aim to',
        ''
    ]
    if len(t) <= 1:
        return
    out_tree = Tree('Execute', [get_random_statement(possible_execute_statements)])
    for elt in t[1:]:
        label = elt.label()
        if label in POSSIBLE_VARIABLE_NAME_TOKENS:
            out_tree.append(transform_VariableName(elt))
    return out_tree

#should include conjs too, but recursion errors means I put that off.
def transform_BoolExp(t):
    out_tree = Tree('BoolExp', [])
    for elt in t:
        label = elt.label()
        if label in POSSIBLE_VARIABLE_NAME_TOKENS:
            out_tree.append(transform_VariableName(elt))
        elif label == 'BoolTest':
            out_tree.append(transform_BoolTest(elt))
        elif label == 'Number':
            out_tree.append(transform_VariableName(elt))
    return out_tree

def transform_BoolTest(t):
    out_tree = Tree('BoolTest',[])
    r = randint(0, 1)
    elt = t[0]
    if elt == '==':
        possible_equals_statements = [
            'is equal to',
            'is exactly the same as'
        ]
        out_tree.append(get_random_statement(possible_equals_statements))
    elif elt == '!=':
        possible_not_equals_statements = [
            'is not equal to',
            'is not the same as'
        ]
        out_tree.append(get_random_statement(possible_not_equals_statements))
    elif elt == '>':
        possible_greater_statements = [
            'is greater than',
            'is larger than'
        ]
        out_tree.append(get_random_statement(possible_greater_statements))
    elif elt == '<':
        possible_less_than_statements = [
            'is less than',
            'is smaller than'
        ]
        out_tree.append(get_random_statement(possible_less_than_statements))
    elif elt == '<=':
        possible_less_than_statements = [
            'is less than or equal to',
            'is smaller than or equal to',
            'is less than or the same as',
            'is smaller than or the same as',
        ]
        out_tree.append(get_random_statement(possible_less_than_statements))
    elif elt == '>=':
        possible_less_than_statements = [
            'is greater than or equal to',
            'is larger than or equal to',
            'is greater than or the same as',
            'is larger than or the same as',
        ]
        out_tree.append(get_random_statement(possible_less_than_statements))
    return out_tree

def transform_VariableName(t):
    return ' '.join(t[0].split('_'))

def main():
    seed(0)

    output_file = "nl_policy_template.txt"
    with open(os.path.join(script_dir, "../data/tokenized_policy_output.txt"), 'r') as f_input:
        with open(os.path.join(script_dir, f"../data/nl/{output_file}"), 'w') as f_output:
            tokenized_lines = f_input.readlines()
            total_lines = len(tokenized_lines)
            print(f"Writing transform file for POLICY to ", output_file)
            print(f'Transforming {total_lines} total statements')
            print('This may take a while...\n...')

            for i in range(len(tokenized_lines)):
                if (i % 1000 == 0):
                    print(f'Finished transforming {i}/{total_lines} statements')
                
                tokenized_rlang = tokenized_lines[i]
                in_tree = Tree.fromstring(tokenized_rlang)
                out_tree = transform(in_tree)
                
                output = ' '.join(out_tree.leaves())
                f_output.write(output + '\n')
                
if __name__ == '__main__':
    main()