import numpy as np
import time
from copy import deepcopy
import os
import sys
import torch
import pickle
import openai
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

ROOT_DIR = os.path.dirname(os.path.realpath(__file__))
SAVE_DIR = os.path.join(ROOT_DIR, 'saved_results')
if not os.path.isdir(SAVE_DIR):
    os.mkdir(SAVE_DIR)
    print(f"mkdir at {SAVE_DIR} for saving results")

def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def chunk_size_helper(params):
    # Set the batch size (the size of the chunks determines the batch size). Default to 4 for GPT-2 and 20 for OpenAI if
    # no batch size is specified.
    bs = params['bs']
    if bs is None:
        if 'gpt2' in params['model']:
            return 1
        else:
            assert params['model'] in ['ada', 'babbage', 'curie', 'davinci', 'ada-beta', 'babbage-beta', 'curie-beta', 'davinci-beta']
            return 20
    else:
        return bs

def random_sampling(sentences, labels, num):
    """randomly sample subset of the training pairs"""
    assert len(sentences) == len(labels)
    if num > len(labels):
        assert False, f"you tried to randomly sample {num}, which is more than the total size of the pool {len(labels)}"
    idxs = np.random.choice(len(labels), size=num, replace=False)
    selected_sentences = [sentences[i] for i in idxs]
    selected_labels = [labels[i] for i in idxs]
    return deepcopy(selected_sentences), deepcopy(selected_labels)

gpt2_model = None
gpt2_tokenizer = None
def setup_gpt2(model_name):
    # load the GPT-2 model
    global gpt2_model
    global gpt2_tokenizer
    if gpt2_model is None:
        print("Setting up GPT-2 model")
        gpt2_model = GPT2LMHeadModel.from_pretrained(model_name)
        gpt2_model.eval().cuda()
        
        gpt2_tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        # to batch generation, we pad on the left and mask those positions out.
        gpt2_tokenizer.padding_side = "left"
        gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token
        gpt2_model.config.pad_token_id = gpt2_model.config.eos_token_id
        print("Finished")

def setup_gpt3():
    # get OpenAI access key
    with open(os.path.join(ROOT_DIR, 'openai_key.txt'), 'r') as f:
        key = f.readline().strip()
        openai.api_key = key

def complete_gpt2(prompt, l=10, model_name='gpt2-xl', num_log_probs=None, echo=False):
    ''' This function runs GPT-2 locally but places the outputs into an json that looks just like the one
     provided by the OpenAI API. '''
    if isinstance(prompt, str):
        prompt = [prompt] # the code below assumes a list
    input_ids = gpt2_tokenizer.batch_encode_plus(prompt, return_tensors="pt", padding=True)
    
    # greedily generate l tokens
    if l > 0:
        # the generate function can handle left padded inputs automatically in HF
        # total_sequences is now the input + possible generated output
        total_sequences = gpt2_model.generate(input_ids=input_ids['input_ids'].cuda(), attention_mask=input_ids['attention_mask'].cuda(), max_length=l + len(input_ids['input_ids'][0]), do_sample=False)
    else:
        assert echo == True and l == 0
        total_sequences = input_ids['input_ids'].cuda()

    # they want the probs of the top tokens
    if num_log_probs is not None:
        # we are left padding, so we need to adjust the position IDs
        attention_mask = (total_sequences != 50256).float()
        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 1)
        # get the logits for the context and the next l tokens
        logits = gpt2_model.forward(input_ids=total_sequences, attention_mask=attention_mask, position_ids=position_ids, return_dict=True).logits.detach().cpu()
        if not echo:
            # get the top tokens and probs for the generated l tokens
            probs = torch.softmax(logits[:,-l-1:], dim=2).cpu()
        else:
            # get the top tokens and probs for the context and the generated l tokens
            probs = torch.softmax(logits, dim=2).cpu()
        top_probs, top_tokens = torch.topk(probs, k=num_log_probs)
        logprobs = torch.log(probs)
        top_log_probs = torch.log(top_probs)

    # create the return value to resemble OpenAI
    return_json = {}
    choices = []
    for batch_id in range(len(prompt)):
        curr_json = {}
        # text is just the optional context and next l tokens
        if not echo:
            curr_json['text'] = gpt2_tokenizer.decode(total_sequences[batch_id][-l:], skip_special_tokens=True)
        else:
            curr_json['text'] = gpt2_tokenizer.decode(total_sequences[batch_id], skip_special_tokens=True)

        # fill the return json with the top tokens and probs to match the OpenAI return value.
        if num_log_probs is not None:
            curr_json['logprobs'] = {}
            curr_json['logprobs']['top_logprobs'] = []
            curr_json['logprobs']['token_logprobs'] = []
            curr_json['logprobs']['tokens'] = []
            if not echo:
                # cutoff the -1 here because the probs are shifted one over for LMs
                for current_element_top_log_probs, current_element_top_tokens in zip(top_log_probs[batch_id][:-1], top_tokens[batch_id][:-1]):
                    # tokens is a list of the top token at each position
                    curr_json['logprobs']['tokens'].append(gpt2_tokenizer.decode([current_element_top_tokens[0]]))
                    # token_logprobs is a list of the logprob of the top token at each position
                    curr_json['logprobs']['token_logprobs'].append(current_element_top_log_probs[0].item())
                    # top_logprobs is a list of dicts for the top K tokens. with each entry being {'token_name': log_prob}
                    temp = {}
                    for log_prob, token in zip(current_element_top_log_probs, current_element_top_tokens):
                        temp[gpt2_tokenizer.decode(token.item())] = log_prob.item()
                    curr_json['logprobs']['top_logprobs'].append(temp)
            else:
                # same as not above but small tweaks
                # we add null to the front because for the GPT models, they have null probability for the first token
                # (for some reason they don't have an beginning of sentence token)
                curr_json['logprobs']['top_logprobs'].append('null')
                # cutoff the -1 here because the probs are shifted one over for LMs
                for index, (current_element_top_log_probs, current_element_top_tokens) in enumerate(zip(top_log_probs[batch_id][:-1], top_tokens[batch_id][:-1])):
                    # skip padding tokens
                    if total_sequences[batch_id][index].item() == 50256:
                        continue
                    temp = {}
                    for log_prob, token in zip(current_element_top_log_probs, current_element_top_tokens):
                        temp[gpt2_tokenizer.decode(token.item())] = log_prob.item()
                    curr_json['logprobs']['top_logprobs'].append(temp)
                for index in range(len(probs[batch_id])):
                    curr_json['logprobs']['tokens'].append(gpt2_tokenizer.decode([total_sequences[batch_id][index]]))
                curr_json['logprobs']['token_logprobs'].append('null')
                for index, log_probs_token_position_j in enumerate(logprobs[batch_id][:-1]):
                    # probs are left shifted for LMs 
                    curr_json['logprobs']['token_logprobs'].append(log_probs_token_position_j[total_sequences[batch_id][index+1]])

        choices.append(curr_json)
    return_json['choices'] = choices
    return return_json

