import requests
import json
from typing import Tuple, List
import os
import time

API_KEY = "dEIsNwgBudf0FzZgVjPXPrH5"
SECRET_KEY = "HZfxPt3EWmFdKDyy5mmRNeRo4n5T4O5d"

def main():

    url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed?access_token=" + get_access_token()

    payload = json.dumps({
        "messages": [
            {
                "role": "user",
                "content": "你好"
            },
        ]
    })
    headers = {
        'Content-Type': 'application/json'
    }

    response = requests.request("POST", url, headers=headers, data=payload)

    print(response.text)


def get_access_token():
    url = "https://aip.baidubce.com/oauth/2.0/token"
    params = {"grant_type": "client_credentials", "client_id": API_KEY, "client_secret": SECRET_KEY}
    return str(requests.post(url, params=params).json().get("access_token"))

def get_chat(turns: List[str], max_len: int = 200, max_tries: int = 100, system_message: str = "") -> Tuple[str, int]:
    url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed?access_token=" + get_access_token()
    num_tries = 0
    headers = {
        'Content-Type': 'application/json'
    }
    while True:
        try:
            messages = [dict(role="system", content=system_message)] if system_message else []
            for i, content in enumerate(turns):
                messages.append(dict(role="user" if i % 2 == 0 else "assistant", content=content))
            messages = json.dumps({"prompt": messages})
            response = requests.request("POST", url, headers=headers, data=messages)
            # response = openai.ChatCompletion.create(
            #     model="gpt-3.5-turbo",
            #     messages=messages,
            #     temperature=.7,
            #     max_tokens=max_len,
            #     top_p=1,
            #     frequency_penalty=0,
            #     presence_penalty=0
            # )
            print(response.text)
            response_dict = json.loads(response.text)
            return response_dict['result'], response_dict["total_tokens"]
        except Exception as e:
            print("Error : ", e)
            num_tries += 1
            if num_tries >= max_tries:
                raise Exception('Error in calling LLM.')