import hf_olmo
import re
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import list_repo_refs

def _build_step_to_ckpt_dict():
    step_to_ckpt = {}
    out = list_repo_refs("allenai/OLMo-7B")
    branches = [b.name for b in out.branches]
    for branch in branches:
        if branch == 'main': continue
        step = int(re.findall('step(.*)-.*', branch)[0])
        step_to_ckpt[step] = branch
    return step_to_ckpt

STEP_TO_CKPT = _build_step_to_ckpt_dict()

# Adapted from instruction at https://huggingface.co/allenai/OLMo-7B-SFT
def prompt_model(model_name, prompt=None, max_tokens=None, temperature=None):
    if prompt is None:
        raise ValueError("Prompt must be specified")

    # Load OLMo-7B model and tokenizer from pre-trained checkpoint
    # If we get access to the intermediate checkpoints, we can just add revision=<CKPT> to load them instead.
    olmo = AutoModelForCausalLM.from_pretrained("allenai/OLMo-7B")
    tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-7B")

    # Tokenize prompt
    inputs = tokenizer(prompt, return_tensors='pt', return_token_type_ids=False)

    # Place inputs on CUDA
    inputs = {k: v.to('cuda') for k,v in inputs.items()}
    olmo = olmo.to('cuda')

    # Generate response.
    response = olmo.generate(**inputs, max_new_tokens=100, do_sample=True, top_k=50, top_p=0.95)
    decoded_text = tokenizer.batch_decode(response, skip_special_tokens=True)[0]

    return decoded_text
    