def complete_gpt3(prompt, l, model_name, temp=0, num_log_probs=None, echo=False, n=None):
    # call GPT-3 API until result is provided and then return it
    response = None
    received = False
    while not received:
        try:
            response = openai.Completion.create(engine=model_name, prompt=prompt, max_tokens=l, temperature=temp,
                                                logprobs=num_log_probs, echo=echo, stop='\n', n=n)
            received = True
        except:
            error = sys.exc_info()[0]
            if error == openai.error.InvalidRequestError: # something is wrong: e.g. prompt too long
                print(f"InvalidRequestError\nPrompt passed in:\n\n{prompt}\n\n")
                assert False

            print("API error:", error)
            time.sleep(1)
    return response

def complete(prompt, l, model, temp=0, num_log_probs=None, echo=False, n=None):
    """complete the prompt using a language model"""
    assert l >= 0
    assert temp >= 0
    if 'gpt2' in model:
        assert n == None # unsupported at the moment
        assert temp == 0 # unsupported at the moment
        setup_gpt2(model)
        return complete_gpt2(prompt, l=l, model_name=model, num_log_probs=num_log_probs, echo=echo)
    else:
        setup_gpt3()
        return complete_gpt3(prompt, l=l, model_name=model, num_log_probs=num_log_probs, echo=echo, n=n)

def construct_prompt(params, train_sentences, train_labels, test_sentence):
    """construct a single prompt to be fed into the model"""
    # special case when the user defines a custom prompt function. 
    if ('prompt_func' in params.keys()) and (params['prompt_func'] is not None):
        return params['prompt_func'](params, train_sentences, train_labels, test_sentence)

    # take the prompt template and fill in the training and test example
    prompt = params["prompt_prefix"]
    q_prefix = params["q_prefix"]
    a_prefix = params["a_prefix"]
    for s, l in zip(train_sentences, train_labels):
        if isinstance(s, dict):
            prompt += q_prefix.format_map(s)
            prompt += '\n'
        else:
            prompt += q_prefix
            prompt += s + "\n"
        if isinstance(l, int) or isinstance(l, np.int32) or isinstance(l, np.int64): # integer labels for classification
            assert params['task_format'] == 'classification'
            l_str = params["label_dict"][l][0] if isinstance(params["label_dict"][l], list) else params["label_dict"][l]
        else:
            assert isinstance(l, str) # string labels
            assert params['task_format'] == 'qa'
            l_str = l

        prompt += a_prefix
        prompt += l_str + "\n\n"

    if isinstance(test_sentence, dict):
        prompt += q_prefix.format_map(test_sentence)
        prompt += "\n"
    else:
        prompt += q_prefix
        prompt += test_sentence + "\n"
    assert a_prefix[-1] == ' '
    prompt += a_prefix[:-1] # GPT models do not want a trailing space, so we cut off -1
    return prompt



