import requests
import json
from tqdm.auto import tqdm
import time

from .api import connect, call
from .constants import (
    VPS_CONFIGS,
    MODEL_NAME_DICT,
    OLLAMA_RPM_TPM_DICT,
    OLLAMA_PULL_TIMNEOUT,
    OLLAMA_MAX_RETRY,
    LITELLM_API_KEY,
    LITELLM_API_BASE,
    DEFALUT_MODEL_NAMES,
    TEST_PROMPT,
    TEST_NUM_CALLS,
    OPENAI_TOKENIZER,
)


def ollama_pull_model(api_base, model, timeout=OLLAMA_PULL_TIMNEOUT):
    if timeout > 0:
        set_time = time.time()
    response = requests.post(
        f"{api_base}/api/pull",
        json={"name": model, "stream": True},
        stream=True,
    )
    pbar = tqdm(total=100.0)
    status = None
    completed = 0
    for chunk in response.iter_content(chunk_size=None):
        if timeout > 0 and time.time() - set_time > timeout:
            print(f"Pulling model {model} from {api_base} timeout, restarting...")
            pbar.close()
            return False
        if chunk:
            try:
                chunk = chunk.decode("utf-8")
                chunk_data = json.loads(chunk)
            except json.JSONDecodeError:
                continue
            if "error" in chunk_data:
                print("Error: ", chunk_data["error"], flush=True)
                pbar.close()
                return False
            if status != chunk_data["status"]:
                status = chunk_data["status"]
                if status.startswith("pulling") and status != "pulling manifest":
                    pbar.reset(total=100.0)
                    completed = 0
                pbar.set_description(f"Pulling {model}: {status}")
            if status == "success":
                pbar.close()
                return True
            if "completed" in chunk_data:
                completed_delta = chunk_data["completed"] - completed
                completed = chunk_data["completed"]
                pbar.update(
                    completed_delta / chunk_data["total"] * 100,
                )
    return False


def ollama_get_avail_models(api_base):
    retry_count = 0
    while retry_count < OLLAMA_MAX_RETRY:
        try:
            response = requests.get(f"{api_base}/api/tags")
            return [model["name"] for model in response.json()["models"]]
        except requests.exceptions.ConnectionError:
            retry_count += 1
            print(f"Connection error to {api_base}, retrying...")
    raise ConnectionError(
        f"Connection error to {api_base} after {retry_count} retries, check your vps.yaml file and network"
    )


def ollama_deploy_model(api_base, model):
    response = requests.post(f"{api_base}/api/chat", json={"model": model})
    if response.json()["model"] == model and response.json()["done"]:
        return True
    return False


def litellim_list_models():
    headers = {
        "Authorization": "Bearer " + LITELLM_API_KEY,
    }
    retry_count = 0
    while retry_count < OLLAMA_MAX_RETRY:
        try:
            response = requests.get(f"{LITELLM_API_BASE}/model/info", headers=headers)
            return [
                (model["model_name"], model["model_info"]["id"])
                for model in response.json()["data"]
            ]
        except requests.exceptions.ConnectionError:
            retry_count += 1
            print(f"Connection error to litellm, retrying...")
    raise ConnectionError(
        f"Connection error to litellm at {LITELLM_API_BASE} after {retry_count} retries, did you start litellm?"
    )


def litellm_delete_model(model_id):
    headers = {
        "Authorization": "Bearer " + LITELLM_API_KEY,
    }
    json = {
        "id": model_id,
    }
    response = requests.post(
        f"{LITELLM_API_BASE}/model/delete", headers=headers, json=json
    )
    if response.status_code != 200:
        raise RuntimeError(
            f"Connot delete model {model_id} form litellm, error message: {response.text}"
        )


def litellm_add_model(api_base, model_name):
    headers = {
        "Authorization": "Bearer " + LITELLM_API_KEY,
    }
    json = {
        "model_name": model_name,
        "litellm_params": {
            "model": MODEL_NAME_DICT[model_name],
            "api_base": api_base,
            "rpm": OLLAMA_RPM_TPM_DICT[model_name][0],
            "tpm": OLLAMA_RPM_TPM_DICT[model_name][1],
        },
    }
    response = requests.post(
        f"{LITELLM_API_BASE}/model/new", headers=headers, json=json
    )
    if response.status_code != 200:
        raise RuntimeError(
            f"Connot add model {model_name} to litellm, error message: {response.text}"
        )


def ollama_setup():
    model_ids = litellim_list_models()
    for model_name, model_id in model_ids:
        if model_name in OLLAMA_RPM_TPM_DICT:
            print(f"Delete model {model_name} from litellm")
            litellm_delete_model(model_id)
    if not VPS_CONFIGS:
        raise RuntimeError(
            "No VPS config found in vps.yaml, please check your config file."
        )
    for vps_dict in tqdm(VPS_CONFIGS, desc="Setting up VPS"):
        api_base = vps_dict["api_base"]
        if api_base[0].isdigit():
            api_base = "http://" + api_base
        if api_base.endswith("/"):
            api_base = api_base[:-1]
        model_name = vps_dict["model_name"]
        model = MODEL_NAME_DICT[model_name].split("/")[-1]
        avail_models = ollama_get_avail_models(api_base)
        print(f"Available models on {api_base}: {avail_models}")
        if model not in avail_models:
            print(f"Pulling model {model} onto {api_base}")
            while True:
                if ollama_pull_model(api_base, model):
                    print(f"Model {model} pulled onto {api_base}")
                    break
        print(f"Deploying model {model} on {api_base}")
        while True:
            if ollama_deploy_model(api_base, model):
                print(f"Model {model} deployed on {api_base}")
                break
        litellm_add_model(api_base, model_name)
    print("Ollama VPSs setup complete!")


def litellm_test_health():
    for model_name in tqdm(DEFALUT_MODEL_NAMES, desc="Testing health of models..."):
        response = requests.get(
            LITELLM_API_BASE + "/health/liveliness", data={"model": model_name}
        )
        if response.text == '"I\'m alive!"':
            print(f"Model {model_name} is healthy!")
        else:
            raise RuntimeError(
                f"Model {model_name} is not healthy! Try checking config.yaml and network connection."
            )


async def litellm_test_throughput():
    client = connect()
    for model_name in DEFALUT_MODEL_NAMES:
        prompt_list = [TEST_PROMPT] * TEST_NUM_CALLS
        start_time = time.time()
        contents = await call(client, model_name, prompt_list)
        end_time = time.time()
        print(
            f"Throughput of {model_name}: {sum([len(OPENAI_TOKENIZER.encode(content)) for content in contents]) / (end_time - start_time) / 1000 * 60:.2f}K Tokens/Min"
        )
