import openai
import os
import nltk
import itertools
import json
import random
import sys
import time
import pandas as pd

script_dir = os.path.dirname(__file__) #<-- absolute dir the script is in

with open(os.path.join(script_dir, 'keys.json'), 'r') as f:
    keys = json.load(f)

openai.api_key = keys['codex']
RLANG_TYPE_MAP = {
    'policy': 'Policy',
    'effect': 'Effect', 
}

def save_translation_results(translation_results):
    results_csv_path = os.path.join(script_dir, 'translations.csv')
    columns = ['english','rlang_translation']

    if os.path.exists(results_csv_path):
       rlang_translations_df = pd.read_csv(results_csv_path, header=0)
    else:
       rlang_translations_df = pd.DataFrame(columns=columns)

    new_row = pd.DataFrame(translation_results, columns=columns)
    rlang_translations_df = pd.concat([rlang_translations_df, new_row])
    rlang_translations_df.to_csv(results_csv_path)

def parse_vocab(path_to_rlang_file):
    with open(os.path.join(script_dir, path_to_rlang_file), 'r') as f:
        lines = f.readlines()
        vocab = {}
        
        for line in lines:
            split = line.split()
            if len(split) > 1 and split[0] in set(['Factor', 'Action', 'Proposition', 'Feature', 'MarkovFeature', ]):
                print(split[1])
                if split[0] in vocab.keys():
                    vocab[split[0]].append(split[1])
                else:
                    vocab[split[0]] = [split[1]]

        final_vocab = []
        print(vocab)
        for k in vocab.keys():
            final_vocab = final_vocab + vocab[k]

    return ', '.join(final_vocab)

# reads relative file path
def read_relative_file(filename):
    with open(os.path.join(script_dir, filename), 'r') as f:
        return f.readlines()

# reads absolute file path
def read_input(input_file_name):
    with open(input_file_name, 'r') as f:
        return f.readlines()

def read_paired_examples(english_filename, rlang_filename, template='{}\n{}'):
    english_lines = read_relative_file(english_filename)
    rlang_lines = read_relative_file(rlang_filename)
    assert len(english_lines) == len(rlang_lines)
    ret = []
    for i in range(len(english_lines)):
        ret.append(template.format(english_lines[i].strip('\n'), rlang_lines[i].strip('\n')))
    
    return ret
        
def read_named_examples(english_filename, names_filename, rlang_filename, template='{}${}${}'):
    english_lines = read_relative_file(english_filename)
    names_lines = read_relative_file(names_filename)
    rlang_lines = read_relative_file(rlang_filename)
    assert len(english_lines) == len(rlang_lines)
    assert len(names_lines) == len(rlang_lines)
    ret = []
    for i in range(len(english_lines)):
        ret.append(template.format(english_lines[i].strip('\n'), names_lines[i].strip('\n'), rlang_lines[i].strip('\n')))
    
    df = pd.DataFrame({'english': english_lines, 'vocab': names_lines, 'rlang': rlang_lines})
    return df

def read_example_pairs(filename):
    lines = read_relative_file(filename)
    n = len(lines)
    for i in range(n / 3):
        pass
    pass

def permute_examples(examples):
    perms = itertools.permutations(examples)
    return list(perms)

def unroll(line):
    """
    Takes in a line of RLang code, where indents are denoted as '>>', dedents are denoted as '<<', and newlines without indentation are denoted as '<>', and returns a string of the unrolled code.
    """
    elts = line.split(' ')
    ret = ''
    indent = 0
    for c in elts:
        if c == '>>':
            indent += 1
            ret = ret[:-1] + '\n' + '\t' * indent
        elif c == '<<':
            indent -= 1
            ret = ret[:-1] + '\n' + '\t' * indent
        elif c == '<>':
            ret = ret[:-1] + '\n' + '\t' * indent
        else:
            ret += c + ' '
    return ret

#perhaps you want to remove explicit "English" & "Rlang" tags from the files
#so you have more latitude. Maybe only do the template formatting in here?
def generate_prompt(context, examples, query, template):
    n_examples = len(examples)
    prompt = context

    for i in range(n_examples):
        prompt = prompt + '\n\n' + template.format(*examples[i].split('\n'))
    
    prompt = prompt + '\n\n' + template[:-3]
    prompt = prompt.format(query)

    return prompt

#is this necessary? it's not clear
def generate_prompt_named(context, top_examples_df, query, template):
    n_examples = top_examples_df.size
    prompt = context

    for i, row in top_examples_df.iterrows():
        prompt = prompt + '\n\n' + template.format(row['english'].strip('\n'), row['vocab'].strip('\n'), row['rlang'].strip('\n'))
    prompt = prompt + '\n\n' + template[:-3]
    prompt = prompt.format(*query)
    print(f'--------------------- PROMPT ---------------------\n{prompt}\n------------------------------------------')

    return prompt

def english_to_rlang(prompt, kwargs):
    print('---------------------  english_to_rlang ---------------------')
    response = openai.Completion.create(prompt=prompt, **kwargs)
    return response['choices'][0]['text']

def compute_bleu(reference, hypothesis):
    hypothesis_arr = hypothesis.split()
    reference_arr = reference.split()
    BLEUscore = nltk.translate.bleu_score.sentence_bleu([reference_arr], hypothesis_arr, weights = (1/3., 1/3., 1/3.))

    return BLEUscore