def construct_prompt_answer_first(params, train_sentences, train_labels, test_sentence):
    """construct a single prompt to be fed into the model"""
    # special case when the user defines a custom prompt function. 
    if ('prompt_func' in params.keys()) and (params['prompt_func'] is not None):
        return params['prompt_func'](params, train_sentences, train_labels, test_sentence)

    # take the prompt template and fill in the training and test example
    prompt = params["prompt_prefix"]
    q_prefix = params["q_prefix"]
    a_prefix = params["a_prefix"]
    for s, l in zip(train_sentences, train_labels):

        if isinstance(l, int) or isinstance(l, np.int32) or isinstance(l, np.int64): # integer labels for classification
            assert params['task_format'] == 'classification'
            l_str = params["label_dict"][l][0] if isinstance(params["label_dict"][l], list) else params["label_dict"][l]
        else:
            assert isinstance(l, str) # string labels
            assert params['task_format'] == 'qa'
            l_str = l
        prompt += "Type: "
        prompt += l_str + "\n"
        prompt += q_prefix
        prompt += s + "\n"
        prompt += a_prefix
        prompt += l_str + "\n\n"

    # prompt += a_prefix
    # prompt += "\n"
    prompt += q_prefix
    prompt += test_sentence + "\n"
    assert a_prefix[-1] == ' '
    prompt += a_prefix[:-1] # GPT models do not want a trailing space, so we cut off -1
    return prompt


def construct_prompt_with_correctness(params, train_sentences, train_labels, test_sentence):
    """construct a single prompt to be fed into the model"""
    # special case when the user defines a custom prompt function. 
    if ('prompt_func' in params.keys()) and (params['prompt_func'] is not None):
        return params['prompt_func'](params, train_sentences, train_labels, test_sentence)

    # take the prompt template and fill in the training and test example
    prompt = params["prompt_prefix"]
    q_prefix = params["q_prefix"]
    a_prefix = params["a_prefix"]
    # {0: ['World'], 1: ['Sports'], 2: ['Business'], 3: ['Technology', 'Science']}
    # {'World': 0, 'Sports': 1, 'Business': 2, 'Technology': 3, 'Science': 3} 
    inv_label_dict = params['inv_label_dict']
    for s, l in zip(train_sentences, train_labels):
        if isinstance(s, dict):
            prompt += q_prefix.format_map(s)
            prompt += '\n'
        else:
            prompt += q_prefix
            prompt += s + "\n"
        # l is the label
        if isinstance(l, int) or isinstance(l, np.int32) or isinstance(l, np.int64): # integer labels for classification
            assert params['task_format'] == 'classification'

            # l_str = ''
            gold_str = params["label_dict"][l][0] if isinstance(params["label_dict"][l], list) else params["label_dict"][l]
            for candidate_str in inv_label_dict:
                if candidate_str == gold_str:
                    res = "Yes"
                else:
                    res = 'No'

                l_str = candidate_str + '    Correctness: ' + res
                prompt += a_prefix
                prompt += l_str + "\n"

            prompt += '\n'
        else:
            assert isinstance(l, str) # string labels
            assert params['task_format'] == 'qa'
            l_str = l + '    Correctness: Yes'

            prompt += a_prefix
            prompt += l_str + "\n\n"

    # only the test sentence, don't have the answer_token, add it on the outside.
    # prompt += q_prefix
    # prompt += test_sentence + "\n"

    # assert a_prefix[-1] == ' '
    # prompt += a_prefix[:-1] # GPT models do not want a trailing space, so we cut off -1
    return prompt

# def construct_prompt_without_test_nonextline(params, train_sentences, train_labels, test_sentence):
#     """construct a single prompt to be fed into the model"""
#     # special case when the user defines a custom prompt function. 
#     # print("prompt = ", prompt)
#     # print("instance_prompt = ", instance_prompt)
#     if ('prompt_func' in params.keys()) and (params['prompt_func'] is not None):
#         return params['prompt_func'](params, train_sentences, train_labels, test_sentence, separate_input=True)

