
"""
Load questions from generations.pkl
"""
from transformers import AutoTokenizer, AutoModelForCausalLM
import argparse
import concurrent.futures
import torch
import datasets
import pickle
import openai
import os
import numpy as np
import re

# parameters for generating most_likely_generation
most_likely_generation_params = {
    # you can use Babbage for debugging, then use 'Curie' for generation.
    'model': model,
    'max_tokens': 64,
    'temperature': 0,  # 0 for deterministic generation
    'n': 1
}

# parameters for generating normal generations
normal_generations_params = {
    # you can use Babbage for debugging, then use 'Curie' for generation.
    'model': model,
    'max_tokens': 64,
    'temperature': 0.5,  # for non-deterministic generation
    'n': 5  # generate 5 generations for each prompt
}

# parameters for generating similarity
similarity_params = {
    # you can use Babbage for debugging, then use 'Curie' for generation.
    'model': model,
    'max_tokens': 64,
    'temperature': 0,  # for non-deterministic generation
    'n': 1
}

def parse_questions(dataset_name):
    seed_value = 10
    if dataset_name == 'trivia_qa':
        dataset = datasets.load_from_disk(f'../../datasets/trivia_qa_opt')
        dataset = dataset.train_test_split(test_size=(1 - 0.06), seed=seed_value)['train']
    elif dataset_name == 'sciq':
        dataset = datasets.load_from_disk(f'../../datasets/sciq_opt')
    elif dataset_name == 'coqa':
        dataset = datasets.load_from_disk(f'../../datasets/coqa_dataset')
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)

    questions = []
    for sample in dataloader:
        if dataset_name == 'coqa':
            questions.append({
                'question': [sample['story'][0] + ' Q: ' + sample['question'][0] + ' A:'],
                'id': sample['id'],
                'answer': sample['answer']['text'],
                'additional_answers': [_[0] for _ in sample['additional_answers']]
            })
        else:
            questions.append({
                'question': sample['question'],
                'id': sample['question_id'],
                'answer': sample['answer'],
                'additional_answers': None
            })
    with open(f'{dataset_name}_for_online.pkl', 'wb') as f:
        pickle.dump(questions, f)

def get_most_likely_generation(prompt, params):
    '''
    get most likely generations
    :param prompt:
    :param params:
    :return: generated string
    '''
    response = openai.Completion.create(model=params['model'], temperature=params['temperature'], max_tokens=params['max_tokens'],
                                        prompt="{} Please directly return the answer.".format(prompt))
    return response.choices[0].text.strip('\n').split('\n')[0]


def get_normal_generations_and_logprobs(prompt, params):
    '''
    get generations and token-wise logprobs.
    For logprobs, please refer to https://platform.openai.com/docs/api-reference/completions/create#completions/create-logprobs
    :param prompt:
    :param params:
    :return: a dict with the following structure
    {
    'generations': ['sentence1', 'sentence2', 'sentence3', 'sentence4', 'sentence5'],
    'token_logprobs: [
                        [token1_logprob, token2_logprob, token3_logprob, ......], # logprobs of tokens in sentence1
                        [token1_logprob, token2_logprob, token3_logprob, ......], # logprobs of tokens in sentence2
                        [token1_logprob, token2_logprob, token3_logprob, ......], # logprobs of tokens in sentence3
                        [token1_logprob, token2_logprob, token3_logprob, ......], # logprobs of tokens in sentence4
                        [token1_logprob, token2_logprob, token3_logprob, ......], # logprobs of tokens in sentence5
                    ]
    }
    each logprob should be a float.

    ##### Very Important
    Note that make sure the number of tokens and the number of logprobs for each sentences should be same, for each sentence.
    This can be evaluated by something like
    assert len(tokenizer.encode(sentence1)) == len(logprobs of sentence1).
    More information about tokenizer can be found in https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer
    It is possible that tokenizer.encode() will add/remove some special tokens, such as pad_token, eos_token or something else.
    '''
    response = openai.Completion.create(model=params['model'], temperature=params['temperature'], max_tokens=params['max_tokens'], logprobs=1,
                                        prompt="{} Please directly return the answer.".format(prompt), n=params['n'])
    # generations = response.choices[0].text.split('\n')
    generations = [resp.text for resp in response.choices]
    generated_tokens = [resp.logprobs.tokens for resp in response.choices]
    token_logprobs = [resp.logprobs.token_logprobs for resp in response.choices]

    for tokens, logprobs in zip(generated_tokens, token_logprobs):
        assert len(tokens) == len(logprobs)

    struct = {}
    struct['generations'] = generations
    struct['generated_tokens'] = generated_tokens
    struct['token_logprobs'] = token_logprobs

    return struct


