from openai import AsyncOpenAI
import asyncio
from tqdm.asyncio import tqdm
from openai import RateLimitError, InternalServerError, BadRequestError

from .constants import (
    LITELLM_API_KEY,
    LITELLM_API_BASE,
    MAX_CONCURRENT_CALLS,
)

##### Main functions #####


def connect():
    return AsyncOpenAI(api_key=LITELLM_API_KEY, base_url=LITELLM_API_BASE)


async def call(client, model_name, prompt_list, max_concurrent=MAX_CONCURRENT_CALLS):
    async def singe_call(prompt):
        async with semaphore:
            while True:
                try:
                    response = await client.chat.completions.create(
                        model=model_name,
                        messages=[{"role": "user", "content": prompt}],
                    )
                    content = response.choices[0].message.content
                    return content
                except RateLimitError:
                    print("Rate limit exceeded, waiting for 1 second")
                    await asyncio.sleep(1)
                    continue
                except InternalServerError:
                    print("Internal server error")
                    continue
                except BadRequestError:
                    print("Bad request")
                    return ""
                except Exception as e:
                    print(f"Unexpected Error: {e}")
                    return ""

    semaphore = asyncio.Semaphore(max_concurrent)
    tasks = [singe_call(prompt) for prompt in prompt_list]
    contents = await tqdm.gather(
        *tasks, desc=f"Calling {model_name}", total=len(prompt_list)
    )
    return contents