#     # take the prompt template and fill in the training and test example
#     prompt = params["prompt_prefix"]
#     q_prefix = params["q_prefix"]
#     a_prefix = params["a_prefix"]
#     for s, l in zip(train_sentences, train_labels):
#         if isinstance(s, dict):
#             prompt += q_prefix.format_map(s)
#             prompt += '-'
#         else:
#             prompt += q_prefix
#             prompt += s + "-"
#         if isinstance(l, int) or isinstance(l, np.int32) or isinstance(l, np.int64): # integer labels for classification
#             assert params['task_format'] == 'classification'
#             l_str = params["label_dict"][l][0] if isinstance(params["label_dict"][l], list) else params["label_dict"][l]
#         else:
#             assert isinstance(l, str) # string labels
#             assert params['task_format'] == 'qa'
#             l_str = l

#         prompt += a_prefix
#         prompt += l_str + "-"
#     if isinstance(test_sentence, dict):
#         instance_prompt = q_prefix.format_map(test_sentence)
#         instance_prompt += "-"
#     else:
#         instance_prompt = q_prefix
#         instance_prompt += test_sentence + "-"
#     # instance_prompt = q_prefix

#     # instance_prompt += test_sentence + "\n"
#     assert a_prefix[-1] == ' '
#     instance_prompt += a_prefix[:-1] # GPT models do not want a trailing space, so we cut off -1
#     # print("prompt = ", prompt)
#     # print("instance_prompt = ", instance_prompt)
#     return prompt, instance_prompt

def construct_prompt_without_test(params, train_sentences, train_labels, test_sentence):
    """construct a single prompt to be fed into the model"""
    # special case when the user defines a custom prompt function. 
    # print("prompt = ", prompt)
    # print("instance_prompt = ", instance_prompt)
    if ('prompt_func' in params.keys()) and (params['prompt_func'] is not None):
        return params['prompt_func'](params, train_sentences, train_labels, test_sentence, separate_input=True)

    # take the prompt template and fill in the training and test example
    prompt = params["prompt_prefix"]
    q_prefix = params["q_prefix"]
    a_prefix = params["a_prefix"]
    for s, l in zip(train_sentences, train_labels):
        if isinstance(s, dict):
            prompt += q_prefix.format_map(s)
            prompt += '\n'
        else:
            prompt += q_prefix
            prompt += s + "\n"
        if isinstance(l, int) or isinstance(l, np.int32) or isinstance(l, np.int64): # integer labels for classification
            assert params['task_format'] == 'classification'
            l_str = params["label_dict"][l][0] if isinstance(params["label_dict"][l], list) else params["label_dict"][l]
        else:
            assert isinstance(l, str) # string labels
            assert params['task_format'] == 'qa'
            l_str = l

        prompt += a_prefix
        prompt += l_str + "\n\n"
    if isinstance(test_sentence, dict):
        instance_prompt = q_prefix.format_map(test_sentence)
        instance_prompt += "\n"
    else:
        instance_prompt = q_prefix
        instance_prompt += test_sentence + "\n"
    # instance_prompt = q_prefix

    # instance_prompt += test_sentence + "\n"
    assert a_prefix[-1] == ' '
    instance_prompt += a_prefix[:-1] # GPT models do not want a trailing space, so we cut off -1
    # print("prompt = ", prompt)
    # print("instance_prompt = ", instance_prompt)
    return prompt, instance_prompt

def construct_repeat_prompt(params, train_sentences, train_labels, test_sentence):
    """construct a single prompt to be fed into the model"""
    # special case when the user defines a custom prompt function. 
    if ('prompt_func' in params.keys()) and (params['prompt_func'] is not None):
        return params['prompt_func'](params, train_sentences, train_labels, test_sentence)

    # take the prompt template and fill in the training and test example
    prompt = "Repeat the sentence. \n\n"
    q_prefix = "Sentence: "
    a_prefix = "Repeat: "
    for s, l in zip(train_sentences, train_labels):
        # if isinstance(s, dict):
            # prompt += q_prefix.format_map(s)
            # prompt += '\n'
        # else:
            # prompt += q_prefix
            # prompt += s + "\n"
        # if isinstance(l, int) or isinstance(l, np.int32) or isinstance(l, np.int64): # integer labels for classification
        #     assert params['task_format'] == 'classification'
        #     l_str = params["label_dict"][l][0] if isinstance(params["label_dict"][l], list) else params["label_dict"][l]
        # else:
        #     assert isinstance(l, str) # string labels
        #     assert params['task_format'] == 'qa'
        #     l_str = l
        prompt += q_prefix
        prompt += s + "\n"
        prompt += a_prefix
        prompt += s + "\n\n"

    # if isinstance(test_sentence, dict):
    #     prompt += q_prefix.format_map(test_sentence)
    #     prompt += "\n"
    # else:
    prompt += q_prefix
    prompt += test_sentence + "\n"
    assert a_prefix[-1] == ' '
    prompt += a_prefix[:-1] # GPT models do not want a trailing space, so we cut off -1
    return prompt

