import os
import cohere
import numpy as np
from typing import List
from tqdm.auto import tqdm
from dotenv import load_dotenv

load_dotenv()

co = cohere.Client(os.getenv("COHERE_API_KEY"))


def get_cohere_embedding(
    prompt: str, model="embed-english-light-v3.0", input_type="search_document"
):
    response = co.embed(texts=[prompt], model=model, input_type=input_type)

    return np.array(response.embeddings[0])


def get_cohere_embeddings_batched(
    prompts: List[str],
    batch_size=96,
    model="embed-english-light-v3.0",
    input_type="search_document",
    pbar=True,
):
    assert batch_size <= 96, "cohere limits batch size at 96"

    num_batches = max(int(len(prompts) / batch_size + 0.99), 1)
    embeddings = []
    for i in tqdm(range(num_batches), disable=not pbar):
        inputs = prompts[i * batch_size : (i + 1) * batch_size]
        response = co.embed(texts=inputs, model=model, input_type=input_type, truncate="END")
        embedding = response.embeddings
        embeddings.append(embedding)
    embeddings = np.concatenate(embeddings)
    return embeddings