def compute_similarity(query, target, prompt=None, engine='text-davinci-001'):
    if not prompt:
        prompt = "The English teacher reads two structurally similar sentences to illustrate natural language's compositional nature."
    prompt += '\n\n' + query.strip('\n') + '\n' + target.strip('\n')
    response = openai.Completion.create(engine=engine, prompt=prompt, logprobs = 0, max_tokens=0, echo=True)
    tokens = response['choices'][0]['logprobs']['tokens']
    logprobs = response['choices'][0]['logprobs']['token_logprobs']
    tokens.reverse()
    logprobs.reverse()
    idx = tokens.index('\n')
    return sum(logprobs[:idx])

def select_examples(target_english, examples_df, k=3):
    print(f'---------------------  selecting {k} examples... ---------------------')
    if k >= examples_df.size:
        return examples_df

    # query.split('$')[0] == just the english

    # compute similarity between english sentence + examples provided
    sims = []
    for i, ex in examples_df.iterrows():
        sims.append(compute_similarity(ex['english'], target_english))
    examples_df['similarities'] = sims
    print(f'TARGET ENGLISH SENTENCE: {target_english}\n')

    # return the top k most similar examples
    return examples_df.sort_values(by=['similarities'], ascending=False).head(k)

def select_names(target, vocabulary):
    context = "The very intelligent programmer looks at a vocabulary and a query, and selects all of the variable names from the list that might be relevant in structuring the query into a formal program, or says 'None' if there aren't any relevant names.\n\nQuery: If your head is underwater, swim up. Otherwise, breathe and then dive\nVocabulary: breathe, run, iron, swim_up, dive, water_level, head, arms\nSelection: breathe, swim_up, dive, water_level, head"
    template = '{}\n\nQuery: {}\nVocabulary: {}\nSelection:'
    prompt = template.format(context, target, vocabulary)
    response = english_to_rlang(prompt, {'engine': "code-davinci-001", 'max_tokens': 200, 'temperature':0, 'frequency_penalty':0.0, 'presence_penalty':0.0, 'stop':['\n\n']})
    if response == "None":
        return []
    else:
        return response.split(', ')

def select_names_batched(target, vocabulary, batchsize=10):
    """
    calls select_names on all words in the vocabulary, in batches of batchsize
    """
    n = len(vocabulary)
    batches = [vocabulary[i:i+batchsize] for i in range(0, n, batchsize)]
    return [select_names(target, batch) for batch in batches]

def translate(query, examples, context = None, kwargs = None, template = None, k=3):
    if not context:
        context = "The expert programmer translates a user's task description from English into RLang, a Python module with a formal syntax."
    if not kwargs:
        kwargs = {'engine': "code-davinci-001", 'max_tokens': 200, 'temperature':0, 'frequency_penalty':0.0, 'presence_penalty':0.0}
    if not template:
        template = 'English: {}\nRLang: {}'
    key = openai.api_key
    openai.api_key = keys['general']
    tops = select_examples(query, examples, k=k) if len(examples) > k else examples
    prompt = generate_prompt(context, tops, query, template)
    openai.api_key = key
    return prompt, english_to_rlang(prompt, kwargs)

def translate_named(query, examples_df, vocabulary, context = None, kwargs = None, template = None, k=3):
    if not context:
        context = "The expert programmer translates a user's task description from English into RLang, a Python module with a formal syntax."
    if not kwargs:
        kwargs = {'engine': "code-davinci-001", 'max_tokens': 200, 'temperature':0, 'frequency_penalty':0.0, 'presence_penalty':0.0, 'stop':['\n\n']}
    if not template:
        template = 'English: {}\nNames: {}\nRLang: {}'
    key = openai.api_key
    openai.api_key = keys['general']

    # names = select_names_batched(query, vocabulary, batchsize=10)
    # names = ', '.join(names)
    names = vocabulary
    top_examples_df = select_examples(query, examples_df, k=k) 
    query = [query, names]
    prompt = generate_prompt_named(context, top_examples_df, query, template)
    ret = None
    openai.api_key = key
    try:
        ret = english_to_rlang(prompt, kwargs)
    except:
        print('failed on query: {}'.format(query[0]))
        time.sleep(20)
        ret = english_to_rlang(prompt, kwargs)
    return prompt, ret

def eval(examples, context=None, template=None, kwargs=None, k=3, out_file='rl.txt'):
    rls = []
    prompts = []
    for i in range(len(examples)):
        query = examples[i].split('\n')[0]
        train_set = examples.copy()
        del train_set[i]
        random.shuffle(train_set)
        prompt, rl = translate(query, train_set, context=context, template=template, kwargs=kwargs, k=k)
        prompts += [prompt]
        rls += [rl]
    with open('out/' + out_file, 'w') as f:
        [f.write(rl.strip() + '\n') for rl in rls]
    with open('prompts/' + out_file, 'w') as f:
        [f.write(p.strip() + '\n----------\n') for p in prompts]

def eval_named(examples, context=None, template=None, kwargs=None, k=3, out_file='rl.txt'):
    rls = []
    prompts = []
    for i in range(len(examples)):
        print('Example', i)
        query, names, target = examples[i].split('$')
        train_set = examples.copy()
        del train_set[i]
        random.shuffle(train_set)
        prompt, rl = translate_named(query, train_set, names, context=context, template=template, kwargs=kwargs, k=k)
        prompts += [prompt]
        rls += [rl]
    with open('out/' + out_file, 'w') as f:
        [f.write(rl.strip() + '\n') for rl in rls]
    with open('prompts/' + out_file, 'w') as f:
        [f.write(p.strip() + '\n----------\n') for p in prompts]

    