def construct_repeat_prompt_without_test(params, train_sentences, train_labels, test_sentence):
    """construct a single prompt to be fed into the model"""
    # special case when the user defines a custom prompt function. 
    # print("prompt = ", prompt)
    # print("instance_prompt = ", instance_prompt)
    if ('prompt_func' in params.keys()) and (params['prompt_func'] is not None):
        return params['prompt_func'](params, train_sentences, train_labels, test_sentence, separate_input=True)

    # take the prompt template and fill in the training and test example
    prompt = "Repeat the sentence. \n\n"
    q_prefix = "Sentence: "
    a_prefix = "Repeat: "
    for s, l in zip(train_sentences, train_labels):
        prompt += q_prefix
        prompt += s + "\n"
        prompt += a_prefix
        prompt += s + "\n\n"
    # if isinstance(test_sentence, dict):
    #     instance_prompt = q_prefix.format_map(test_sentence)
    #     instance_prompt += "\n"
    # else:
    instance_prompt = q_prefix
    instance_prompt += test_sentence + "\n"
    # instance_prompt = q_prefix

    # instance_prompt += test_sentence + "\n"
    assert a_prefix[-1] == ' '
    instance_prompt += a_prefix[:-1] # GPT models do not want a trailing space, so we cut off -1
    # print("prompt = ", prompt)
    # print("instance_prompt = ", instance_prompt)
    return prompt, instance_prompt

def construct_repeat_one_prompt(params, train_sentences, train_labels, test_sentence):
    """construct a single prompt to be fed into the model"""
    # special case when the user defines a custom prompt function. 
    if ('prompt_func' in params.keys()) and (params['prompt_func'] is not None):
        return params['prompt_func'](params, train_sentences, train_labels, test_sentence)

    # take the prompt template and fill in the training and test example
    prompt = "Repeat the sentence. \n\n"
    q_prefix = "Sentence: "
    a_prefix = "Repeat: "
    for s, l in zip(train_sentences, train_labels):
        # if isinstance(s, dict):
            # prompt += q_prefix.format_map(s)
            # prompt += '\n'
        # else:
            # prompt += q_prefix
            # prompt += s + "\n"
        # if isinstance(l, int) or isinstance(l, np.int32) or isinstance(l, np.int64): # integer labels for classification
        #     assert params['task_format'] == 'classification'
        #     l_str = params["label_dict"][l][0] if isinstance(params["label_dict"][l], list) else params["label_dict"][l]
        # else:
        #     assert isinstance(l, str) # string labels
        #     assert params['task_format'] == 'qa'
        #     l_str = l
        prompt += q_prefix
        prompt += s + "\n"
        for i in range(4):
            prompt += a_prefix
            prompt += s + "\n"
        break

    # if isinstance(test_sentence, dict):
    #     prompt += q_prefix.format_map(test_sentence)
    #     prompt += "\n"
    # else:
    # prompt += q_prefix
    # prompt += test_sentence + "\n"
    assert a_prefix[-1] == ' '
    prompt += a_prefix[:-1] # GPT models do not want a trailing space, so we cut off -1
    return prompt

def construct_repeat_one_prompt_without_test(params, train_sentences, train_labels, test_sentence):
    """construct a single prompt to be fed into the model"""
    # special case when the user defines a custom prompt function. 
    # print("prompt = ", prompt)
    # print("instance_prompt = ", instance_prompt)
    if ('prompt_func' in params.keys()) and (params['prompt_func'] is not None):
        return params['prompt_func'](params, train_sentences, train_labels, test_sentence, separate_input=True)

    # take the prompt template and fill in the training and test example
    prompt = "Repeat the sentence. \n\n"
    q_prefix = "Sentence: "
    a_prefix = "Repeat: "
    for s, l in zip(train_sentences, train_labels):
        prompt += q_prefix
        prompt += s + "\n"
        for i in range(1):
            prompt += a_prefix
            prompt += s + "\n"
        break
    # if isinstance(test_sentence, dict):
    #     instance_prompt = q_prefix.format_map(test_sentence)
    #     instance_prompt += "\n"
    # else:
    # instance_prompt = q_prefix
    # instance_prompt += test_sentence + "\n"
    # instance_prompt = q_prefix

    # instance_prompt += test_sentence + "\n"
    assert a_prefix[-1] == ' '
    instance_prompt = a_prefix[:-1] # GPT models do not want a trailing space, so we cut off -1
    # print("prompt = ", prompt)
    # print("instance_prompt = ", instance_prompt)
    return prompt, instance_prompt

