from dotenv import load_dotenv
from openai import AzureOpenAI
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
import os
import json
import io
import sys
import time
from openai import OpenAI
from typing_extensions import override
from openai import AssistantEventHandler
import tiktoken
import re
import numpy as np
import ast
import anthropic
from mistralai import Mistral

openai_api_key = ''
claude_api_key = ''
mixtral_api_key = ''


def extract_and_check(response):
    matches = re.findall(r'<<<(.*?)>>>', response, re.DOTALL)
    extracted_text = matches[-1] if matches else ''
    itertools_present = 'code_interpreter' in response or '\n```python' in response
    return extracted_text, itertools_present

def extract_code(text):
    # Regular expression to match code blocks enclosed in triple backticks
    code_block_pattern = re.compile(r'```python\n(.*?)\n```', re.DOTALL)

    # Find all matches in the text
    code_blocks = code_block_pattern.findall(text)

    # If no code blocks are found, try to find indented code blocks
    if not code_blocks:
        return []

    return code_blocks

def count_total_tokens(user_prompt_list, response_total_list, model_name="gpt-3.5-turbo"):
    # Initialize the tokenizer
    encoding = tiktoken.encoding_for_model(model_name)

    # Count tokens in each list
    user_prompt_tokens = sum(len(encoding.encode(prompt)) for prompt in user_prompt_list)
    response_total_tokens = sum(len(encoding.encode(response)) for response in response_total_list)

    # Calculate the total token count
    total_tokens = user_prompt_tokens + response_total_tokens

    return total_tokens

def message_construct_func(user_prompt_list, response_total_list, system_message, model_name):
    if model_name in ['gpt-4o', 'gpt-4o-mini', 'gpt-35-turbo-16k-0613', 'gpt-4-turbo', 'gpt-3.5-turbo-0125', 'gpt-3.5-turbo',
                      'open-mixtral-8x7b', "mistral-large-latest"]:
        messages = [{"role": "system", "content": system_message}]
    elif model_name in ["o1-preview", 'o1-mini', "claude-3-sonnet-20240229", "claude-3-opus-20240229", "claude-3-haiku-20240307"]:
        messages = []

    for i in range(len(user_prompt_list)):
        messages.append({"role": "user", "content": user_prompt_list[i]})
        if i < len(user_prompt_list) - 1:
            messages.append({"role": "assistant", "content": response_total_list[i]})
    return messages

def GPT_response(system_message, question, model_name, code_interpreter, user_prompt_list, response_total_list):
    for iteration_num in range(15):
        try:
            response = GPT_response_once(system_message, question, model_name, code_interpreter, user_prompt_list, response_total_list, use_azure = True)
            return response
        except Exception as e:
            print(f"Error on iteration {iteration_num + 1}: {e}")
            if iteration_num >= 8:
                try:
                    if model_name == 'gpt-35-turbo-16k-0613':
                        response = GPT_response_once(system_message, question, 'gpt-3.5-turbo-0125', code_interpreter,
                                                     user_prompt_list, response_total_list, use_azure=False)
                    else:
                        response = GPT_response_once(system_message, question, model_name, code_interpreter, user_prompt_list, response_total_list, use_azure = False)
                    return response
                except:
                    pass
            print("Waiting for 20 seconds before retrying...")
            time.sleep(30)
    raise RuntimeError("Failed to get a response after 15 attempts")

