"""
# UNCOMMENT IT IF YOU WANT TO RUN GEMINI
import os
from time import sleep


#import backoff
import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold
"""

from tqdm import tqdm

import torch # speed up?

"""
# UNCOMMENT IT IF YOU WANT TO RUN GEMINI
GOOGLE_API_KEY = os.environ["GOOGLE_GEMINI_API_KEY"] # change to your own api key
assert GOOGLE_API_KEY is not None

genai.configure(api_key=GOOGLE_API_KEY)
gemini_model = genai.GenerativeModel('gemini-1.0-pro-001')

#@backoff.on_exception(backoff.expo, RateLimitError)
def chatcompletions_with_backoff(**kwargs):
    response = gemini_model.generate_content(**kwargs)
    return response

#def get_error_messages_list():
#    return ["Invalid Request Error"]

def _get_safety_config(threshold=HarmBlockThreshold.BLOCK_NONE):
    # Ideally, should use BLOCK_NONE, but in case it is not available for now, use BLOCK_ONLY_HIGH instead...
    # Ref: https://www.googlecloudcommunity.com/gc/AI-ML/Gemini-s-safety-config-400-error-from-today/m-p/715146
    # Note that some attributes are commented because we cannot overwrite these (it may subject to change)
    cur = [#HarmCategory.HARM_CATEGORY_DANGEROUS, 
           HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
           #HarmCategory.HARM_CATEGORY_DEROGATORY,
           HarmCategory.HARM_CATEGORY_HARASSMENT,
           HarmCategory.HARM_CATEGORY_HATE_SPEECH,
           #HarmCategory.HARM_CATEGORY_MEDICAL,
           #HarmCategory.HARM_CATEGORY_SEXUAL,
           HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
           #HarmCategory.HARM_CATEGORY_TOXICITY,
           #HarmCategory.HARM_CATEGORY_UNSPECIFIED,
           #HarmCategory.HARM_CATEGORY_VIOLENCE
          ]
    safety_config = {i: threshold for i in cur if i in HarmCategory}
    return safety_config

def get_response_list_from_gemini(messages_list,
                                  model="gemini-1.0-pro-001",
                                  temperature=0,
                                  max_tokens=128):
    gemini_model = genai.GenerativeModel(model)
    safety_config = _get_safety_config(HarmBlockThreshold.BLOCK_ONLY_HIGH) # Ideally, should use HarmBlockThreshold.BLOCK_NONE
    response_list = []
    for messages in tqdm(messages_list) if len(messages_list) > 1 else messages_list:
        not_processed = True
        while not_processed:
            try:
                response = chatcompletions_with_backoff(
                    contents=messages,
                    # https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/tutorials/python_quickstart.ipynb#scrollTo=0UIt5LKp16jL&line=3&uniqifier=1
                    # https://ai.google.dev/docs/safety_setting_gemini
                    safety_settings = safety_config,
                    # not sure if block_none is effective... (ref: https://www.googlecloudcommunity.com/gc/AI-ML/Gemini-s-safety-config-400-error-from-today/m-p/715146)
                    #safety_settings={'HARASSMENT': 'block_none'},
                    # https://ai.google.dev/api/python/google/generativeai/GenerationConfig
                    generation_config=genai.types.GenerationConfig(candidate_count=1,
                                                                   max_output_tokens=max_tokens,
                                                                   temperature=temperature,
                                                                   top_p=1)
                )
                not_processed = False
            except Exception as e:
                not_processed = True
                tqdm.write(f"{e}")
                tqdm.write("Some error occurs, sleep 30 seconds..")
                sleep(30)
            finally:
                pass
        response_list.append(response)
    return response_list
"""
def get_response_list_from_gemini(messages_list,
                                  model="gemini-1.0-pro-001",
                                  temperature=0,
                                  max_tokens=128):
    print("We do not run gemini in KEIC data... as Google set an upper bound limit for daily use when using a free API key... you can comment this function and uncomment above to run the model tho")
    assert False
    return

def get_response_list_from_gemma_2(messages_list,
                                   model_id="google/gemma-2-9b-it",
                                   tokenizer=None, 
                                   model=None,
                                   temperature=0,
                                   max_tokens=512):
    assert tokenizer is not None and model is not None
    response_list = []
    for messages in tqdm(messages_list) if len(messages_list) > 1 else messages_list:
        # Ref: https://huggingface.co/docs/transformers/chat_templating#how-do-i-use-chat-templates
        inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
        #terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("'<eos>'")]
        model.generation_config.temperature = temperature
        max_new_tokens = max_tokens
        with torch.no_grad():
            outputs = model.generate(inputs,
                                     temperature=temperature,
                                     top_p=1,
                                     max_new_tokens=max_new_tokens,
                                     #eos_token_id=terminators,
                                     pad_token_id=tokenizer.eos_token_id) # https://stackoverflow.com/questions/69609401/suppress-huggingface-logging-warning-setting-pad-token-id-to-eos-token-id
        response = tokenizer.batch_decode(outputs[:, inputs.shape[1]:], skip_special_tokens=True)[0]
        response_list.append(response)
    return response_list
