from typing import Optional, Union
import numpy as np
from Prompt.delivery_man_prompt import delivery_man_user_prompt, delivery_man_reasoning_user_prompt, delivery_man_system_prompt, delivery_man_context_prompt
from Base.ActionSpace import ActionSpace
from Base.ReasoningSpace import ReasoningSpace
from simworld.llm.base_llm import BaseLLM
from utils.prompt import extract_json_and_fix_escapes
import time
import json
import requests

class BaseModel(BaseLLM):
    def __init__(self, model: str = 'meta-llama/Meta-Llama-3.1-70B-Instruct', url: str = "https://api.openai.com/v1", api_key: str = None):
        super().__init__(model_name=model, url = url, api_key = api_key)
        self.max_tokens = 1000
        self.temperature = 0.7
        self.top_p = 1.0
        self.num_return_sequences = 1
        self.rate_limit_per_min = 20
        self.logprobs = None
        self.model = model

    def generate(
        self,
        system_prompt: Optional[Union[str, list[str]]],
        user_prompt: Optional[Union[str, list[str]]],
        images: Optional[Union[str, list[str], np.ndarray, list[np.ndarray]]] = None,
        max_tokens: int = None,
        top_p: float = 1.0,
        num_return_sequences: int = 1,
        rate_limit_per_min: Optional[int] = 20,
        logprobs: Optional[int] = None,
        temperature=None,
        additional_prompt=None,
        retry=64,
        action_history: Optional[list[str]] = None,
        is_instruct_model: bool = False,
        **kwargs,
    ):

        max_tokens = self.max_tokens if max_tokens is None else max_tokens
        temperature = self.temperature if temperature is None else temperature
        logprobs = 0 if logprobs is None else logprobs

        supports_vision = False
        multimodal_models = ["gpt-4o", "gpt-4o-mini", "o1", "o1-mini"]
        model_name = self.model.lower()
        if model_name in multimodal_models:
            supports_vision = True

        messages = [{"role": "system", "content": system_prompt}]
        user_content = []
        if user_prompt:
            user_content.append({"type": "text", "text": user_prompt})
        if images and supports_vision:
            if isinstance(images, str):
                images = [images]
            for image in images:
                # If image is already a base64 string, use it directly
                if isinstance(image, str):
                    img_data = image
                else:
                    img_data = self._process_image_to_base64(image)
                user_content.append({
                    "type": "image_url",
                    "image_url": {"url": f"data:image/jpeg;base64,{img_data}"}
                })

        if action_history:
            user_content.append({"type": "text", "text": f"Your action history is: {action_history}"})

        messages.append({"role": "user", "content": user_content if len(user_content) > 1 else user_content[0]["text"]})

        for i in range(1, retry + 1):
            # try:
                # sleep several seconds to avoid rate limit
            if rate_limit_per_min is not None:
                time.sleep(60 / rate_limit_per_min)
            response = self.client.chat.completions.create(
                model=self.model,
                messages=messages,
                max_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p,
                n=num_return_sequences,
                **kwargs,
            )
            return response.choices[0].message.content
            # except Exception as e:
            #     print(f"An Error Occured: {e}, sleeping for {i} seconds")
            #     time.sleep(i)

        # after 64 tries, still no luck
        raise RuntimeError(
            "BaseModel failed to generate output, even after 64 tries"
        )

    def react(self, system_prompt: str,
                context_prompt: str,
                user_prompt: str,
                reasoning_prompt: str,
                images: Optional[Union[str, list[str], np.ndarray, list[np.ndarray]]] = None,
                max_tokens: int = None,
                top_p: float = 1.0,
                num_return_sequences: int = 1,
                rate_limit_per_min: Optional[int] = 20,
                logprobs: Optional[int] = None,
                temperature=None,
                additional_prompt=None,
                retry=4,
                max_retry_time=60,  # maximum retry time (seconds)
                action_history: Optional[list[str]] = None,
                is_instruct_model: bool = False,
                **kwargs):
        """
            ReAct-1: Reasoning and Acting
            The model will first make a reasoning according to the prompt, then consider the function calling if there is necessary.
            Then the model will act according to the reasoning and function calling.
        """
        max_tokens = self.max_tokens if max_tokens is None else max_tokens
        temperature = self.temperature if temperature is None else temperature
        num_return_sequences = self.num_return_sequences if num_return_sequences is None else num_return_sequences

        # set the default timeout time
        kwargs['timeout'] = kwargs.get('timeout', 30)
        start_time = time.time()

        reasoning_prompt = reasoning_prompt.format(context=context_prompt)
        reasoning_prompt += "\nPlease respond in valid JSON format following this schema: " + str(ReasoningSpace.to_json_schema())

        reasoning_response = None
        for i in range(1, retry + 1):
            try:
                # check if the maximum retry time is exceeded
                if time.time() - start_time > max_retry_time:
                    print(f"Exceeded maximum retry time of {max_retry_time} seconds")
                    break

                response = requests.post(
                    url='https://openrouter.ai/api/v1/chat/completions',
                    headers={
                        "Authorization": f"Bearer {self.api_key}",
                    },
                    data=json.dumps({
                        "model": self.model,
                        "messages": [
                            {"role": "system", "content": [
                                {"type": "text", "text": system_prompt, "cache_control": {"type": "ephemeral"}},
                                # {"type": "text", "text": system_prompt},
                                {"type": "cache_control"}
                            ]},
                            {"role": "user", "content": [
                                {"type": "text", "text": reasoning_prompt}
                            ]}
                        ]
                    }),
                    timeout=kwargs['timeout']
                ).json()
                reasoning_response = response['choices'][0]['message']['content']
                break
            except requests.Timeout:
                print(f"Timeout occurred on attempt {i}/{retry}")
                if i < retry:
                    # after timeout, use a longer waiting time
                    wait_time = min(i * 5, 20)  # after timeout, wait for a longer time
                    time.sleep(wait_time)
                continue
            except requests.RequestException as e:
                print(f"Request error occurred: {e}, attempt {i}/{retry}")
                if i < retry:
                    time.sleep(min(i * 2, 10))
                continue
            except Exception as e:
                print(f"Unexpected error occurred: {e}, attempt {i}/{retry}")
                if i < retry:
                    time.sleep(min(i * 2, 10))
                continue

        if reasoning_response is None:
            print("Warning: Failed to get reasoning response, using default")
            reasoning_json = {"reasoning": "No reasoning available"}
        else:
            reasoning_json = extract_json_and_fix_escapes(reasoning_response)
            if reasoning_json is None:
                reasoning_json = {"reasoning": "No reasoning available"}

        reasoning = ReasoningSpace.from_json(reasoning_json)

        user_prompt = user_prompt.format(
            # context=context_prompt,
            reasoning=reasoning.reasoning
        )
        user_prompt += "\nPlease respond in valid JSON format following this schema: " + str(ActionSpace.to_json_schema())

        action_response = None
        for i in range(1, retry + 1):
            try:
                # check if the maximum retry time is exceeded
                if time.time() - start_time > max_retry_time:
                    print(f"Exceeded maximum retry time of {max_retry_time} seconds")
                    break

                response = requests.post(
                    url='https://openrouter.ai/api/v1/chat/completions',
                    headers={
                        "Authorization": f"Bearer {self.api_key}",
                    },
                    data=json.dumps({
                        "model": self.model,
                        "messages": [
                            {"role": "system", "content": [
                                {"type": "text", "text": system_prompt, "cache_control": {"type": "ephemeral"}},
                                {"type": "cache_control"}
                            ]},
                            {"role": "user", "content": [
                                {"type": "text", "text": user_prompt}
                            ]}
                        ]
                    }),
                    timeout=kwargs['timeout']
                ).json()
                action_response = response['choices'][0]['message']['content']
                break
            except requests.Timeout:
                print(f"Timeout occurred on attempt {i}/{retry}")
                if i < retry:
                    wait_time = min(i * 5, 20)
                    time.sleep(wait_time)
                continue
            except requests.RequestException as e:
                print(f"Request error occurred: {e}, attempt {i}/{retry}")
                if i < retry:
                    time.sleep(min(i * 2, 10))
                continue
            except Exception as e:
                print(f"Unexpected error occurred: {e}, attempt {i}/{retry}")
                if i < retry:
                    time.sleep(min(i * 2, 10))
                continue

        if action_response is None:
            print("Warning: Failed to get action response, using default")
            action_json = {
                "choice": 0,
                "index": 0,
                "target_point": [0, 0],
                "meeting_point": [0, 0],
                "next_waypoint": [0, 0],
                "new_speed": 0,
            }
        else:
            action_json = extract_json_and_fix_escapes(action_response)
            if action_json is None:
                action_json = {
                    "choice": 0,
                    "index": 0,
                    "target_point": [0, 0],
                    "meeting_point": [0, 0],
                    "next_waypoint": [0, 0],
                    "new_speed": 0,
                }

        return action_json