""" This file contains the code for calling all LLM APIs. """

import os
from functools import partial
import tiktoken
# from schema import TooLongPromptError, LLMError

enc = tiktoken.get_encoding("cl100k_base")

try:
    from helm.common.authentication import Authentication
    from helm.common.request import Request, RequestResult
    from helm.proxy.accounts import Account
    from helm.proxy.services.remote_service import RemoteService
    # setup CRFM API
    auth = Authentication(api_key=open("crfm_api_key.txt").read().strip())
    service = RemoteService("https://crfm-models.stanford.edu")
    account: Account = service.get_account(auth)
except Exception as e:
    print(e)
    print("Could not load CRFM API key crfm_api_key.txt.")

try:   
    import anthropic
    #setup anthropic API key
    anthropic_client = anthropic.Anthropic(api_key=open("claude_api_key.txt").read().strip())
except Exception as e:
    print(e)
    print("Could not load anthropic API key claude_api_key.txt.")

try:
    import openai
    from openai import OpenAI
    organization, api_key  =  open("openai_api_key.txt").read().strip().split(":")    
    os.environ["OPENAI_API_KEY"] = api_key 
    client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
except Exception as e:
    print(e)
    print("Could not load OpenAI API key openai_api_key.txt.")


def log_to_file(log_file, prompt, completion, model, max_tokens_to_sample):
    """ Log the prompt and completion to a file."""
    with open(log_file, "a") as f:
        f.write("\n===================prompt=====================\n")
        f.write(f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}")
        num_prompt_tokens = len(enc.encode(f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}"))
        f.write(f"\n==================={model} response ({max_tokens_to_sample})=====================\n")
        f.write(completion)
        num_sample_tokens = len(enc.encode(completion))
        f.write("\n===================tokens=====================\n")
        f.write(f"Number of prompt tokens: {num_prompt_tokens}\n")
        f.write(f"Number of sampled tokens: {num_sample_tokens}\n")
        f.write("\n\n")


def complete_text_claude(prompt, stop_sequences=[anthropic.HUMAN_PROMPT], model="claude-v1", max_tokens_to_sample = 2000, temperature=0.5, log_file=None, **kwargs):
    """ Call the Claude API to complete a prompt."""

    ai_prompt = anthropic.AI_PROMPT
    if "ai_prompt" in kwargs is not None:
        ai_prompt = kwargs["ai_prompt"]
        del kwargs["ai_prompt"]
    # model = "claude-2"
    if model.startswith("claude-3"):
        messages = [
            {'role': 'user', 'content': f"{anthropic.HUMAN_PROMPT} {prompt}"}
        ]
        rsp = anthropic_client.messages.create(
            model=model,
            messages=messages,
            max_tokens=max_tokens_to_sample
        )
        completion = rsp.content[0].text
        if log_file is not None:
            log_to_file(log_file, prompt, completion, model, max_tokens_to_sample)
        return completion
    try:
        rsp = anthropic_client.completions.create(
            prompt=f"{anthropic.HUMAN_PROMPT} {prompt} {ai_prompt}",
            stop_sequences=stop_sequences,
            model=model,
            temperature=temperature,
            max_tokens_to_sample=max_tokens_to_sample,
            **kwargs
        )
    except anthropic.APIStatusError as e:
        print(e)
        exit()
        raise TooLongPromptError()
    except Exception as e:
        exit()
        raise LLMError(e)

    completion = rsp.completion
    if log_file is not None:
        log_to_file(log_file, prompt, completion, model, max_tokens_to_sample)
    return completion


def get_embedding_crfm(text, model="openai/gpt-4-0314"):
    request = Request(model="openai/text-similarity-ada-001", prompt=text, embedding=True)
    request_result: RequestResult = service.make_request(auth, request)
    return request_result.embedding 

def complete_text_crfm(prompt=None, stop_sequences = None, model="openai/gpt-4-0314",  max_tokens_to_sample=2000, temperature = 0.5, log_file=None, messages = None, **kwargs):

    random = log_file
    if messages:
        request = Request(
                prompt=prompt, 
                messages=messages,
                model=model, 
                stop_sequences=stop_sequences,
                temperature = temperature,
                max_tokens = max_tokens_to_sample,
                random = random
            )
    else:
        print("model", model)
        print("max_tokens", max_tokens_to_sample)
        request = Request(
                prompt=prompt, 
                model=model, 
                stop_sequences=stop_sequences,
                temperature = temperature,
                max_tokens = max_tokens_to_sample,
                random = random
        )

    try:      
        request_result: RequestResult = service.make_request(auth, request)
    except Exception as e:
        # probably too long prompt
        print(e)
        exit()
        # raise TooLongPromptError()

    if request_result.success == False:
        print(request.error)
        # raise LLMError(request.error)
    completion = request_result.completions[0].text
    if log_file is not None:
        log_to_file(log_file, prompt, completion, model, max_tokens_to_sample)
    return completion


def complete_text_openai(prompt, stop_sequences=[], model="gpt-3.5-turbo", max_tokens_to_sample=2000, temperature=0.5, log_file=None, **kwargs):

    """ Call the OpenAI API to complete a prompt."""
    raw_request = {
          "model": model,
        #   "temperature": temperature,
        #   "max_completion_tokens": max_tokens_to_sample,
        #   "stop": stop_sequences or None,  # API doesn't like empty list
          **kwargs
    }
    if model.startswith("gpt-3.5") or model.startswith("gpt-4") or model.startswith("o1"):
        # Requires openai==1.42.0
        messages = [{"role": "user", "content": prompt}]
        response = client.chat.completions.create(**{"messages": messages,**raw_request})
        completion = response.choices[0].message.content
    # elif model.startswith("gpt-3.5") or model.startswith("gpt-4"):
    #     # Requires openai==0.28
    #     messages = [{"role": "user", "content": prompt}]
    #     response = openai.ChatCompletion.create(
    #         model=model,
    #         messages=[
    #             {"role": "user", "content": prompt}
    #         ]
    #     )
    #     completion = response.choices[0].message.content
    else:
        response = client.completions.create(**{"prompt": prompt,**raw_request})
        completion = response.choices[0].text
    if log_file is not None:
        log_to_file(log_file, prompt, completion, model, max_tokens_to_sample)
    return completion

def complete_text(prompt, log_file, model, **kwargs):
    """ Complete text using the specified model with appropriate API. """

    if model.startswith("claude"):
        # use anthropic API
        completion = complete_text_claude(prompt, stop_sequences=[anthropic.HUMAN_PROMPT, "Observation:"], log_file=log_file, model=model, **kwargs)
    elif "/" in model:
        # use CRFM API since this specifies organization like "openai/..."
        completion = complete_text_crfm(prompt, stop_sequences=["Observation:"], log_file=log_file, model=model, **kwargs)
    else:
        # use OpenAI API
        completion = complete_text_openai(prompt, stop_sequences=["Observation:"], log_file=log_file, model=model, **kwargs)
    return completion

# specify fast models for summarization etc
FAST_MODEL = "claude-v1"
def complete_text_fast(prompt, **kwargs):
    return complete_text(prompt = prompt, model = FAST_MODEL, temperature =0.01, **kwargs)
# complete_text_fast = partial(complete_text_openai, temperature= 0.01)

# import anthropic
# from helm.common.authentication import Authentication
# from helm.common.perspective_api_request import PerspectiveAPIRequest, PerspectiveAPIRequestResult
# from helm.common.request import Request, RequestResult
# from helm.common.tokenization_request import TokenizationRequest, TokenizationRequestResult
# from helm.proxy.accounts import Account
# from helm.proxy.services.remote_service import RemoteService
# from functools import partial
# from transformers import GPT2TokenizerFast
# import os
# import openai
# import time
# # setup OpenAI API key
# openai.api_key  =  open("openai_api_key.txt").read().strip()  
# os.environ["OPENAI_API_KEY"] = openai.api_key 

# # An example of how to use the request API.
# try:
#     auth = Authentication(api_key=open("gpt4_api_key.txt").read().strip())
#     # auth = Authentication(api_key="benchmarking-123")

#     service = RemoteService("https://crfm-models.stanford.edu")

#     # Access account and show my current quotas and usages
#     account: Account = service.get_account(auth)
# except:
#     pass
# # Make a request
# # request = Request(model="openai/gpt-4-0314", prompt="Life is like a box of")
# # request_result: RequestResult = service.make_request(auth, request)
# # print(request_result.completions[0].text)
# c = anthropic.Client(open("claude_api_key.txt").read().strip())
# # c = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])

# tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

# def log_to_file(log_file, prompt, completion):
#     with open(log_file, "a") as f:
#         f.write("\n===================prompt=====================\n")
#         f.write(f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}")
#         num_prompt_tokens = len(tokenizer(f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}")["input_ids"])
#         f.write("\n===================response=====================\n")
#         f.write(completion)
#         num_sample_tokens = len(tokenizer(completion)["input_ids"])
#         f.write("\n===================tokens=====================\n")
#         f.write(f"Number of prompt tokens: {num_prompt_tokens}\n")
#         f.write(f"Number of sampled tokens: {num_sample_tokens}\n")
#         f.write("\n\n")


# def complete_text(prompt, stop_sequences=[anthropic.HUMAN_PROMPT],
#                          model=None, max_tokens_to_sample = 2000,
#                          temperature=0.5, log_file=None, **kwargs):
#     ai_prompt = anthropic.AI_PROMPT
#     if "ai_prompt" in kwargs is not None:
#         ai_prompt = kwargs["ai_prompt"]

#     model_keys={
#         'claude': 'claude-v1',
#         'gpt3.5': 'openai/gpt-3.5-turbo-0301'
#     }

#     resp = c.completion(
#         prompt=f"{anthropic.HUMAN_PROMPT}{prompt}{ai_prompt}",
#         stop_sequences=stop_sequences,
#         model=model_keys[model],
#         temperature1=temperature,
#         max_tokens_to_sample=max_tokens_to_sample,
#         **kwargs
#     )
#     completion = resp["completion"]
#     if log_file is not None:
#         log_to_file(log_file, prompt, completion)
#     return completion


# def complete_text_claude(prompt, stop_sequences=[anthropic.HUMAN_PROMPT],
#                          model="claude-v1", max_tokens_to_sample = 100000,
#                          temperature=0.5, log_file=None, **kwargs):
#     ai_prompt = anthropic.AI_PROMPT
#     if "ai_prompt" in kwargs is not None:
#         ai_prompt = kwargs["ai_prompt"]

#     model = "claude-v2"
#     resp = c.completion(
#         prompt=f"{anthropic.HUMAN_PROMPT}{prompt}{ai_prompt}",
#         stop_sequences=stop_sequences,
#         model=model,
#         temperature1=temperature,
#         max_tokens_to_sample=max_tokens_to_sample,
#         **kwargs
#     )
#     completion = resp["completion"]
#     if log_file is not None:
#         log_to_file(log_file, prompt, completion)
#     return completion


# def complete_text_gpt4_temp(prompt, stop_sequences = None,
#                        model="openai/gpt-4-0314",  max_tokens_to_sample=2000, temperature = 0.5, log_file=None, **kwargs):

#     # request = Request(
#     #         prompt=prompt, 
#     #         model=model, 
#     #         stop_sequences=stop_sequences,
#     #         temperature = temperature,
#     #         max_tokens = max_tokens_to_sample,
#     #     )
#     # request_result: RequestResult = service.make_request(auth, request)
#     # completion = request_result.completions[0].text
#     # if log_file is not None:
#     #     log_to_file(log_file, prompt, completion)


#     return completion

# def complete_text_gpt4(prompt, stop_sequences=[], model="gpt-4", max_tokens_to_sample=1024, temperature=0.2, log_file=None, **kwargs):
#     """ Call the OpenAI API to complete a prompt."""
#     raw_request = {
#           "model": model,
#           "temperature": temperature,
#           "max_tokens": max_tokens_to_sample,
#           "stop": stop_sequences or None,  # API doesn't like empty list
#           **kwargs
#     }
#     time.sleep(1)
#     if model.startswith("gpt-3.5") or model.startswith("gpt-4"):
#         messages = [{"role": "user", "content": prompt}]
#         response = openai.ChatCompletion.create(**{"messages": messages,**raw_request})
#         completion = response["choices"][0]["message"]["content"]
#     else:
#         response = openai.Completion.create(**{"prompt": prompt,**raw_request})
#         completion = response["choices"][0]["text"]
#     if log_file is not None:
#         log_to_file(log_file, prompt, completion)
#     return completion

# # complete_text_fast = complete_text_claude
# complete_text_fast = partial(complete_text_claude, temperature= 0.01)
# complete_text_slow = complete_text_gpt4