def get_similarity(prompt, params):
    '''
    get similarity
    Note taht once you called openai api to get this result, you may need regular expression to extract this similarity from the generated sentences.
    You can simply recognize any float and mean these floats
    :param prompt:
    :param params:
    :return: a float in (0, 1)
    '''
    response = openai.Completion.create(model=params['model'], temperature=params['temperature'], max_tokens=params['max_tokens'], logprobs=1,
                                        prompt=prompt)
    response = response.choices[0].text
    # response = response.split('\n')
    results = []
    for res in re.findall(r"[-+]?(?:\d*\.*\d+)", response):
        try:
            res = float(res)
            results.append(res)
        except:
            continue

    if len(results) == 0:
        print(response)
        results = [0]
    return torch.asarray(results).mean()


def online_generate_worker(idx, sample, generations):
    print(f'{idx} / 1000')
    question = sample['question'][0]
    id = sample['id'][0]
    answer = sample['answer'][0]

    # generate most likely generations
    most_likely_generations = get_most_likely_generation(
        prompt=question, params=most_likely_generation_params)

    # generate normal generations
    normal_generations = get_normal_generations_and_logprobs(prompt=question, params=normal_generations_params)

    # generate similarity between most_likely_generation and real answer
    # you can use the following prompt to ask a similarity between the generated answer and the real answer
    prompt = f'Given question "{question}", the real answer is "{answer}", ' \
             f'the answer generated by a language model is "{most_likely_generations}". ' \
             f'What is the similarity between the generated answer and the real answer regarding the given question.' \
             f'Higher similarity means the generated answer is closer to the real answer. Please return this similarity' \
             f'on scale of 0 to 1.'
    similarity = get_similarity(prompt=prompt, params=similarity_params)
    print(answer, most_likely_generations, similarity, normal_generations['generations'])
    generations.append({
        'id': id,
        'question': question,
        'most_likeli_generated_text': most_likely_generations,
        'generations': normal_generations['generations'],
        'token_wise_entropy': normal_generations['token_logprobs'],
        'generated_tokens': normal_generations['generated_tokens'],
        'answer': [answer],
        'additional_answers': sample['additional_answers'],
        'similarity': similarity
    })


def main(dataset, model):
    openai.api_key = ""  # API Key goes here
    with open(fr"{dataset}_for_online.pkl", 'rb') as f:
        questions = pickle.load(f)
    num = 1000
    questions = questions[:num]

    generations = []
    # set 1 for debug, then you can set like 10-20 to specify how many questions will be sent to openai
    # at the same time to speed up the generation
    num_workers = 30

    with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
        for idx, sample in enumerate(questions):
            executor.submit(online_generate_worker,
                            idx, sample, generations)

    with open(f'{dataset}_{model}_generations.pkl', 'wb') as f:
        pickle.dump(generations, f)

def cmdline_args():
    # Make parser object
    p = argparse.ArgumentParser(description=__doc__,
                                formatter_class=argparse.RawDescriptionHelpFormatter)

    p.add_argument("--measurement-model", default='cross-encoder/stsb-roberta-large',
                   choices=['cross-encoder/stsb-roberta-large',
                            'cross-encoder/stsb-distilroberta-base'],
                   help="desc")
    p.add_argument('--dataset', default='')
    p.add_argument('--model', default='')

    return (p.parse_args())

if __name__ == '__main__':
    args = cmdline_args()
    parse_questions(args.dataset)
    main(args.dataset, args.model)
