import openai
import time
import pandas as pd
import numpy as np
import spacy
from statistics import mean

openai.api_key = os.getenv("OPENAI_API_KEY")
nlp = spacy.load("en_core_web_sm")

unicodetoascii_mapping = {
    '\\xe2\\x80\\x99': "'",
    '\\xc3\\xa9': 'e',
    '\\xe2\\x80\\x90': '-',
    '\\xe2\\x80\\x91': '-',
    '\\xe2\\x80\\x92': '-',
    '\\xe2\\x80\\x93': '-',
    '\\xe2\\x80\\x94': '-',
    '\\xe2\\x80\\x94': '-',
    '\\xe2\\x80\\x98': "'",
    '\\xe2\\x80\\x9b': "'",
    '\\xe2\\x80\\x9c': '"',
    '\\xe2\\x80\\x9c': '"',
    '\\xe2\\x80\\x9d': '"',
    '\\xe2\\x80\\x9e': '"',
    '\\xe2\\x80\\x9f': '"',
    '\\xe2\\x80\\xa6': '...',
    '\\xe2\\x80\\xb2': "'",
    '\\xe2\\x80\\xb3': "'",
    '\\xe2\\x80\\xb4': "'",
    '\\xe2\\x80\\xb5': "'",
    '\\xe2\\x80\\xb6': "'",
    '\\xe2\\x80\\xb7': "'",
    '\\xe2\\x81\\xba': "+",
    '\\xe2\\x81\\xbb': "-",
    '\\xe2\\x81\\xbc': "=",
    '\\xe2\\x81\\xbd': "(",
    '\\xe2\\x81\\xbe': ")",
}


def get_default_kwargs_for_openai_call():
    return {
        "engine": "text-davinci-003",
        "temperature": 0,
        "max_tokens": 50,
        "top_p": 0,
        "logprobs": 5,
        "frequency_penalty": 0,
        "presence_penalty": 0,
    }


def get_openai_response(kwargs, return_complete_response=False):
    while(True):
        try:
            response = openai.Completion.create(**kwargs)
            if(return_complete_response):
                return response
            else:
                return response["choices"][0]["text"].strip()
        except Exception as e:
            print('Error occurred retrying. Error type: ', type(e).__name__)
            if type(e).__name__ == 'RateLimitError':
                time.sleep(10)


def get_openai_one_sentence_response(kwargs):

    def softmax(x):
        e_x = np.exp(x - np.max(x))
        return e_x / e_x.sum()

    def find_prefix(response_text, sentence):
        output_sentence = ""
        i = 0
        while (response_text[i] != sentence[0]):
            output_sentence += response_text[i]
            i += 1

        return output_sentence + sentence

    response = get_openai_response(kwargs, return_complete_response=True)
    df = pd.DataFrame(response["choices"][0]["logprobs"])

    probs = []
    tokens = []
    char_positions_using_offset = []
    len_prompt = len(kwargs["prompt"])
    softmax_for_all_tokens = []
    sentences = nlp(response["choices"][0]["text"].strip()).sents
    list_of_sentences = list(sentences)
    if(len(list_of_sentences) == 0):
        return None
    sentence = str(list_of_sentences[0]).rstrip()
    sentence = find_prefix(response["choices"][0]["text"], sentence)
    current_generation = ""
    current_char_position = 0
    current_bytes = ""
    additional_space = ""

    for idx, row in df.iterrows():
        keys = row["top_logprobs"]
        values = list(row["top_logprobs"].values())
        softmax_values = softmax(values)
        row_softmax = dict(zip(keys, softmax_values))
        softmax_for_all_tokens.append(row_softmax)

        if (row["tokens"].lstrip()[:6] == "bytes:"):
            current_bytes += row["tokens"].lstrip()[6:].strip()
            diff = len(row["tokens"].lstrip()[6:]) - len(row["tokens"].lstrip()[6:].strip())
            if(diff != 0):
                additional_space = row["tokens"].lstrip()[6: 6 + diff]
            if (current_bytes in unicodetoascii_mapping):
                probs.append(row_softmax[row["tokens"]])
                row["tokens"] = additional_space + unicodetoascii_mapping[current_bytes]
                tokens.append(row["tokens"])
                current_bytes = ""
                additional_space = ""
            else:
                continue
        else:
            probs.append(row_softmax[row["tokens"]])
            tokens.append(row["tokens"])

        current_generation += row["tokens"]

        number_of_leading_spaces = len(row["tokens"]) - len(row["tokens"].lstrip())
        if(number_of_leading_spaces == len(row["tokens"])):
            number_of_leading_spaces = 0

        starting_index = row["text_offset"] - len_prompt
        char_positions_using_offset.append([starting_index + number_of_leading_spaces, starting_index + len(row["tokens"])])
        current_char_position += len(row["tokens"])

        if(len(current_generation) >= len(sentence)):
            break
    if(len(sentence.strip()) != len(current_generation.strip())):
        print("************* Check for bugs ************ \n sentence != current_generation")
        print(len(sentence), len(current_generation))

    return {
        "all_probs": probs,
        "all_tokens": tokens,
        "char_positions": char_positions_using_offset,
        "current_generation": current_generation,
        "sentence": sentence,
        "response_text": response["choices"][0]["text"]
    }
