import openai
import time
import collections
import datasets
import pysbd
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from openai import AzureOpenAI
import numpy as np
import shelve

api_version = "2023-12-01-preview"
endpoint = # YOUR ENDPOINT

client = AzureOpenAI(
    api_key="c0e2e82bd18c4b62a77d09a17110c188",
    api_version=api_version,
    azure_endpoint=endpoint
)

template = """Paraphrase this sentence. Try to include the word "{include_word}" in your paraphrase without changing the meaning. If that's not possible, say "not possible".
Sentence: {sentence}"""

cache = shelve.open("responses_cache_horribleincredible_train")

def get_completion(sentence, include_word):
    #if sentence in cache:
    #    response = cache[sentence]
    if True:#else:
        prompt = template.format(include_word=include_word, sentence=sentence)
        chat_completion = client.chat.completions.create(
            model="gpt-4-0613",
            messages=[{"role": "user", "content": prompt}],
            temperature=1.0
        )

        try:
            response = chat_completion.choices[0].message.content
        except:
            response = None

    if response is None or 'not possible' in response.strip().lower():
        return None

    cache[sentence] = response
    return response


seg = pysbd.Segmenter(language="en", clean=True)

ds = datasets.load_from_disk("imdb_horribleincredible_withpara")
example_queue = collections.deque()

train = datasets.concatenate_datasets([ds['S0_train'], ds['S1_train']])
test = datasets.concatenate_datasets([ds['S0_test'], ds['S1_test']])

#keys = ['bad', 'good']
keys = ['horrible', 'incredible']


raw2 = 'Oh my god what a story! This movie is great and it had to be God who had this happen! You did a awesome job.The acting was excellent you picked the right actors for sure. This movie is so great I am really glad you made this because if you had not then I would have never ever known about this story because I am not a big golf fan and I think it is kinda boring so thank you. I really enjoyed it and that is why I gave the movie a 10/10. I liked Shia Labouf too he was perfect for the roll of Fransis Quimet. I hope most of that stuff you put in there was true also. Oh and some parts were funny and others I was just really happy.'


# switch train and test here and re-run to get test para
N = len(train)
for i,ex in enumerate(train):
    example_queue.append((i,ex))

outputs = shelve.open("gpt4_paraphrases_horribleincredible_train")

while len(example_queue) > 0:
    idx, ex = example_queue.popleft()
    if str(idx) in outputs:
        print(f'skip {idx+1}/{N}; remaining {len(example_queue)}')
        continue

    gold_label = ex['label']
    weak_label = ex['weak_label']

    # point is in Sigood, so include the wrong word
    if weak_label == gold_label:
        include_word = keys[1-weak_label]
    else: # point is in Sibad, so include the right word
        include_word = keys[gold_label]

    # NB: we could've done the above w/out knowing the true label by just swapping "good"/"bad"
    # i.e., "good" present ==> include_word = "bad"
    # then if a point is in S1good, it gets mapped to S1bad
    # if it's in S0bad, it gets mapped to S0good.
    # this way is just more convenient given how the data is already split up.

    # segment the sentences
    try:
        sentences = seg.segment(ex['Tpara1'])
    except:
        sentences = seg.segment(raw2)

    # pick random sentence to paraphrase
    if len(sentences) == 0:
        print(f"IDX {idx} HAS NO TPARA")
        continue

    index = np.random.choice(range(len(sentences)))
    sentence = sentences[index]
    completion = get_completion(sentence, include_word)
    print(f"\ntrying {sentence}\n")
    if completion is None or not (include_word in completion.lower()):
        print("failed")
        print(completion)
        breakpoint()
        example_queue.append((idx,ex))
    else:
        print(idx)
        print(sentence)
        print(completion)
        print("------")
        full_paraphrase = " ".join(sentences[:index])
        full_paraphrase += f" {completion}"
        if index < len(sentences)-1:
            full_paraphrase += " "
            full_paraphrase += " ".join(sentences[index+1:])
        outputs[str(idx)] = full_paraphrase

cache.close()
outputs.close()
