"""
Streaming OpenAI-compatible provider for Inspect.

Motivation: Inspect's default OpenAI-compatible provider issues a single
non-streaming request and waits for the complete response. For backends that
support chunked streaming (e.g., vLLM, OpenAI), enabling streaming can reduce
latency to first token and keep long generations responsive. This subclass
opt-in enables ``stream=True`` and reassembles the streamed deltas into a
``ChatCompletion`` so the rest of Inspect remains unchanged.
"""

from typing import Any

from openai import BadRequestError, PermissionDeniedError, UnprocessableEntityError
from openai._types import NOT_GIVEN
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.completion_usage import CompletionUsage
from typing_extensions import override

from inspect_ai.model import (
    ChatMessage,
    GenerateConfig,
    ModelCall,
    ModelOutput,
    modelapi,
)
from inspect_ai.model._openai import (
    model_output_from_openai,
    openai_chat_messages,
    openai_chat_tool_choice,
    openai_chat_tools,
    openai_media_filter,
)
from inspect_ai.model._providers.openai_compatible import OpenAICompatibleAPI
from inspect_ai.model._providers.util.hooks import HttpxHooks
from inspect_ai.tool import ToolChoice, ToolInfo


class StreamOpenAICompatibleAPI(OpenAICompatibleAPI):
    """OpenAI-compatible provider that aggregates streamed chat deltas.

    The implementation mirrors the base class flow but flips on streaming and
    collects content/tool_calls/finish_reason across chunks into a final
    ``ChatCompletion`` object. Usage and return types match the base provider.
    """
    
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
    
    @override
    async def generate(
        self,
        input: list[ChatMessage],
        tools: list[ToolInfo],
        tool_choice: ToolChoice,
        config: GenerateConfig,
    ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
        """Generate a response using streaming and return standard Inspect types.

        Rationale for key choices:
        - We allocate a request id early so timing/headers are consistent with
          the base provider and visible in ``ModelCall``.
        - We set ``stream=True`` and then aggregate chunks into one completion
          to preserve downstream semantics (inspect expects a full object).
        - Some providers only include usage/system_fingerprint on the last
          chunk, so we pull metadata from the last available event.
        """
        # allocate request_id (so we can see it from ModelCall)
        request_id = self._http_hooks.start_request()

        # setup request and response for ModelCall
        request: dict[str, Any] = {}
        response: dict[str, Any] = {}

        def model_call() -> ModelCall:
            return ModelCall.create(
                request=request,
                response=response,
                filter=openai_media_filter,
                time=self._http_hooks.end_request(request_id),
            )

        tools, tool_choice, config = self.resolve_tools(tools, tool_choice, config)

        # get completion params
        completion_params = self.completion_params(
            config=config,
            tools=len(tools) > 0,
        )

        request = dict(
            messages=await openai_chat_messages(input),
            tools=openai_chat_tools(tools) if len(tools) > 0 else NOT_GIVEN,
            tool_choice=openai_chat_tool_choice(tool_choice)
            if len(tools) > 0
            else NOT_GIVEN,
            stream=True,
            extra_headers={HttpxHooks.REQUEST_ID_HEADER: request_id},
            **completion_params,
        )

        try:
            # Get streaming response
            stream = await self.client.chat.completions.create(**request)
            
            chunks = []
            full_content = ""
            finish_reason = None
            tool_calls = None
            
            # Collect all chunks
            async for chunk in stream:
                chunks.append(chunk)
                
                if chunk.choices and len(chunk.choices) > 0:
                    choice = chunk.choices[0]
                    if choice.delta and choice.delta.content:
                        full_content += choice.delta.content
                    if choice.delta and choice.delta.tool_calls:
                        tool_calls = choice.delta.tool_calls
                    if choice.finish_reason:
                        finish_reason = choice.finish_reason
            
            # Get the last chunk for metadata
            last_chunk = chunks[-1] if chunks else None
            
            completion = ChatCompletion(
                id=last_chunk.id if last_chunk else "stream-completion",
                object="chat.completion",
                created=last_chunk.created if last_chunk else 0,
                model=last_chunk.model if last_chunk else completion_params.get("model", "unknown"),
                choices=[
                    Choice(
                        index=0,
                        message=ChatCompletionMessage(
                            role="assistant",
                            content=full_content,
                            tool_calls=tool_calls
                        ),
                        finish_reason=finish_reason or "stop",
                        logprobs=None
                    )
                ],
                usage=CompletionUsage(
                    prompt_tokens=0,
                    completion_tokens=0,
                    total_tokens=0
                ) if not (last_chunk and hasattr(last_chunk, 'usage') and last_chunk.usage) else last_chunk.usage,
                system_fingerprint=last_chunk.system_fingerprint if last_chunk and hasattr(last_chunk, 'system_fingerprint') else None
            )
            
            response = completion.model_dump()
            self.on_response(response)

            # return output and call
            choices = self.chat_choices_from_completion(completion, tools)
            return model_output_from_openai(completion, choices), model_call()

        except (BadRequestError, UnprocessableEntityError, PermissionDeniedError) as ex:
            return self.handle_bad_request(ex), model_call()


# Register the streaming provider
@modelapi(name="stream-openai-api")
def stream_openai_api():
    """
    Factory function to create StreamOpenAICompatibleAPI instances.
    
    This provider can be used with model names like "stream-openai-api/service/model"
    where 'service' determines the API key and base URL environment variables.
    """
    return StreamOpenAICompatibleAPI
