
import torch

def format_answer(tokenizer, answer_ids):
    answer = tokenizer.decode(answer_ids[0], skip_special_tokens=True)
    return answer


def compare(model, tokenizer, input_ids, max_new_tokens,
            top_k=None, top_p=None, min_p=None, num_beams=1, greedy=False, **kwargs):
    configurable = [top_k, top_p, min_p]
    assert sum([x is not None for x in configurable]) <= 1, "Can only set at most one of top_k, top_p, or min_p"

    do_sample = False

    if greedy:
        do_sample = False
        num_beams = 1

    if any(option is not None for option in configurable):
        do_sample = True
        if greedy:
            print("WARNING: greedy=True not compatible with sampling config, setting to False")
            
    pad_value = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id

    with torch.inference_mode():
        output = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            top_k=top_k,
            top_p=top_p,
            min_p=min_p,
            do_sample=do_sample,
            num_beams=num_beams,
            pad_token_id=pad_value,
            eos_token_id=tokenizer.eos_token_id,
            **kwargs  # forward any other generation-compatible kwargs
        )

    answer = format_answer(tokenizer, output)
    
    del output
    
    return answer


def compare_batch(model, tokenizer, input_ids, max_new_tokens,
                  top_k=None, top_p=None, min_p=None, num_beams=1, greedy=False,
                  attention_mask=None, **kwargs):
    configurable = [top_k, top_p, min_p]
    assert sum(x is not None for x in configurable) <= 1, "Only one of top_k, top_p, min_p can be set"

    do_sample = False
    if greedy:
        do_sample = False
        num_beams = 1

    if any(option is not None for option in configurable):
        do_sample = True
        if greedy:
            print("WARNING: greedy=True not compatible with sampling, setting to False")

    pad_value = tokenizer.pad_token_id or tokenizer.eos_token_id

    with torch.inference_mode():
        outputs = model.generate(
            input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            top_k=top_k,
            top_p=top_p,
            min_p=min_p,
            do_sample=do_sample,
            num_beams=num_beams,
            pad_token_id=pad_value,
            eos_token_id=tokenizer.eos_token_id,
            **kwargs
        )

    decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    
    del outputs
    torch.cuda.empty_cache()
    
    return decoded