import json
import re
import functools
from typing import Any, Dict, Iterable, List, Type, cast

from pydantic import BaseModel


from openai import OpenAI, AsyncOpenAI
from openai.types.chat import (
    ChatCompletionAssistantMessageParam,
    ChatCompletionContentPartParam,
    ChatCompletionContentPartTextParam,
    ChatCompletionContentPartImageParam,
    ChatCompletionContentPartRefusalParam,
    ChatCompletionMessage,
    ChatCompletionMessageParam,
    ChatCompletionMessageToolCall,
    ChatCompletionSystemMessageParam,
    ChatCompletionToolParam,
    ChatCompletionToolMessageParam,
    ChatCompletionUserMessageParam,
    ChatCompletion,
)
from uuid import uuid4
from opentelemetry import trace
from mcp.types import (
    CallToolRequestParams,
    CallToolRequest,
    CallToolResult,
    EmbeddedResource,
    ImageContent,
    ListToolsResult,
    ModelPreferences,
    TextContent,
    TextResourceContents,
)

from mcp_agent.config import OpenAISettings
from mcp_agent.executor.workflow_task import workflow_task
from mcp_agent.tracing.telemetry import get_tracer, telemetry
from mcp_agent.tracing.semconv import (
    GEN_AI_AGENT_NAME,
    GEN_AI_REQUEST_MODEL,
    GEN_AI_RESPONSE_FINISH_REASONS,
    GEN_AI_TOOL_CALL_ID,
    GEN_AI_TOOL_NAME,
    GEN_AI_USAGE_INPUT_TOKENS,
    GEN_AI_USAGE_OUTPUT_TOKENS,
)
from mcp_agent.tracing.telemetry import is_otel_serializable
from mcp_agent.utils.common import ensure_serializable, typed_dict_extras
from mcp_agent.utils.mime_utils import image_url_to_mime_and_base64
from mcp_agent.utils.pydantic_type_serializer import serialize_model, deserialize_model
from mcp_agent.workflows.llm.augmented_llm import (
    AugmentedLLM,
    MessageTypes,
    ModelT,
    MCPMessageParam,
    MCPMessageResult,
    ProviderToMCPConverter,
    RequestParams,
)
from mcp_agent.logging.logger import get_logger
from mcp_agent.workflows.llm.multipart_converter_openai import OpenAIConverter


class RequestCompletionRequest(BaseModel):
    config: OpenAISettings
    payload: dict


class RequestStructuredCompletionRequest(BaseModel):
    config: OpenAISettings
    response_model: Any | None = None
    serialized_response_model: str | None = None
    response_str: str
    model: str
    user: str | None = None