def construct_prompt_without_test_emptyanswer(params, train_sentences, train_labels, test_sentence):
    """construct a single prompt to be fed into the model"""
    # special case when the user defines a custom prompt function. 
    # print("prompt = ", prompt)
    # print("instance_prompt = ", instance_prompt)
    if ('prompt_func' in params.keys()) and (params['prompt_func'] is not None):
        return params['prompt_func'](params, train_sentences, train_labels, test_sentence, separate_input=True)

    # take the prompt template and fill in the training and test example
    prompt = params["prompt_prefix"]
    q_prefix = params["q_prefix"]
    a_prefix = params["a_prefix"]
    for s, l in zip(train_sentences, train_labels):
        if isinstance(s, dict):
            prompt += q_prefix.format_map(s)
            prompt += '\n'
        else:
            prompt += q_prefix
            prompt += s + "\n"
        # if isinstance(l, int) or isinstance(l, np.int32) or isinstance(l, np.int64): # integer labels for classification
        #     assert params['task_format'] == 'classification'
        #     l_str = params["label_dict"][l][0] if isinstance(params["label_dict"][l], list) else params["label_dict"][l]
        # else:
        #     assert isinstance(l, str) # string labels
        #     assert params['task_format'] == 'qa'
        #     l_str = l

        prompt += a_prefix
        prompt += "\n\n"

    if isinstance(test_sentence, dict):
        instance_prompt = q_prefix.format_map(test_sentence)
        instance_prompt += "\n"
    else:
        instance_prompt = q_prefix
        instance_prompt += test_sentence + "\n"
    assert a_prefix[-1] == ' '
    instance_prompt += a_prefix[:-1] # GPT models do not want a trailing space, so we cut off -1
    # print("prompt = ", prompt)
    # print("instance_prompt = ", instance_prompt)
    return prompt, instance_prompt

def construct_prompt_with_random_prefix(params, train_sentences, train_labels, test_sentence):
    """construct a single prompt to be fed into the model"""
    # special case when the user defines a custom prompt function. 
    if ('prompt_func' in params.keys()) and (params['prompt_func'] is not None):
        return params['prompt_func'](params, train_sentences, train_labels, test_sentence)

    # take the prompt template and fill in the training and test example
    prompt = params["prompt_prefix"]
    q_prefix = params["q_prefix"]
    a_prefix = params["a_prefix"]
    assert isinstance(q_prefix, list)
    assert isinstance(a_prefix, list)
    # q_prefix = ['dsafjklda: ', 'ewqroudajf: ', 'eqdashcxzl: ', 'cxzvadeqr: ', 'vcxnkfgah: ']
    # a_prefix = ['xiadfjdsal: ', 'yufoufgad: ', 'afdgvcxjl: ', 'fgsgfskj: ', 'dafhglajfd: ']
    # q_prefix = ['dsafjklda: ' for i in range(5)]
    # a_prefix = ['dsafjklda: ' for i in range(5)]
    pos = 0
    for s, l in zip(train_sentences, train_labels):
        # prompt += q_prefix[pos]
        # prompt += s + "\n"
        if isinstance(s, dict):
            prompt += q_prefix[pos].format_map(s)
            prompt += '\n'
        else:
            prompt += q_prefix[pos]
            prompt += s + "\n"
        if isinstance(l, int) or isinstance(l, np.int32) or isinstance(l, np.int64): # integer labels for classification
            assert params['task_format'] == 'classification'
            l_str = params["label_dict"][l][0] if isinstance(params["label_dict"][l], list) else params["label_dict"][l]
        else:
            assert isinstance(l, str) # string labels
            assert params['task_format'] == 'qa'
            l_str = l

        prompt += a_prefix[pos]
        prompt += l_str + "\n\n"
        pos += 1

    if isinstance(test_sentence, dict):
        prompt += q_prefix[pos].format_map(test_sentence)
        prompt += "\n"
    else:
        prompt += q_prefix[pos]
        prompt += test_sentence + "\n"
    # prompt += q_prefix[pos]
    # prompt += test_sentence + "\n"
    # assert a_prefix[-1] == ' '
    prompt += a_prefix[pos] # GPT models do not want a trailing space, so we cut off -1
    return prompt

