import functools

import datasets

import _settings


@functools.lru_cache(1)
def read_all_contexts():
    dataset = datasets.load_from_disk(f'{_settings.DATA_FOLDER}/semantic_uncertainty/coqa_dataset')
    return {_['id']: _['story'] for _ in dataset}

def preprocess_data(tokenizer, split='validation'):
    # from https://github.com/lorenzkuhn/semantic_uncertainty/blob/main/code/parse_coqa.py
    dataset = datasets.load_from_disk(f'{_settings.DATA_FOLDER}/semantic_uncertainty/coqa_dataset')
    id_to_question_mapping = dict(zip(dataset['id'], dataset['question']))

    def encode_coqa(example):
        example['answer'] = example['answer']['text']
        example['prompt'] = prompt = example['story'] + ' Q: ' + example['question'] + ' A:'
        return tokenizer(prompt, truncation=False, padding=False)


    dataset = dataset.map(encode_coqa, batched=False, load_from_cache_file=False)
    dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'], output_all_columns=True)

    return dataset


def _generate_config(tokenizer):

    if tokenizer.__class__.__name__ == 'LlamaTokenizer':
        eos_token_id = [tokenizer.encode(_)[-1] for _ in ['.', '\n']] + [29889]  # seems to be '.' as well
        #eos_token_id = [tokenizer(_)['input_ids'] for _ in ['\n', ',', '.']]
    elif tokenizer.__class__.__name__ == 'GPT2Tokenizer':
        eos_token_id = [tokenizer.encode(_)[1] for _ in ['.', '\n']]
    else:
        raise NotImplementedError
    eos_token_id += [tokenizer.eos_token_id]
    question_framing_ids = ['Question:', ' Question:', '\n', 'Answer:', ' Answer:', 'Q:']
    # Follows Kuhn et al 2023 as Llama does not have CoQA
    question_framing_ids = [[tokenizer(eos_token)['input_ids'][1]] for eos_token in question_framing_ids]
    return dict(eos_token_id=eos_token_id, bad_words_ids=question_framing_ids)