# https://developers.generativeai.google/api

import google.generativeai as genai

from vertexai.language_models import CodeGenerationModel
# https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.language_models.CodeGenerationModel
from google.generativeai.types import generation_types

import json

from google.cloud import aiplatform

aiplatform.init(project='nl-to-sql-model-eval')

test_prompt = """
write a hello world program in Rust
"""

def call_vertex(
        prompt: str,
        model: str = "gemini-1.5-pro-latest"
) -> str:
    vertex_info_file = open(".local/vertex.json")
    vertex_info = json.loads(vertex_info_file.read())
    vertex_info_file.close()
    genai.configure(api_key=vertex_info["api_key"])

    generation_config = {
        "temperature": vertex_info["temperature"],
        "top_p": vertex_info["top_p"],
        "top_k": vertex_info["top_k"],
        "max_output_tokens": vertex_info["max_tokens"],
    }

    safety_settings = [
        {
            "category": "HARM_CATEGORY_HARASSMENT",
            "threshold": "BLOCK_MEDIUM_AND_ABOVE"
        },
        {
            "category": "HARM_CATEGORY_HATE_SPEECH",
            "threshold": "BLOCK_MEDIUM_AND_ABOVE"
        },
        {
            "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
            "threshold": "BLOCK_NONE"
        },
        {
            "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
            "threshold": "BLOCK_MEDIUM_AND_ABOVE"
        },
    ]

    model = genai.GenerativeModel(model_name="gemini-1.5-pro-latest",
                                generation_config=generation_config,
                                safety_settings=safety_settings)

    convo = model.start_chat(history=[])
    try:
        convo.send_message(prompt)
    except generation_types.StopCandidateException as e:
        print(e)
        if hasattr(e, 'message'):
            return e.message
        else:
            return "Encountered exception without a message attribute."
    return convo.last.text


def call_google_palm(
        prompt,
        max_attempts = 10,
        model = 'text-bison',
        max_tokens = 800,
        verbose = True
) -> str:
    
    palm_info_file = open('.local/googlepalm.json')
    palm_info = json.loads(palm_info_file.read())
    palm_info_file.close()

    genai.configure(api_key = palm_info['api_key'])
    result = genai.generate_text(
        prompt=prompt,
        model=model,
        temperature=palm_info['temperature'],
        max_output_tokens=max_tokens,
        top_p=palm_info['top_p'],
        stop_sequences=['#', '--', ';']
        )
    return result.result.replace("```sql", "").strip()

def call_codey(
        prompt,
        max_attempts = 10,
        model = 'code-bison-32k',
        max_tokens = 800,
        verbose = True
) -> str:
    codey_info_file = open(".local/codey.json")
    codey_info = json.loads(codey_info_file.read())
    codey_info_file.close()

    parameters = {
        "temperature": codey_info['temperature'],
        "max_output_tokens": max_tokens,
    }

    code_generation_model = CodeGenerationModel.from_pretrained(model)
    response = code_generation_model.predict(
        prefix=prompt, **parameters
    )

    if verbose:
        print(f"Response from {model} Model: {response.text}")

    return response.text.replace("```sql", "").replace("```", "").strip()


if __name__ == "__main__":
    prompt = "Write a sql query to get sales revenue by customer on a canonical customer database."
    # text = call_google_palm(
    #     "Write a sql query to get sales revenue by customer on a canonical customer database."
    # )
    text = call_vertex(test_prompt)
    print(text)