import torch, os
import  torch.nn.functional as F

os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
os.environ['HF_HOME'] = '.cache'
from transformers import AutoTokenizer, AutoModelForCausalLM


def load_model(model_name):
    '''
    This is a demo for the interpretation module. Use Qwen 0.5B as test model
    so that it can fit into my local RTX 3060Ti with 8G VRAM.
    '''
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
    return model, tokenizer


def generate_response(tokenizer, model, user_prompt, sys_prompt=None, device='cuda'):
    '''
    Generate the response using the given prompt and given model. 

    @param tokenizer: The tokenizer of the model
    @param model: The model to generate the response
    @param user_prompt: The prompt from the user
    @param sys_prompt: The prompt from the system
    @param device: The device to run
    '''
    messages = []
    if sys_prompt is not None:
        messages.append({"role": "system", "content": sys_prompt})
    if type(user_prompt) == list:
        messages.extend(user_prompt)
    elif type(user_prompt) == str:
        messages.append({"role": "user", "content": user_prompt})
    else:
        raise ValueError("Invalid prompt")

    tokenizer.pad_token = tokenizer.eos_token

    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    model_inputs = tokenizer([text], return_tensors="pt").to(device)

    generated_content = model.generate(
        model_inputs.input_ids, 
        max_new_tokens=50, 
        do_sample=True, 
        return_dict_in_generate=True, 
        output_logits=True
    )
    generated_ids = generated_content['sequences']
    generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]

    return model_inputs.input_ids, generated_content


def get_logit_from_llm(
        prompt: str | list[str], 
        model_name: str = "Qwen/Qwen2-1.5B", 
        ans_parser = None,
        ans_options: list = ["A", "B", "C", "D"],
        device: str = "cuda"
    ) -> list:
    '''
    Return the logits of all possible answers for the prompt question.

    @param prompt: 
        The prompt question
    @param model_name: 
        The model name to use
    @param ans_parser: The function to parse the answer, should return the
        index of answer token in the response
    @param ans_options: 
        The possible answer options in token string
    @param device: 
        The device to run the model 
    '''

    # Load the model
    model, tokenizer = load_model(model_name)

    # Generate the response
    input_id, response = generate_response(tokenizer, model, prompt, device=device)
    
    # Decode the response to text
    generated_ids = response['sequences']
    generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(input_id, generated_ids)]

    response_txt = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

    # Parse the answer and get index
    if ans_parser is not None:
        answer_idx, answer = ans_parser(response_txt)
    else:
        answer_idx, answer = -1, tokenizer.decode(generated_ids[0][-1])

    # Get the logits of the answer token
    logits = F.normalize(response['logits'][answer_idx][0], dim=-1)
    answer_id = torch.argmax(logits).item()
    
    ans_logits = []
    for ans in ans_options:
        possible_ans_id = tokenizer.encode(ans, add_special_tokens=False)[0]
        ans_logits.append(logits[possible_ans_id].item())

    return ans_logits
    