import string

system_prompt_mcqa = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer (letter of correct option) are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think>
<answer> $LETTER </answer>
"""

LETTERS = string.ascii_uppercase

def get_question_options(question, options):
    options_prompt = 'Options:\n'
    for i, option in enumerate(options):
        letter = LETTERS[i]
        options_prompt += f'{letter}. {option}\n'
    text = ''
    text += f"Question: {question}\n"
    text += options_prompt
    return text

def get_prompt(model_name, tokenizer, question, options, answer_type, image=None):

    text = get_question_options(question, options)
    system_prompt = system_prompt_mcqa

    if 'llama' in model_name.lower():
        # 1. format prompt
        chat_messages = [
            {
                "role": "system",
                "content": system_prompt,
            },
            {
                "role": "user",
                "content": text,
            },
        ]
        prompt = tokenizer.apply_chat_template(
            chat_messages, tokenize=False, add_generation_prompt=True
        )
        prompt += ' <think>'

        input_ids = tokenizer.encode(prompt, add_special_tokens=False)
    elif 'aya' in model_name.lower():
        assert image is not None
        processor = tokenizer
        tokenizer = processor.tokenizer

        messages = [
            {"role": "user",
             "content": [
                 {"type": "image"},
                 {"type": "text", "text": text},
             ]},
        ]

        prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
        prompt += ' <think>'
        inputs = processor(text=[prompt], images=[[image]], return_tensors="pt", padding=False)  # needed for max_length
        input_ids = inputs['input_ids'][0].tolist()
        assert input_ids[0] == 5
        assert input_ids[1] == 5
        assert prompt.startswith('<BOS_TOKEN>')
        input_ids = input_ids[1:] # somehow a second BOS token is added
        prompt = prompt[len('<BOS_TOKEN>'):]
    else:
        raise NotImplementedError()
    return prompt, input_ids
