def get_memory_constrained_batch_size(length: int, llm_name: str) -> int:
    a, b = get_inverse_function_params(llm_name)
    return int(a / (length + b))


def get_inverse_function_params(llm_name: str) -> tuple[float, float]:
    # NOTE: these parameters are computed by fitting an inverse function to data
    # generated by benchmark_batch_size.py
    if llm_name == "sft10k" or llm_name == "alpaca-7b":
        return (53288.568, 9.164)
    elif llm_name == "Meta-Llama-3-8B":
        return (61626.403, 2.076)
    elif llm_name == "Meta-Llama-3-8B-Instruct" or "Mistral-7B" in llm_name:
        return (61562.069, 2.058)
    else:
        raise Exception("Unknown LLM name")
