import os
import torch
import string
from torch.nn.functional import softmax
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    pipeline,
    BitsAndBytesConfig
)
from accelerate import Accelerator
from utils.util import letters

def get_base_model(llm):
    if llm.split('-')[1] == '3':
        llm = 'Meta-' + llm.title()
    else:
        llm = llm.capitalize()
    if llm.endswith('instruct'):
        llm = llm.replace('instruct', 'Instruct')
    return f"meta-llama/{llm}"

def get_model_and_tokenizer(base_model):
    compute_dtype = getattr(torch, "float16")
    quant_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=False,
    )
    tokenizer = AutoTokenizer.from_pretrained(base_model)
    model = AutoModelForCausalLM.from_pretrained(
        base_model,
        quantization_config=quant_config,#load_in_8bit=True,
        device_map={"": Accelerator().local_process_index}, 
    )
    return model, tokenizer

def set_up_inference_pipeline(llm):
    """
    Set up the inference pipeline
    llm: string, name of the LLM
    returns: pipeline, inference pipeline
    """
    # Configure environment
    with open('huggingface_api_key.txt', 'r') as file:
        os.environ['HF_TOKEN'] = file.read().strip().lstrip('[').rstrip(']').strip()

    # Load the model and tokenizer
    torch.cuda.empty_cache()
    base_model = get_base_model(llm)
    print("Setting up inference pipeline for model:", base_model)

    # Quantization
    model, tokenizer = get_model_and_tokenizer(base_model)
    
    # Set up the inference pipeline
    inference_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, return_full_text=False)
    return model, tokenizer, inference_pipeline

def get_option_probabilities(model, tokenizer, prompt_text, options):
    # Tokenize the input and send to the same device as the model
    inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    # Get logits for the last token position before completion
    last_token_logits = logits[0, -1, :]

    # Map options to their respective token IDs
    letter_options = letters[:len(options)]
    option_ids = [tokenizer.encode(letter_option + ")", add_special_tokens=False)[0] for letter_option in letter_options]

    # Calculate probabilities
    probs = softmax(last_token_logits, dim=-1)

    # Extract probabilities for the options
    option_probs = {letter_option: probs[option_id].item() for letter_option, option_id in zip(letter_options, option_ids)}

    return option_probs