def construct_prompt_with_random_prefix_without_test(params, train_sentences, train_labels, test_sentence):
    """construct a single prompt to be fed into the model"""
    # special case when the user defines a custom prompt function. 
    if ('prompt_func' in params.keys()) and (params['prompt_func'] is not None):
        return params['prompt_func'](params, train_sentences, train_labels, test_sentence)

    # take the prompt template and fill in the training and test example
    prompt = params["prompt_prefix"]
    q_prefix = params["q_prefix"]
    a_prefix = params["a_prefix"]
    assert isinstance(q_prefix, list)
    assert isinstance(a_prefix, list)
    # q_prefix = ['dsafjklda: ', 'ewqroudajf: ', 'eqdashcxzl: ', 'cxzvadeqr: ', 'vcxnkfgah: ']
    # a_prefix = ['xiadfjdsal: ', 'yufoufgad: ', 'afdgvcxjl: ', 'fgsgfskj: ', 'dafhglajfd: ']
    # q_prefix = ['dsafjklda: ' for i in range(5)]
    # a_prefix = ['dsafjklda: ' for i in range(5)]
    pos = 0
    for s, l in zip(train_sentences, train_labels):
        # prompt += q_prefix[pos]
        # prompt += s + "\n"
        if isinstance(s, dict):
            prompt += q_prefix[pos].format_map(s)
            prompt += '\n'
        else:
            prompt += q_prefix[pos]
            prompt += s + "\n"
        if isinstance(l, int) or isinstance(l, np.int32) or isinstance(l, np.int64): # integer labels for classification
            assert params['task_format'] == 'classification'
            l_str = params["label_dict"][l][0] if isinstance(params["label_dict"][l], list) else params["label_dict"][l]
        else:
            assert isinstance(l, str) # string labels
            assert params['task_format'] == 'qa'
            l_str = l

        prompt += a_prefix[pos]
        prompt += l_str + "\n\n"
        # print("prompt = ", prompt)
        pos += 1

    # if isinstance(test_sentence, dict):
    #     prompt += q_prefix[pos].format_map(test_sentence)
    #     prompt += "\n"
    # else:
    #     prompt += q_prefix[pos]
    #     prompt += test_sentence + "\n"
    # prompt += q_prefix[pos]
    # prompt += test_sentence + "\n"
    # assert a_prefix[-1] == ' '
    # prompt += a_prefix[pos] # GPT models do not want a trailing space, so we cut off -1
    # print("prompt = ", prompt)
    return prompt




def construct_prompt_for_one_demonstration(params, train_sentences, train_labels):
    """construct a single prompt to be fed into the model"""
    # special case when the user defines a custom prompt function. 
    # print("prompt = ", prompt)
    # print("instance_prompt = ", instance_prompt)
    # if ('prompt_func' in params.keys()) and (params['prompt_func'] is not None):
    #     return params['prompt_func'](params, train_sentences, train_labels, test_sentence, separate_input=True)

    # take the prompt template and fill in the training and test example
    prompt = ''
    q_prefix = params["q_prefix"]
    a_prefix = params["a_prefix"]
    for s, l in zip(train_sentences, train_labels):
        prompt += q_prefix
        prompt += s + "\n"
        if isinstance(l, int) or isinstance(l, np.int32) or isinstance(l, np.int64): # integer labels for classification
            assert params['task_format'] == 'classification'
            l_str = params["label_dict"][l][0] if isinstance(params["label_dict"][l], list) else params["label_dict"][l]
        else:
            assert isinstance(l, str) # string labels
            assert params['task_format'] == 'qa'
            l_str = l

        prompt += a_prefix
        prompt += l_str + "\n\n"

    # instance_prompt = q_prefix

    # instance_prompt += test_sentence + "\n"
    # assert a_prefix[-1] == ' '
    # instance_prompt += a_prefix[:-1] # GPT models do not want a trailing space, so we cut off -1
    # print("prompt = ", prompt)
    # print("instance_prompt = ", instance_prompt)
    return prompt

def construct_prompt_instance_prompt_text(params, train_sentences, train_labels, test_sentence):
    """construct a single prompt to be fed into the model"""
    # special case when the user defines a custom prompt function. 
    # print("prompt = ", prompt)
    # print("instance_prompt = ", instance_prompt)
    if ('prompt_func' in params.keys()) and (params['prompt_func'] is not None):
        return params['prompt_func'](params, train_sentences, train_labels, test_sentence, separate_input=False, separate_prompt=True)

    # take the prompt template and fill in the training and test example
    prompt = params["prompt_prefix"]

    instance = ""
    q_prefix = params["q_prefix"]
    a_prefix = params["a_prefix"]
    for s, l in zip(train_sentences, train_labels):
        if isinstance(s, dict):
            instance += q_prefix.format_map(s)
            instance += '\n'
        else:
            instance += q_prefix
            instance += s + "\n"
        # instance += q_prefix
        # instance += s + "\n"
        if isinstance(l, int) or isinstance(l, np.int32) or isinstance(l, np.int64): # integer labels for classification
            assert params['task_format'] == 'classification'
            l_str = params["label_dict"][l][0] if isinstance(params["label_dict"][l], list) else params["label_dict"][l]
        else:
            assert isinstance(l, str) # string labels
            assert params['task_format'] == 'qa'
            l_str = l

        instance += a_prefix
        instance += l_str + "\n\n"

    if isinstance(s, dict):
        test_prompt += q_prefix.format_map(s)
        test_prompt += '\n'
    else:
        test_prompt += q_prefix
        test_prompt += s + "\n"

    # test_prompt = q_prefix
    # test_prompt += test_sentence + "\n"

    assert a_prefix[-1] == ' '
    test_prompt += a_prefix[:-1] # GPT models do not want a trailing space, so we cut off -1
    # print("prompt = ", prompt)
    # print("instance_prompt = ", instance_prompt)
    return instance, prompt, test_prompt