class OpenAIAugmentedLLM(
    AugmentedLLM[ChatCompletionMessageParam, ChatCompletionMessage]
):
    """
    The basic building block of agentic systems is an LLM enhanced with augmentations
    such as retrieval, tools, and memory provided from a collection of MCP servers.
    This implementation uses OpenAI's ChatCompletion as the LLM.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, type_converter=MCPOpenAITypeConverter, **kwargs)
        # Track last request's token usage aggregated across internal iterations
        self.last_usage: dict | None = None

        self.provider = "OpenAI"
        # Initialize logger with name if available
        self.logger = get_logger(f"{__name__}.{self.name}" if self.name else __name__)

        self.model_preferences = self.model_preferences or ModelPreferences(
            costPriority=0.3,
            speedPriority=0.4,
            intelligencePriority=0.3,
        )

        # Get default model from config if available
        if "default_model" in kwargs:
            default_model = kwargs["default_model"]
        else:
            default_model = "gpt-4o"  # Fallback default

        self._reasoning_effort = "medium"
        if self.context and self.context.config and self.context.config.openai:
            if hasattr(self.context.config.openai, "default_model"):
                default_model = self.context.config.openai.default_model
            if hasattr(self.context.config.openai, "reasoning_effort"):
                self._reasoning_effort = self.context.config.openai.reasoning_effort

        self._reasoning = lambda model: model.startswith((
            "o1",
            "o3",
            "o4",
            "gpt-5",
            "gpt-5-mini",
        ))

        if self._reasoning(default_model):
            self.logger.info(
                f"Using reasoning model '{default_model}' with '{self._reasoning_effort}' reasoning effort"
            )

        self.default_request_params = self.default_request_params or RequestParams(
            model=default_model,
            modelPreferences=self.model_preferences,
            maxTokens=12800,  
            systemPrompt=self.instruction,
            max_iterations=30,  # 从 20 增加到 30，Support更复杂的多步骤推理
            use_history=True,
        )

    @classmethod
    def convert_message_to_message_param(
        cls, message: ChatCompletionMessage, **kwargs
    ) -> ChatCompletionMessageParam:
        """Convert a response object to an input parameter object to allow LLM calls to be chained."""
        assistant_message_params = {
            "role": "assistant",
            "audio": message.audio,
            "refusal": message.refusal,
            **kwargs,
        }
        if message.content is not None:
            assistant_message_params["content"] = message.content
        if message.tool_calls is not None:
            assistant_message_params["tool_calls"] = message.tool_calls

        return ChatCompletionAssistantMessageParam(**assistant_message_params)

    async def generate(
        self,
        message,
        request_params: RequestParams | None = None,
    ):
        """
        Process a query using an LLM and available tools.
        The default implementation uses OpenAI's ChatCompletion as the LLM.
        Override this method to use a different LLM.
        """
        tracer = get_tracer(self.context)
        with tracer.start_as_current_span(
            f"{self.__class__.__name__}.{self.name}.generate"
        ) as span:
            span.set_attribute(GEN_AI_AGENT_NAME, self.agent.name)
            self._annotate_span_for_generation_message(span, message)

            messages: List[ChatCompletionMessageParam] = []
            params = self.get_request_params(request_params)

            if self.context.tracing_enabled:
                AugmentedLLM.annotate_span_with_request_params(span, params)

            if params.use_history:
                messages.extend(self.history.get())

            system_prompt = self.instruction or params.systemPrompt
            if system_prompt and len(messages) == 0:
                span.set_attribute("system_prompt", system_prompt)
                messages.append(
                    ChatCompletionSystemMessageParam(
                        role="system", content=system_prompt
                    )
                )
            messages.extend((OpenAIConverter.convert_mixed_messages_to_openai(message)))

            response: ListToolsResult = await self.agent.list_tools()
            available_tools: List[ChatCompletionToolParam] = [
                ChatCompletionToolParam(
                    type="function",
                    function={
                        "name": tool.name,
                        "description": tool.description,
                        "parameters": tool.inputSchema,
                        # TODO: saqadri - determine if we should specify "strict" to True by default
                    },
                )
                for tool in response.tools
            ]

            if self.context.tracing_enabled:
                span.set_attribute(
                    "available_tools",
                    [t.get("function", {}).get("name") for t in available_tools],
                )
            if not available_tools:
                available_tools = None

            responses: List[ChatCompletionMessage] = []
            model = await self.select_model(params)
            if model:
                span.set_attribute(GEN_AI_REQUEST_MODEL, model)

            # prefer user from the request params,
            # otherwise use the default from the config
            user = params.user or getattr(self.context.config.openai, "user", None)
            if self.context.tracing_enabled and user:
                span.set_attribute("user", user)

            total_input_tokens = 0
            total_output_tokens = 0
            finish_reasons = []

            for i in range(params.max_iterations):
                arguments = {
                    "model": model,
                    "messages": messages,
                    "tools": available_tools,
                }

                # If tools are available, explicitly enable tool calling. Some OpenAI-compatible servers (e.g., vLLM)
                # require setting tool_choice to "auto" to allow the model to actually emit tool calls.
                if available_tools:
                    arguments["tool_choice"] = "auto"

                if user:
                    arguments["user"] = user

                if params.stopSequences is not None:
                    arguments["stop"] = params.stopSequences

                if self._reasoning(model):
                    arguments = {
                        **arguments,
                        # DEPRECATED: https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens
                        # "max_tokens": params.maxTokens,
                        "max_completion_tokens": params.maxTokens,
                        "reasoning_effort": self._reasoning_effort,
                    }
                else:
                    arguments = {**arguments, "max_tokens": params.maxTokens}
                    # if available_tools:
                    #     arguments["parallel_tool_calls"] = params.parallel_tool_calls

                if params.metadata:
                    arguments = {**arguments, **params.metadata}

                self.logger.debug(f"{arguments}")
                self._log_chat_progress(chat_turn=len(messages) // 2, model=model)

                request = RequestCompletionRequest(
                    config=self.context.config.openai,
                    payload=arguments,
                )

                self._annotate_span_for_completion_request(span, request, i)

                response: ChatCompletion = await self.executor.execute(
                    OpenAICompletionTasks.request_completion_task,
                    ensure_serializable(request),
                )

                self.logger.debug(
                    "OpenAI ChatCompletion response:",
                    data=response,
                )

                if isinstance(response, BaseException):
                    self.logger.error(f"Error: {response}")
                    span.record_exception(response)
                    span.set_status(trace.Status(trace.StatusCode.ERROR))
                    break

                self._annotate_span_for_completion_response(span, response, i)

                total_input_tokens += response.usage.prompt_tokens
                total_output_tokens += response.usage.completion_tokens

                if not response.choices or len(response.choices) == 0:
                    # No response from the model, we're done
                    break

                # TODO: saqadri - handle multiple choices for more complex interactions.
                # Keeping it simple for now because multiple choices will also complicate memory management
                choice = response.choices[0]
                message = choice.message
                responses.append(message)
                finish_reasons.append(choice.finish_reason)

                # Fixes an issue with openai validation that does not allow non alphanumeric characters, dashes, and underscores
                sanitized_name = (
                    re.sub(r"[^a-zA-Z0-9_-]", "_", self.name)
                    if isinstance(self.name, str)
                    else None
                )

                converted_message = self.convert_message_to_message_param(
                    message, name=sanitized_name
                )
                messages.append(converted_message)

                # Fallback: if no structured tool_calls returned but the model emitted textual tool call markers,
                # parse them and execute as structured tool calls to keep the agent flow working on OpenAI-compatible servers.
                fallback_tool_calls: List[ChatCompletionMessageToolCall] = []
                if not getattr(message, "tool_calls", None):
                    try:
                        fallback_tool_calls = self._extract_fallback_tool_calls_from_text(
                            self.message_str(message, content_only=True)
                        )
                    except Exception:
                        fallback_tool_calls = []

                if (
                    choice.finish_reason in ["tool_calls", "function_call"]
                    and message.tool_calls
                ):
                    # Execute all tool calls in parallel using functools.partial to bind arguments
                    tool_tasks = [
                        functools.partial(self.execute_tool_call, tool_call=tool_call)
                        for tool_call in message.tool_calls
                    ]
                    # Wait for all tool calls to complete.
                    tool_results = await self.executor.execute_many(tool_tasks)
                    self.logger.debug(
                        f"Iteration {i}: Tool call results: {str(tool_results) if tool_results else 'None'}"
                    )
                    # Add non-None results to messages.
                    for result in tool_results:
                        if isinstance(result, BaseException):
                            self.logger.error(
                                f"Warning: Unexpected error during tool execution: {result}. Continuing..."
                            )
                            span.record_exception(result)
                            continue
                        if result is not None:
                            messages.append(result)
                elif fallback_tool_calls:
                    # Execute parsed fallback tool calls even though the provider didn't return structured tool_calls
                    tool_tasks = [
                        functools.partial(self.execute_tool_call, tool_call=tool_call)
                        for tool_call in fallback_tool_calls
                    ]
                    tool_results = await self.executor.execute_many(tool_tasks)
                    self.logger.debug(
                        f"Iteration {i}: Fallback tool call results: {str(tool_results) if tool_results else 'None'}"
                    )
                    for result in tool_results:
                        if isinstance(result, BaseException):
                            self.logger.error(
                                f"Warning: Unexpected error during fallback tool execution: {result}. Continuing..."
                            )
                            span.record_exception(result)
                            continue
                        if result is not None:
                            messages.append(result)
                elif choice.finish_reason == "length":
                    # We have reached the max tokens limit
                    self.logger.debug(
                        f"Iteration {i}: Stopping because finish_reason is 'length'"
                    )
                    span.set_attribute("finish_reason", "length")
                    # TODO: saqadri - would be useful to return the reason for stopping to the caller
                    break
                elif choice.finish_reason == "content_filter":
                    # The response was filtered by the content filter
                    self.logger.debug(
                        f"Iteration {i}: Stopping because finish_reason is 'content_filter'"
                    )
                    span.set_attribute("finish_reason", "content_filter")
                    # TODO: saqadri - would be useful to return the reason for stopping to the caller
                    break
                elif choice.finish_reason == "stop":
                    self.logger.debug(
                        f"Iteration {i}: Stopping because finish_reason is 'stop'"
                    )
                    span.set_attribute("finish_reason", "stop")
                    break

            if params.use_history:
                self.history.set(messages)

            self._log_chat_finished(model=model)

            if self.context.tracing_enabled:
                span.set_attribute(GEN_AI_USAGE_INPUT_TOKENS, total_input_tokens)
                span.set_attribute(GEN_AI_USAGE_OUTPUT_TOKENS, total_output_tokens)
                span.set_attribute(GEN_AI_RESPONSE_FINISH_REASONS, finish_reasons)

                for i, res in enumerate(responses):
                    response_data = (
                        self.extract_response_message_attributes_for_tracing(
                            res, prefix=f"response.{i}"
                        )
                    )
                    span.set_attributes(response_data)

            # Expose aggregated usage to callers
            self.last_usage = {
                "input_tokens": int(total_input_tokens),
                "output_tokens": int(total_output_tokens),
                "total_tokens": int(total_input_tokens + total_output_tokens),
            }

            return responses

    def _extract_fallback_tool_calls_from_text(
        self, text: str
    ) -> List[ChatCompletionMessageToolCall]:
        """Parse textual tool call markers into structured tool_calls.

        Supported patterns:
          - <tool_call>{"name": "...", "arguments"|"parameters": {...}}</tool_call>
          - <|python_tag|>{...same as above...}
          - JSON code blocks: ```json {"type":"function","name":"...","parameters":{...}} ```
          - Generic JSON dict/list with recognized fields
        """
        calls: List[ChatCompletionMessageToolCall] = []
        if not text:
            return calls

        # Pattern 1: <tool_call> ... </tool_call>
        try:
            for m in re.findall(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", text, re.DOTALL):
                try:
                    obj = json.loads(m)
                except Exception:
                    continue
                calls.extend(self._tool_objs_to_calls(obj))
        except Exception:
            pass

        # Pattern 2: <|python_tag|>{...}
        if not calls and "<|python_tag|>" in text:
            try:
                idx = text.find("<|python_tag|>")
                brace = text.find("{", idx)
                if brace != -1:
                    candidate = text[brace:].strip()
                    # Try straightforward json.loads
                    try:
                        obj = json.loads(candidate)
                    except Exception:
                        # Try to cut at the last closing brace to form a JSON object
                        last = candidate.rfind("}")
                        if last != -1:
                            obj = json.loads(candidate[: last + 1])
                        else:
                            obj = None
                    if obj:
                        calls.extend(self._tool_objs_to_calls(obj))
            except Exception:
                pass

        # Pattern 3: JSON code blocks ```...```
        if not calls:
            try:
                for block in re.findall(r"```(?:json|JSON|\w+)?\s*(\{.*?\}|\[.*?\])\s*```", text, re.DOTALL):
                    try:
                        obj = json.loads(block)
                    except Exception:
                        continue
                    calls.extend(self._tool_objs_to_calls(obj))
            except Exception:
                pass

        # Pattern 4: last resort – search for first top-level JSON object in text
        if not calls:
            try:
                # Find the first '{' and try to parse a JSON object till the matching '}'
                first = text.find('{')
                if first != -1:
                    candidate = text[first:]
                    last = candidate.rfind('}')
                    if last != -1:
                        obj_str = candidate[: last + 1]
                        try:
                            obj = json.loads(obj_str)
                            calls.extend(self._tool_objs_to_calls(obj))
                        except Exception:
                            pass
            except Exception:
                pass

        return calls

    def _tool_objs_to_calls(self, obj: Any) -> List[ChatCompletionMessageToolCall]:
        """Convert recognized tool-call JSON dict/list into ChatCompletionMessageToolCall list.

        Accepts formats like:
          - {"type":"function","name":"...","arguments"|"parameters": {...}}
          - {"function": {"name": "...", "arguments": "{...}" }}
          - [ {...}, {...} ]
        """
        out: List[ChatCompletionMessageToolCall] = []
        objs = obj if isinstance(obj, list) else [obj]
        for item in objs:
            if not isinstance(item, dict):
                continue
            # OpenAI-style nested function object
            if "function" in item and isinstance(item["function"], dict):
                fn = item["function"]
                name = fn.get("name")
                args_obj = fn.get("arguments")
                if name and args_obj is not None:
                    args_str = args_obj if isinstance(args_obj, str) else json.dumps(args_obj)
                    out.append(ChatCompletionMessageToolCall(
                        id=f"call_fallback_{uuid4().hex}",
                        type="function",
                        function={"name": name, "arguments": args_str},
                    ))
                    continue
            # Flat structure with type/name/arguments or parameters
            name = item.get("name") or item.get("tool")
            type_val = item.get("type")
            args_obj = item.get("arguments")
            if args_obj is None:
                args_obj = item.get("parameters")
            if (type_val == "function" or name) and args_obj is not None:
                if not name:
                    # some formats use {"tool":"name", ...}
                    name = item.get("tool")
                if not name:
                    continue
                args_str = args_obj if isinstance(args_obj, str) else json.dumps(args_obj)
                out.append(ChatCompletionMessageToolCall(
                    id=f"call_fallback_{uuid4().hex}",
                    type="function",
                    function={"name": name, "arguments": args_str},
                ))
        return out

    async def generate_str(
        self,
        message,
        request_params: RequestParams | None = None,
    ):
        """
        Process a query using an LLM and available tools.
        The default implementation uses OpenAI's ChatCompletion as the LLM.
        Override this method to use a different LLM.
        """
        tracer = get_tracer(self.context)
        with tracer.start_as_current_span(
            f"{self.__class__.__name__}.{self.name}.generate_str"
        ) as span:
            if self.context.tracing_enabled:
                span.set_attribute(GEN_AI_AGENT_NAME, self.agent.name)
                self._annotate_span_for_generation_message(span, message)
                if request_params:
                    AugmentedLLM.annotate_span_with_request_params(span, request_params)

            responses = await self.generate(
                message=message,
                request_params=request_params,
            )

            final_text: List[str] = []

            for response in responses:
                content = response.content
                if not content:
                    continue

                if isinstance(content, str):
                    final_text.append(content)
                    continue

            res = "\n".join(final_text)
            span.set_attribute("response", res)
            return res

    async def generate_structured(
        self,
        message,
        response_model: Type[ModelT],
        request_params: RequestParams | None = None,
    ) -> ModelT:
        # First we invoke the LLM to generate a string response
        # We need to do this in a two-step process because Instructor doesn't
        # know how to invoke MCP tools via call_tool, so we'll handle all the
        # processing first and then pass the final response through Instructor
        tracer = get_tracer(self.context)
        with tracer.start_as_current_span(
            f"{self.__class__.__name__}.{self.name}.generate_structured"
        ) as span:
            span.set_attribute(GEN_AI_AGENT_NAME, self.agent.name)
            self._annotate_span_for_generation_message(span, message)

            params = self.get_request_params(request_params)

            if self.context.tracing_enabled:
                AugmentedLLM.annotate_span_with_request_params(span, params)

            response = await self.generate_str(
                message=message,
                request_params=params,
            )

            model = await self.select_model(params) or "gpt-4o"
            span.set_attribute(GEN_AI_REQUEST_MODEL, model)

            span.set_attribute("response_model", response_model.__name__)

            serialized_response_model: str | None = None

            if self.executor and self.executor.execution_engine == "temporal":
                # Serialize the response model to a string
                serialized_response_model = serialize_model(response_model)

            structured_response = await self.executor.execute(
                OpenAICompletionTasks.request_structured_completion_task,
                RequestStructuredCompletionRequest(
                    config=self.context.config.openai,
                    response_model=response_model
                    if not serialized_response_model
                    else None,
                    serialized_response_model=serialized_response_model,
                    response_str=response,
                    model=model,
                    user=params.user
                    or getattr(self.context.config.openai, "user", None),
                ),
            )
            # TODO: saqadri (MAC) - fix request_structured_completion_task to return ensure_serializable
            # Convert dict back to the proper model instance if needed
            if isinstance(structured_response, dict):
                structured_response = response_model.model_validate(structured_response)

            if self.context.tracing_enabled:
                try:
                    span.set_attribute(
                        "structured_response_json",
                        structured_response.model_dump_json(),
                    )
                # pylint: disable=broad-exception-caught
                except Exception:
                    span.set_attribute("unstructured_response", response)

            return structured_response

    async def pre_tool_call(self, tool_call_id: str | None, request: CallToolRequest):
        return request

    async def post_tool_call(
        self, tool_call_id: str | None, request: CallToolRequest, result: CallToolResult
    ):
        return result

    async def execute_tool_call(
        self,
        tool_call: ChatCompletionMessageToolCall,
    ) -> ChatCompletionToolMessageParam | None:
        """
        Execute a single tool call and return the result message.
        Returns None if there's no content to add to messages.
        """
        tracer = get_tracer(self.context)
        with tracer.start_as_current_span(
            f"{self.__class__.__name__}.{self.name}.execute_tool_call"
        ) as span:
            tool_name = tool_call.function.name
            tool_args_str = tool_call.function.arguments
            tool_call_id = tool_call.id
            tool_args = {}

            if self.context.tracing_enabled:
                span.set_attribute(GEN_AI_TOOL_CALL_ID, tool_call_id)
                span.set_attribute(GEN_AI_TOOL_NAME, tool_name)
                span.set_attribute("tool_args", tool_args_str)

            try:
                if tool_args_str:
                    tool_args = json.loads(tool_args_str)
            except json.JSONDecodeError as e:
                span.record_exception(e)
                span.set_status(trace.Status(trace.StatusCode.ERROR))
                return ChatCompletionToolMessageParam(
                    role="tool",
                    tool_call_id=tool_call_id,
                    content=f"Invalid JSON provided in tool call arguments for '{tool_name}'. Failed to load JSON: {str(e)}",
                )

            tool_call_request = CallToolRequest(
                method="tools/call",
                params=CallToolRequestParams(name=tool_name, arguments=tool_args),
            )

            result = await self.call_tool(
                request=tool_call_request, tool_call_id=tool_call_id
            )

            self._annotate_span_for_call_tool_result(span, result)

            if result.content:
                return ChatCompletionToolMessageParam(
                    role="tool",
                    tool_call_id=tool_call_id,
                    content=[
                        mcp_content_to_openai_content_part(c) for c in result.content
                    ],
                )

            return None

    def message_param_str(self, message: ChatCompletionMessageParam) -> str:
        """Convert an input message to a string representation."""
        if message.get("content"):
            content = message["content"]
            if isinstance(content, str):
                return content
            else:  # content is a list
                final_text: List[str] = []
                for part in content:
                    text_part = part.get("text")
                    if text_part:
                        final_text.append(str(text_part))
                    else:
                        final_text.append(str(part))

                return "\n".join(final_text)

        return str(message)

    def message_str(
        self, message: ChatCompletionMessage, content_only: bool = False
    ) -> str:
        """Convert an output message to a string representation."""
        content = message.content
        if content:
            return content
        elif content_only:
            # If content_only is True, return empty string if no content
            return ""

        return str(message)

    def _annotate_span_for_generation_message(
        self,
        span: trace.Span,
        message: MessageTypes,
    ) -> None:
        """Annotate the span with the message content."""
        if not self.context.tracing_enabled:
            return
        if isinstance(message, str):
            span.set_attribute("message.content", message)
        elif isinstance(message, list):
            for i, msg in enumerate(message):
                if isinstance(msg, str):
                    span.set_attribute(f"message.{i}.content", msg)
                else:
                    span.set_attribute(f"message.{i}", str(msg))
        else:
            span.set_attribute("message", str(message))

    def _extract_message_param_attributes_for_tracing(
        self, message_param: ChatCompletionMessageParam, prefix: str = "message"
    ) -> dict[str, Any]:
        """Return a flat dict of span attributes for a given ChatCompletionMessageParam."""
        attrs = {}
        # TODO: rholinshead - serialize MessageParam dict
        return attrs

    def _annotate_span_for_completion_request(
        self, span: trace.Span, request: RequestCompletionRequest, turn: int
    ) -> None:
        """Annotate the span with the completion request as an event."""
        if not self.context.tracing_enabled:
            return

        event_data = {
            "completion.request.turn": turn,
            "config.reasoning_effort": request.config.reasoning_effort,
        }

        if request.config.base_url:
            event_data["config.base_url"] = request.config.base_url

        for key, value in request.payload.items():
            if key == "messages":
                for i, message in enumerate(
                    cast(List[ChatCompletionMessageParam], value)
                ):
                    role = message.get("role")
                    event_data[f"messages.{i}.role"] = role
                    message_content = message.get("content")

                    match role:
                        case "developer" | "system" | "user":
                            if isinstance(message_content, str):
                                event_data[f"messages.{i}.content"] = message_content
                            elif message_content is not None:
                                for j, part in enumerate(message_content):
                                    event_data[f"messages.{i}.content.{j}.type"] = part[
                                        "type"
                                    ]
                                    if part["type"] == "text":
                                        event_data[f"messages.{i}.content.{j}.text"] = (
                                            part["text"]
                                        )
                                    elif part["type"] == "image_url":
                                        event_data[
                                            f"messages.{i}.content.{j}.image_url.url"
                                        ] = part["image_url"]["url"]
                                        event_data[
                                            f"messages.{i}.content.{j}.image_url.detail"
                                        ] = part["image_url"]["detail"]
                                    elif part["type"] == "input_audio":
                                        event_data[
                                            f"messages.{i}.content.{j}.input_audio.format"
                                        ] = part["input_audio"]["format"]
                        case "assistant":
                            if isinstance(message_content, str):
                                event_data[f"messages.{i}.content"] = message_content
                            elif message_content is not None:
                                for j, part in enumerate(message_content):
                                    event_data[f"messages.{i}.content.{j}.type"] = part[
                                        "type"
                                    ]
                                    if part["type"] == "text":
                                        event_data[f"messages.{i}.content.{j}.text"] = (
                                            part["text"]
                                        )
                                    elif part["type"] == "refusal":
                                        event_data[
                                            f"messages.{i}.content.{j}.refusal"
                                        ] = part["refusal"]
                            if message.get("audio") is not None:
                                event_data[f"messages.{i}.audio.id"] = message.get(
                                    "audio"
                                ).get("id")
                            if message.get("function_call") is not None:
                                event_data[f"messages.{i}.function_call.name"] = (
                                    message.get("function_call").get("name")
                                )
                                event_data[f"messages.{i}.function_call.arguments"] = (
                                    message.get("function_call").get("arguments")
                                )
                            if message.get("name") is not None:
                                event_data[f"messages.{i}.name"] = message.get("name")
                            if message.get("refusal") is not None:
                                event_data[f"messages.{i}.refusal"] = message.get(
                                    "refusal"
                                )
                            if message.get("tool_calls") is not None:
                                for j, tool_call in enumerate(
                                    message.get("tool_calls")
                                ):
                                    event_data[
                                        f"messages.{i}.tool_calls.{j}.{GEN_AI_TOOL_CALL_ID}"
                                    ] = tool_call.id
                                    event_data[
                                        f"messages.{i}.tool_calls.{j}.function.name"
                                    ] = tool_call.function.name
                                    event_data[
                                        f"messages.{i}.tool_calls.{j}.function.arguments"
                                    ] = tool_call.function.arguments

                        case "tool":
                            event_data[f"messages.{i}.{GEN_AI_TOOL_CALL_ID}"] = (
                                message.get("tool_call_id")
                            )
                            if isinstance(message_content, str):
                                event_data[f"messages.{i}.content"] = message_content
                            elif message_content is not None:
                                for j, part in enumerate(message_content):
                                    event_data[f"messages.{i}.content.{j}.type"] = part[
                                        "type"
                                    ]
                                    if part["type"] == "text":
                                        event_data[f"messages.{i}.content.{j}.text"] = (
                                            part["text"]
                                        )
                        case "function":
                            event_data[f"messages.{i}.name"] = message.get("name")
                            event_data[f"messages.{i}.content"] = message_content

            elif key == "tools":
                if value is not None:
                    event_data["tools"] = [
                        tool.get("function", {}).get("name") for tool in value
                    ]
            elif is_otel_serializable(value):
                event_data[key] = value

        # Event name is based on the latest message role
        event_name = f"completion.request.{turn}"
        latest_message_role = request.payload.get("messages", [{}])[-1].get("role")

        if latest_message_role:
            event_name = f"gen_ai.{latest_message_role}.message"

        span.add_event(event_name, event_data)

    def _annotate_span_for_completion_response(
        self, span: trace.Span, response: ChatCompletion, turn: int
    ) -> None:
        """Annotate the span with the completion response as an event."""
        if not self.context.tracing_enabled:
            return

        event_data = {
            "completion.response.turn": turn,
        }

        event_data.update(
            self._extract_chat_completion_attributes_for_tracing(response)
        )

        # Event name is based on the first choice for now
        event_name = f"completion.response.{turn}"
        if response.choices and len(response.choices) > 0:
            latest_message_role = response.choices[0].message.role
            event_name = f"gen_ai.{latest_message_role}.message"

        span.add_event(event_name, event_data)

    def extract_response_message_attributes_for_tracing(
        self, message: ChatCompletionMessage, prefix: str | None = None
    ) -> Dict[str, Any]:
        """
        Extract relevant attributes from the ChatCompletionMessage for tracing.
        """
        if not self.context.tracing_enabled:
            return {}

        attr_prefix = f"{prefix}." if prefix else ""
        attrs = {
            f"{attr_prefix}role": message.role,
        }

        if message.content is not None:
            attrs[f"{attr_prefix}content"] = message.content

        if message.refusal:
            attrs[f"{attr_prefix}refusal"] = message.refusal
        if message.audio is not None:
            attrs[f"{attr_prefix}audio.id"] = message.audio.id
            attrs[f"{attr_prefix}audio.expires_at"] = message.audio.expires_at
            attrs[f"{attr_prefix}audio.transcript"] = message.audio.transcript
        if message.function_call is not None:
            attrs[f"{attr_prefix}function_call.name"] = message.function_call.name
            attrs[f"{attr_prefix}function_call.arguments"] = (
                message.function_call.arguments
            )
        if message.tool_calls:
            for j, tool_call in enumerate(message.tool_calls):
                attrs[f"{attr_prefix}tool_calls.{j}.{GEN_AI_TOOL_CALL_ID}"] = (
                    tool_call.id
                )
                attrs[f"{attr_prefix}tool_calls.{j}.function.name"] = (
                    tool_call.function.name
                )
                attrs[f"{attr_prefix}tool_calls.{j}.function.arguments"] = (
                    tool_call.function.arguments
                )

        return attrs

    def _extract_chat_completion_attributes_for_tracing(
        self, response: ChatCompletion, prefix: str | None = None
    ) -> Dict[str, Any]:
        """
        Extract relevant attributes from the ChatCompletion response for tracing.
        """
        if not self.context.tracing_enabled:
            return {}

        attr_prefix = f"{prefix}." if prefix else ""
        attrs = {
            f"{attr_prefix}id": response.id,
            f"{attr_prefix}model": response.model,
            f"{attr_prefix}object": response.object,
            f"{attr_prefix}created": response.created,
        }

        if response.service_tier:
            attrs[f"{attr_prefix}service_tier"] = response.service_tier

        if response.system_fingerprint:
            attrs[f"{attr_prefix}system_fingerprint"] = response.system_fingerprint

        if response.usage:
            attrs[f"{attr_prefix}{GEN_AI_USAGE_INPUT_TOKENS}"] = (
                response.usage.prompt_tokens
            )
            attrs[f"{attr_prefix}{GEN_AI_USAGE_OUTPUT_TOKENS}"] = (
                response.usage.completion_tokens
            )

        finish_reasons = []
        for i, choice in enumerate(response.choices):
            attrs[f"{attr_prefix}choices.{i}.index"] = choice.index
            attrs[f"{attr_prefix}choices.{i}.finish_reason"] = choice.finish_reason
            finish_reasons.append(choice.finish_reason)

            message_attrs = self.extract_response_message_attributes_for_tracing(
                choice.message, f"{attr_prefix}choices.{i}.message"
            )
            attrs.update(message_attrs)

        attrs[GEN_AI_RESPONSE_FINISH_REASONS] = finish_reasons

        return attrs


class OpenAICompletionTasks:
    @staticmethod
    @workflow_task
    @telemetry.traced()
    async def request_completion_task(
        request: RequestCompletionRequest,
    ) -> ChatCompletion:
        """
        Request a completion from OpenAI's API.
        """

        openai_client = OpenAI(
            api_key=request.config.api_key,
            base_url=request.config.base_url,
            http_client=request.config.http_client
            if hasattr(request.config, "http_client")
            else None,
            default_headers=request.config.default_headers
            if hasattr(request.config, "default_headers")
            else None,
        )

        payload = request.payload
        response = openai_client.chat.completions.create(**payload)
        response = ensure_serializable(response)
        return response

    @staticmethod
    @workflow_task
    @telemetry.traced()
    async def request_structured_completion_task(
        request: RequestStructuredCompletionRequest,
    ) -> ModelT:
        """
        Request a structured completion using Instructor's OpenAI API.
        """
        import instructor
        from instructor.exceptions import InstructorRetryException

        if request.response_model:
            response_model = request.response_model
        elif request.serialized_response_model:
            response_model = deserialize_model(request.serialized_response_model)
        else:
            raise ValueError(
                "Either response_model or serialized_response_model must be provided for structured completion."
            )

        # Next we pass the text through instructor to extract structured data
        client = instructor.from_openai(
            AsyncOpenAI(
                api_key=request.config.api_key,
                base_url=request.config.base_url,
                http_client=request.config.http_client
                if hasattr(request.config, "http_client")
                else None,
                default_headers=request.config.default_headers
                if hasattr(request.config, "default_headers")
                else None,
            ),
            mode=instructor.Mode.TOOLS_STRICT,
        )

        try:
            # Extract structured data from natural language
            structured_response = await client.chat.completions.create(
                model=request.model,
                response_model=response_model,
                messages=[
                    {"role": "user", "content": request.response_str},
                ],
                user=request.user,
            )
        except InstructorRetryException:
            # Retry the request with JSON mode
            client = instructor.from_openai(
                AsyncOpenAI(
                    api_key=request.config.api_key,
                    base_url=request.config.base_url,
                    http_client=request.config.http_client
                    if hasattr(request.config, "http_client")
                    else None,
                    default_headers=request.config.default_headers
                    if hasattr(request.config, "default_headers")
                    else None,
                ),
                mode=instructor.Mode.JSON,
            )

            structured_response = await client.chat.completions.create(
                model=request.model,
                response_model=response_model,
                messages=[
                    {"role": "user", "content": request.response_str},
                ],
                user=request.user,
            )

        return structured_response


class MCPOpenAITypeConverter(
    ProviderToMCPConverter[ChatCompletionMessageParam, ChatCompletionMessage]
):
    """
    Convert between OpenAI and MCP types.
    """

    @classmethod
    def from_mcp_message_result(cls, result: MCPMessageResult) -> ChatCompletionMessage:
        # MCPMessageResult -> ChatCompletionMessage
        if result.role != "assistant":
            raise ValueError(
                f"Expected role to be 'assistant' but got '{result.role}' instead."
            )

        return ChatCompletionMessage(
            role="assistant",
            content=result.content.text or str(result.context),
            # Lossy conversion for the following fields:
            # result.model
            # result.stopReason
        )

    @classmethod
    def to_mcp_message_result(cls, result: ChatCompletionMessage) -> MCPMessageResult:
        # ChatCompletionMessage -> MCPMessageResult
        return MCPMessageResult(
            role=result.role,
            content=TextContent(type="text", text=result.content),
            model="",
            stopReason=None,
            # extras for ChatCompletionMessage fields
            **result.model_dump(exclude={"role", "content"}),
        )

    @classmethod
    def from_mcp_message_param(
        cls, param: MCPMessageParam
    ) -> ChatCompletionMessageParam:
        # MCPMessageParam -> ChatCompletionMessageParam
        if param.role == "assistant":
            extras = param.model_dump(exclude={"role", "content"})
            return ChatCompletionAssistantMessageParam(
                role="assistant",
                content=[mcp_content_to_openai_content_part(param.content)],
                **extras,
            )
        elif param.role == "user":
            extras = param.model_dump(exclude={"role", "content"})
            return ChatCompletionUserMessageParam(
                role="user",
                content=[mcp_content_to_openai_content_part(param.content)],
                **extras,
            )
        else:
            raise ValueError(
                f"Unexpected role: {param.role}, MCP only supports 'assistant' and 'user'"
            )

    @classmethod
    def to_mcp_message_param(cls, param: ChatCompletionMessageParam) -> MCPMessageParam:
        # ChatCompletionMessage -> MCPMessageParam

        contents = openai_content_to_mcp_content(param.content)

        # TODO: saqadri - the mcp_content can have multiple elements
        # while sampling message content has a single content element
        # Right now we error out if there are > 1 elements in mcp_content
        # We need to handle this case properly going forward
        if len(contents) > 1:
            raise NotImplementedError(
                "Multiple content elements in a single message are not supported"
            )
        mcp_content: TextContent | ImageContent | EmbeddedResource = contents[0]

        if param.role == "assistant":
            return MCPMessageParam(
                role="assistant",
                content=mcp_content,
                **typed_dict_extras(param, ["role", "content"]),
            )
        elif param.role == "user":
            return MCPMessageParam(
                role="user",
                content=mcp_content,
                **typed_dict_extras(param, ["role", "content"]),
            )
        elif param.role == "tool":
            raise NotImplementedError(
                "Tool messages are not supported in SamplingMessage yet"
            )
        elif param.role == "system":
            raise NotImplementedError(
                "System messages are not supported in SamplingMessage yet"
            )
        elif param.role == "developer":
            raise NotImplementedError(
                "Developer messages are not supported in SamplingMessage yet"
            )
        elif param.role == "function":
            raise NotImplementedError(
                "Function messages are not supported in SamplingMessage yet"
            )
        else:
            raise ValueError(
                f"Unexpected role: {param.role}, MCP only supports 'assistant', 'user', 'tool', 'system', 'developer', and 'function'"
            )


def mcp_content_to_openai_content_part(
    content: TextContent | ImageContent | EmbeddedResource,
) -> ChatCompletionContentPartParam:
    if isinstance(content, TextContent):
        return ChatCompletionContentPartTextParam(type="text", text=content.text)
    elif isinstance(content, ImageContent):
        return ChatCompletionContentPartImageParam(
            type="image_url", image_url=f"data:{content.mimeType};base64,{content.data}"
        )
    elif isinstance(content, EmbeddedResource):
        if isinstance(content.resource, TextResourceContents):
            return ChatCompletionContentPartTextParam(
                type="text", text=content.resource.text
            )
        else:  # BlobResourceContents
            if content.resource.mimeType and content.resource.mimeType.startswith(
                "image/"
            ):
                return ChatCompletionContentPartImageParam(
                    type="image_url",
                    image_url=f"data:{content.resource.mimeType};base64,{content.resource.blob}",
                )
            else:
                # Best effort if mime type is unknown
                return ChatCompletionContentPartTextParam(
                    type="text",
                    text=f"{content.resource.mimeType}:{content.resource.blob}",
                )
    else:
        # Last effort to convert the content to a string
        return ChatCompletionContentPartTextParam(type="text", text=str(content))


def openai_content_to_mcp_content(
    content: str
    | Iterable[ChatCompletionContentPartParam | ChatCompletionContentPartRefusalParam],
) -> Iterable[TextContent | ImageContent | EmbeddedResource]:
    mcp_content = []

    if isinstance(content, str):
        mcp_content = [TextContent(type="text", text=content)]
    else:
        # TODO: saqadri - this is a best effort conversion, we should handle all possible content types
        for c in content:
            if (
                c["type"] == "text"
            ):  # isinstance(c, ChatCompletionContentPartTextParam):
                mcp_content.append(
                    TextContent(
                        type="text", text=c["text"], **typed_dict_extras(c, ["text"])
                    )
                )
            elif (
                c["type"] == "image_url"
            ):  # isinstance(c, ChatCompletionContentPartImageParam):
                if c["image_url"].startswith("data:"):
                    mime_type, base64_data = image_url_to_mime_and_base64(
                        c["image_url"]
                    )
                    mcp_content.append(
                        ImageContent(type="image", data=base64_data, mimeType=mime_type)
                    )
                else:
                    # TODO: saqadri - need to download the image into a base64-encoded string
                    raise NotImplementedError(
                        "Image content conversion not implemented"
                    )
            elif (
                c["type"] == "input_audio"
            ):  # isinstance(c, ChatCompletionContentPartInputAudioParam):
                raise NotImplementedError("Audio content conversion not implemented")
            elif (
                c["type"] == "refusal"
            ):  # isinstance(c, ChatCompletionContentPartRefusalParam):
                mcp_content.append(
                    TextContent(
                        type="text",
                        text=c["refusal"],
                        **typed_dict_extras(c, ["refusal"]),
                    )
                )
            else:
                raise ValueError(f"Unexpected content type: {c['type']}")

    return mcp_content
