import os
import time
# import openai
import backoff 
import google.generativeai as genai
import torch

from tot.solver_utils import solvercheck_propose_prompt_out
from tot.tasks.game24 import get_current_numbers

gemini_model = None

completion_tokens = prompt_tokens = 0

def gpt(prompt, task, model="gemma-2b-it", temperature=0.6, max_tokens=200, n=1, stop=None, propose_prompt_flag=False, x_y_pair=None) -> list:
    if model == "gemini-pro":
        while True:
            try:
                responses = [gemini_model.generate_content(prompt, generation_config=genai.types.GenerationConfig(
                                                    max_output_tokens=max_tokens,
                                                    stop_sequences=[stop] if stop else None,
                                                    temperature=temperature)) for _ in range(n)]
                responses_text = [response.text for response in responses]
                break
            except Exception as e:
                print(e)
                time.sleep(10)
                continue
        # print(responses_text)
        return responses_text
    elif model == "gemma-2b-it":
        chat = [
            { "role": "user", "content": "\n".join(prompt.split("\n")[:-2])},
            {"role": "model", "content": "\n".join(prompt.split("\n")[-2:])}
        ]
        prompt_modified = task.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)[:-14] # add_generation_prompt=False so that '<start_of_turn>model' is not added at the end of the prompt. [:-14] to remove '<end_of_turn>' string from the prompt. 
        prompt_modified_tokenized = task.tokenizer.encode(prompt_modified, add_special_tokens=False, return_tensors="pt").to("cuda")
        with torch.no_grad():
            outputs = task.model_object.generate(prompt_modified_tokenized, max_new_tokens=max_tokens, temperature=temperature, num_return_sequences=n, do_sample=True, min_length=-1)
        decoded_responses = task.tokenizer.batch_decode(outputs[:, prompt_modified_tokenized.shape[1]:], skip_special_tokens=True) # list of string reponse without the prompt included
        decoded_responses = [response.strip() for response in decoded_responses] # clean up new line and spaces
        if propose_prompt_flag:
            # ['input_x', 'input_y_second_to_last', 'final_prompt']
            x, y_second_to_last = x_y_pair
            task.prompt_df = task.prompt_df._append({'input_x': x, 'input_y_second_to_last': y_second_to_last, 'final_prompt': prompt_modified}, ignore_index=True)
        return decoded_responses
    elif model == "llama2-7b":
        if propose_prompt_flag: # no system prompt
            user_prompt = "\n".join(prompt.split("\n")[:-2])
            model_prompt = "\n".join(prompt.split("\n")[-2:])
            prompt_modified = f"[INST] {user_prompt} [/INST]\n{model_prompt}"
        else:
            system_prompt = prompt.split("\n")[0]
            user_prompt = "\n".join(prompt.split("\n")[1:-2])
            model_prompt = "\n".join(prompt.split("\n")[-2:])
            prompt_modified = f"[INST] <<SYS>> {system_prompt} \n<</SYS>>\n{user_prompt} [/INST]\n{model_prompt}"

        prompt_modified_tokenized = task.tokenizer.encode(prompt_modified, add_special_tokens=False, return_tensors="pt").to("cuda")
        with torch.no_grad():
            outputs = task.model_object.generate(prompt_modified_tokenized, max_new_tokens=max_tokens, temperature=temperature, num_return_sequences=n, do_sample=True, min_length=-1)
        decoded_responses = task.tokenizer.batch_decode(outputs[:, prompt_modified_tokenized.shape[1]:], skip_special_tokens=True) # list of string reponse without the prompt included
        decoded_responses = [response.strip() for response in decoded_responses] # clean up new line and spaces
        if propose_prompt_flag:
            # ['input_x', 'input_y_second_to_last', 'final_prompt']
            x, y_second_to_last = x_y_pair
            task.prompt_df = task.prompt_df._append({'input_x': x, 'input_y_second_to_last': y_second_to_last, 'final_prompt': prompt_modified}, ignore_index=True)
        return decoded_responses
    else:
        print(f"model {model} has not been implemented!")
        raise NotImplementedError
    
def gpt_usage(backend="gemini-pro"):
    global completion_tokens, prompt_tokens
    if backend == "gpt-4":
        cost = completion_tokens / 1000 * 0.06 + prompt_tokens / 1000 * 0.03
    elif backend == "gpt-3.5-turbo" or "gemini-pro": # TODO: replace with actual amount when google releases the api cost
        cost = completion_tokens / 1000 * 0.002 + prompt_tokens / 1000 * 0.0015
    return {"completion_tokens": completion_tokens, "prompt_tokens": prompt_tokens, "cost": cost}