def GPT_response_once(system_message, question, model_name, code_interpreter, user_prompt_list, response_total_list, use_azure):
    openai_api_key_name = openai_api_key
    claude_api_key_name = claude_api_key
    mixtral_api_key = mixtral_api_key

    if model_name not in ["o1-preview", 'o1-mini', 'gpt-4o', 'gpt-4o-mini', 'gpt-35-turbo-16k-0613', 'gpt-4-turbo', 'gpt-3.5-turbo-0125', 'gpt-3.5-turbo',
                          "claude-3-sonnet-20240229", "claude-3-opus-20240229", "claude-3-haiku-20240307",
                          'open-mixtral-8x7b', "mistral-large-latest"]:
        print('\nModel name is wrong!')
        raise error
    if model_name == 'gpt-35-turbo-16k-0613':
        model_name = 'gpt-3.5-turbo'

    input_messages = message_construct_func(user_prompt_list, response_total_list, system_message, model_name)

    if code_interpreter == False:
        if model_name in ['gpt-4o', 'gpt-4o-mini', 'gpt-35-turbo-16k-0613', 'gpt-4-turbo', 'gpt-3.5-turbo-0125', 'gpt-3.5-turbo']:
           client = OpenAI(api_key=openai_api_key_name)
           response = client.chat.completions.create(
                model=model_name,
                messages=input_messages,
                temperature=0.0,
                top_p=1,
                frequency_penalty=0,
                presence_penalty=0
            )
           return response.choices[0].message.content
        elif model_name in ["o1-preview", 'o1-mini']:
            client = OpenAI(api_key=openai_api_key_name)
            response = client.chat.completions.create(
                model=model_name,
                messages=input_messages
            )
            return response.choices[0].message.content
        elif model_name in ["claude-3-sonnet-20240229", "claude-3-opus-20240229", "claude-3-haiku-20240307"]:
            client = anthropic.Anthropic(
                # defaults to os.environ.get("ANTHROPIC_API_KEY")
                api_key=claude_api_key_name,
            )

            message = client.messages.create(
                model=model_name,
                # claude-3-sonnet-20240229, claude-3-opus-20240229, claude-3-haiku-20240307
                max_tokens=2000,
                temperature=0.0,
                top_p=1,
                system=system_message,
                messages=input_messages,
            )
            return message.content[0].text
        elif model_name in ['open-mixtral-8x7b', "mistral-large-latest"]:
            client = Mistral(api_key = mixtral_api_key)
            chat_response = client.chat.complete(
                model=model_name,
                messages=input_messages,
                temperature=0.0,
                top_p=1,
            )
            return chat_response.choices[0].message.content

    elif code_interpreter == True:
        client = OpenAI(api_key=openai_api_key_name)

        # Create a StringIO object to capture the output
        captured_output = io.StringIO()

        # Save the current stdout so we can restore it later
        original_stdout = sys.stdout

        try:
            # Redirect stdout to the StringIO object
            sys.stdout = captured_output

            assistant = client.beta.assistants.create(
                instructions=system_message,
                model=model_name,
                tools=[{"type": "code_interpreter"}],
                temperature=0.0,
                top_p=1
            )

            thread = client.beta.threads.create()

            message = client.beta.threads.messages.create(
                thread_id=thread.id,
                role="user",
                content=question)

            # EventHandler class to handle the events in the response stream.
            class EventHandler(AssistantEventHandler):
                @override
                def on_text_created(self, text) -> None:
                    print(f"\nassistant > ", end="", flush=True)

                @override
                def on_text_delta(self, delta, snapshot):
                    print(delta.value, end="", flush=True)

                def on_tool_call_created(self, tool_call):
                    print(f"\nassistant > {tool_call.type}\n", flush=True)

                def on_tool_call_delta(self, delta, snapshot):
                    if delta.type == 'code_interpreter':
                        if delta.code_interpreter.input:
                            print(delta.code_interpreter.input, end="", flush=True)
                        if delta.code_interpreter.outputs:
                            print(f"\n\noutput >", flush=True)
                            for output in delta.code_interpreter.outputs:
                                if output.type == "logs":
                                    print(f"\n{output.logs}", flush=True)

            # Use the `stream` SDK helper with the `EventHandler` class to create the Run and stream the response.
            with client.beta.threads.runs.stream(
                    thread_id=thread.id,
                    assistant_id=assistant.id,
                    instructions="",
                    event_handler=EventHandler(),
            ) as stream:
                stream.until_done()

        finally:
            # Reset stdout to the original value
            sys.stdout = original_stdout

        # Get the captured output as a string
        output = captured_output.getvalue()
        return output
    else:
        print('\nCode interpreter expression is wrong!')
        raise error