def get_model_response(params, train_sentences, train_labels, test_sentences, return_all_prompts=False,
                       num_tokens_to_predict_override=None, override_prompt=None):
    """
    Obtain model's responses on test sentences, given the training examples
    :param params: parameters for the experiment
    :param train_sentences: few-shot training sentences
    :param train_labels: few-shot training labels
    :param test_sentences: few-shot test sentences
    :param return_all_prompts: whether to return all the prompts
    :param num_tokens_to_predict_override: whether to override num token to predict
    :param override_prompt: whether to override prompt
    :return: a list of dictionaries
    """
    all_raw_answers = []

    # can optionally ignore the normal prompt and feed in a custom prompt (used for contextual calibration)
    if override_prompt is None:
        prompts = []
        for test_sentence in test_sentences:
            prompts.append(construct_prompt(params, train_sentences, train_labels, test_sentence))
    else:
        prompts = override_prompt

    chunked_prompts = list(chunks(prompts, chunk_size_helper(params)))
    for chunk_id, test_chunk_prompts in enumerate(chunked_prompts):
        if num_tokens_to_predict_override is not None:
            num_tokens_to_predict = num_tokens_to_predict_override
        else:
            num_tokens_to_predict = params['num_tokens_to_predict']
        resp = complete(test_chunk_prompts, num_tokens_to_predict, params['model'], num_log_probs=params['api_num_log_prob'])
        for answer_id, answer in enumerate(resp['choices']):
            all_raw_answers.append(answer)
    if return_all_prompts:
        return all_raw_answers, prompts
    else:
        return all_raw_answers

def load_pickle(params):
    # load saved results from model
    file_name = os.path.join(SAVE_DIR, f"{params['expr_name']}.pkl")
    assert os.path.isfile(file_name), f"file does not exist: {file_name}"
    with open(file_name, 'rb') as file:
        data = pickle.load(file)
    print(f"Loaded data from {file_name}")
    return data

def save_pickle(params, data):
    # save results from model
    file_name = os.path.join(SAVE_DIR, f"{params['expr_name']}.pkl")
    if os.path.isfile(file_name):
        print("WARNING! overwriting existing saved files")
    with open(file_name, 'wb') as file:
        pickle.dump(data, file)
    print(f"Saved to {file_name}")
    return data

def print_results(tree, names=('Original Accuracy  ','Calibrated Accuracy')):
    # print out all results
    root = deepcopy(tree)
    for dataset in root.keys():
        print(f"\n\nDataset: {dataset}")
        models_node = root[dataset]
        for model in models_node.keys():
            print(f"\nModel: {model}")
            num_shots_node = models_node[model]
            for num_shots in num_shots_node.keys():
                accuracies = np.array(list(num_shots_node[num_shots].values()))
                accuracies_mean = np.mean(accuracies, axis=0)
                accuracies_low = np.min(accuracies, axis=0)
                accuracies_high = np.max(accuracies, axis=0)
                accuracies_std = np.std(accuracies, axis=0)

                print(f"\n{num_shots}-shot, {len(accuracies)} seeds")
                for i, (m, l, h, s) in enumerate(zip(accuracies_mean, accuracies_low, accuracies_high, accuracies_std)):
                    print(f"{names[i]} | Mean: {m:.4f}, Low: {l:.4f}, High: {h:.4f}, Std: {s:.4f}")
                print()

def load_results(params_list):
    # load saved results from model
    result_tree = dict()
    for params in params_list:
        saved_result = load_pickle(params)
        keys = [params['dataset'], params['model'], params['num_shots']]
        node = result_tree # root
        for k in keys:
            if not (k in node.keys()):
                node[k] = dict()
            node = node[k]
        node[params['seed']] = saved_result['accuracies']
    print_results(result_tree)