from __future__ import annotations

import argparse
import json
import os
import time
import traceback
from http import HTTPStatus
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from typing import Any, Dict, List, Tuple
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen


def _json_dumps(obj: Any) -> bytes:
    return json.dumps(obj, ensure_ascii=False).encode("utf-8")


def _read_json_body(handler: BaseHTTPRequestHandler) -> dict[str, Any]:
    length = int(handler.headers.get("content-length", "0"))
    raw = handler.rfile.read(length) if length else b""
    if not raw:
        return {}
    return json.loads(raw.decode("utf-8"))


def _as_text(content: Any) -> str:
    if content is None:
        return ""
    if isinstance(content, str):
        return content
    if isinstance(content, list):
        parts: list[str] = []
        for part in content:
            if isinstance(part, str):
                parts.append(part)
            elif isinstance(part, dict) and part.get("type") in {"text", "input_text", "output_text"}:
                text = part.get("text")
                if isinstance(text, str):
                    parts.append(text)
        return "\n".join(p for p in parts if p)
    return str(content)


def _infer_provider(model: str) -> str:
    forced = (os.getenv("ROUTER_PROVIDER") or "").strip().lower()
    if forced:
        return forced
    m = (model or "").lower()
    if m.startswith("claude"):
        return "anthropic"
    if m.startswith("gemini"):
        return "gemini"
    return "anthropic"


def _openai_tools_to_anthropic(tools: list[dict[str, Any]] | None) -> list[dict[str, Any]] | None:
    if not tools:
        return None
    converted: list[dict[str, Any]] = []
    for tool in tools:
        if tool.get("type") != "function":
            continue
        fn = tool.get("function") or {}
        name = fn.get("name")
        if not name:
            continue
        converted.append(
            {
                "name": name,
                "description": fn.get("description") or "",
                "input_schema": fn.get("parameters") or {"type": "object", "properties": {}},
            }
        )
    return converted or None


def _openai_tools_to_gemini(tools: list[dict[str, Any]] | None) -> list[dict[str, Any]] | None:
    if not tools:
        return None
    decls: list[dict[str, Any]] = []
    for tool in tools:
        if tool.get("type") != "function":
            continue
        fn = tool.get("function") or {}
        name = fn.get("name")
        if not name:
            continue
        decls.append(
            {
                "name": name,
                "description": fn.get("description") or "",
                "parameters": fn.get("parameters") or {"type": "object", "properties": {}},
            }
        )
    if not decls:
        return None
    return [{"functionDeclarations": decls}]


def _openai_messages_to_anthropic(messages: list[dict[str, Any]]) -> tuple[str | None, list[dict[str, Any]]]:
    system_parts: list[str] = []
    anthropic: list[dict[str, Any]] = []

    for msg in messages:
        role = msg.get("role")
        content = msg.get("content")

        if role == "system":
            system_parts.append(_as_text(content))
            continue

        if role in {"user", "assistant"}:
            blocks: list[dict[str, Any]] = []
            text = _as_text(content)
            if text:
                blocks.append({"type": "text", "text": text})

            if role == "assistant" and msg.get("tool_calls"):
                for call in msg["tool_calls"]:
                    fn = call.get("function") or {}
                    args_raw = fn.get("arguments") or "{}"
                    try:
                        args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
                    except json.JSONDecodeError:
                        args = {"_raw": args_raw}
                    blocks.append(
                        {
                            "type": "tool_use",
                            "id": call.get("id") or f"toolu_{int(time.time()*1000)}",
                            "name": fn.get("name") or "unknown",
                            "input": args,
                        }
                    )
            anthropic.append({"role": role, "content": blocks or [{"type": "text", "text": ""}]})
            continue

        if role == "tool":
            tool_call_id = msg.get("tool_call_id") or "unknown_tool_call"
            tool_content = _as_text(content)
            anthropic.append(
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "tool_result",
                            "tool_use_id": tool_call_id,
                            "content": tool_content,
                        }
                    ],
                }
            )
            continue

    system = "\n\n".join(s for s in system_parts if s).strip() or None
    return system, anthropic


def _anthropic_to_openai_message(anthropic_resp: dict[str, Any]) -> dict[str, Any]:
    content_blocks = anthropic_resp.get("content") or []
    text_parts: list[str] = []
    tool_calls: list[dict[str, Any]] = []

    for block in content_blocks:
        if block.get("type") == "text":
            t = block.get("text")
            if isinstance(t, str) and t:
                text_parts.append(t)
        elif block.get("type") == "tool_use":
            tool_id = block.get("id") or f"toolu_{int(time.time()*1000)}"
            name = block.get("name") or "unknown"
            inp = block.get("input") or {}
            tool_calls.append(
                {
                    "id": tool_id,
                    "type": "function",
                    "function": {"name": name, "arguments": json.dumps(inp, ensure_ascii=False)},
                }
            )

    msg: dict[str, Any] = {"role": "assistant", "content": "\n".join(text_parts) if text_parts else ""}
    if tool_calls:
        msg["tool_calls"] = tool_calls
    return msg


def _openai_messages_to_gemini(messages: list[dict[str, Any]]) -> tuple[dict[str, Any] | None, list[dict[str, Any]]]:
    system_parts: list[str] = []
    contents: list[dict[str, Any]] = []
    tool_id_to_name: dict[str, str] = {}

    for msg in messages:
        role = msg.get("role")
        content = msg.get("content")

        if role == "system":
            system_parts.append(_as_text(content))
            continue

        if role == "assistant" and msg.get("tool_calls"):
            for call in msg["tool_calls"]:
                if isinstance(call, dict):
                    tool_id = call.get("id")
                    fn = call.get("function") or {}
                    if tool_id and fn.get("name"):
                        tool_id_to_name[tool_id] = fn["name"]

        if role in {"user", "assistant"}:
            gemini_role = "user" if role == "user" else "model"
            parts: list[dict[str, Any]] = []
            text = _as_text(content)
            if text:
                parts.append({"text": text})
            if role == "assistant" and msg.get("tool_calls"):
                for call in msg["tool_calls"]:
                    fn = call.get("function") or {}
                    args_raw = fn.get("arguments") or "{}"
                    try:
                        args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw
                    except json.JSONDecodeError:
                        args = {"_raw": args_raw}
                    parts.append({"functionCall": {"name": fn.get("name") or "unknown", "args": args}})
            contents.append({"role": gemini_role, "parts": parts or [{"text": ""}]})
            continue

        if role == "tool":
            tool_call_id = msg.get("tool_call_id") or "unknown_tool_call"
            name = tool_id_to_name.get(tool_call_id, "unknown")
            tool_content = _as_text(content)
            contents.append(
                {
                    "role": "user",
                    "parts": [
                        {
                            "functionResponse": {
                                "name": name,
                                "response": {"content": tool_content},
                            }
                        }
                    ],
                }
            )
            continue

    system_text = "\n\n".join(s for s in system_parts if s).strip()
    system = {"parts": [{"text": system_text}]} if system_text else None
    return system, contents


def _gemini_to_openai_message(gemini_resp: dict[str, Any]) -> dict[str, Any]:
    candidates = gemini_resp.get("candidates") or []
    if not candidates:
        return {"role": "assistant", "content": ""}
    content = (candidates[0].get("content") or {}).get("parts") or []

    text_parts: list[str] = []
    tool_calls: list[dict[str, Any]] = []
    for part in content:
        if "text" in part and isinstance(part["text"], str) and part["text"]:
            text_parts.append(part["text"])
        if "functionCall" in part and isinstance(part["functionCall"], dict):
            fc = part["functionCall"]
            name = fc.get("name") or "unknown"
            args = fc.get("args") or {}
            tool_calls.append(
                {
                    "id": f"call_{int(time.time()*1000)}_{len(tool_calls)}",
                    "type": "function",
                    "function": {"name": name, "arguments": json.dumps(args, ensure_ascii=False)},
                }
            )

    msg: dict[str, Any] = {"role": "assistant", "content": "\n".join(text_parts) if text_parts else ""}
    if tool_calls:
        msg["tool_calls"] = tool_calls
    return msg


def _http_json(url: str, headers: dict[str, str], payload: dict[str, Any]) -> tuple[int, dict[str, Any], dict[str, str]]:
    req = Request(url, data=_json_dumps(payload), headers=headers, method="POST")
    try:
        with urlopen(req, timeout=300) as resp:
            body = resp.read().decode("utf-8")
            return resp.status, json.loads(body), dict(resp.headers.items())
    except HTTPError as e:
        body = e.read().decode("utf-8", errors="replace")
        try:
            parsed = json.loads(body)
        except json.JSONDecodeError:
            parsed = {"error": body}
        return e.code, parsed, dict(e.headers.items())
    except URLError as e:
        return 599, {"error": str(e)}, {}


def _handle_chat_completions(payload: dict[str, Any]) -> tuple[int, dict[str, Any]]:
    model = payload.get("model") or os.getenv("OPENAI_MODEL") or ""
    provider = _infer_provider(model)
    messages = payload.get("messages") or []
    tools = payload.get("tools")
    temperature = payload.get("temperature")
    max_tokens = (
        payload.get("max_tokens")
        or payload.get("max_completion_tokens")
        or payload.get("max_output_tokens")
        or 2048
    )

    if provider == "anthropic":
        key = os.getenv("ANTHROPIC_API_KEY")
        if not key:
            return HTTPStatus.BAD_REQUEST, {"error": {"message": "Missing ANTHROPIC_API_KEY"}}

        system, anthropic_messages = _openai_messages_to_anthropic(messages)
        body: dict[str, Any] = {
            "model": model,
            "max_tokens": int(max_tokens),
            "messages": anthropic_messages,
        }
        if system:
            body["system"] = system
        if temperature is not None:
            body["temperature"] = temperature
        anth_tools = _openai_tools_to_anthropic(tools)
        if anth_tools:
            body["tools"] = anth_tools

        status, resp_json, _ = _http_json(
            "https://api.anthropic.com/v1/messages",
            headers={
                "content-type": "application/json",
                "x-api-key": key,
                "anthropic-version": os.getenv("ANTHROPIC_VERSION", "2023-06-01"),
            },
            payload=body,
        )
        if status >= 400:
            return status, {"error": resp_json}

        msg = _anthropic_to_openai_message(resp_json)
        finish_reason = resp_json.get("stop_reason") or "stop"
        if finish_reason == "tool_use":
            finish_reason = "tool_calls"
        return HTTPStatus.OK, {
            "id": f"chatcmpl_router_{int(time.time()*1000)}",
            "object": "chat.completion",
            "created": int(time.time()),
            "model": model,
            "choices": [{"index": 0, "message": msg, "finish_reason": finish_reason}],
        }

    if provider == "gemini":
        key = os.getenv("GOOGLE_API_KEY")
        if not key:
            return HTTPStatus.BAD_REQUEST, {"error": {"message": "Missing GOOGLE_API_KEY"}}

        system, contents = _openai_messages_to_gemini(messages)
        body: dict[str, Any] = {"contents": contents}
        if system:
            body["systemInstruction"] = system
        if temperature is not None or max_tokens is not None:
            gen_cfg: dict[str, Any] = {}
            if temperature is not None:
                gen_cfg["temperature"] = temperature
            if max_tokens is not None:
                gen_cfg["maxOutputTokens"] = int(max_tokens)
            body["generationConfig"] = gen_cfg
        gem_tools = _openai_tools_to_gemini(tools)
        if gem_tools:
            body["tools"] = gem_tools

        url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={key}"
        status, resp_json, _ = _http_json(
            url,
            headers={"content-type": "application/json"},
            payload=body,
        )
        if status >= 400:
            return status, {"error": resp_json}

        msg = _gemini_to_openai_message(resp_json)
        return HTTPStatus.OK, {
            "id": f"chatcmpl_router_{int(time.time()*1000)}",
            "object": "chat.completion",
            "created": int(time.time()),
            "model": model,
            "choices": [{"index": 0, "message": msg, "finish_reason": "stop"}],
        }

    return HTTPStatus.BAD_REQUEST, {"error": {"message": f"Unsupported provider: {provider}"}}


def _responses_input_to_chat_messages(
    payload: dict[str, Any],
) -> list[dict[str, Any]]:
    if payload.get("messages"):
        messages = payload["messages"]
        return messages if isinstance(messages, list) else []

    input_obj = payload.get("input")
    if input_obj is None:
        return []
    if isinstance(input_obj, str):
        return [{"role": "user", "content": input_obj}]
    if not isinstance(input_obj, list):
        return [{"role": "user", "content": _as_text(input_obj)}]

    out: list[dict[str, Any]] = []
    for item in input_obj:
        if isinstance(item, dict) and "role" in item and "content" in item:
            out.append({"role": item["role"], "content": item["content"]})
            continue

        if isinstance(item, dict) and item.get("type") == "function_call_output":
            call_id = item.get("call_id") or item.get("tool_call_id") or "unknown_tool_call"
            output = item.get("output")
            out.append({"role": "tool", "tool_call_id": call_id, "content": _as_text(output)})
            continue

        if isinstance(item, dict) and item.get("type") == "message":
            role = item.get("role") or "user"
            content = item.get("content")
            out.append({"role": role, "content": _as_text(content)})
            continue

        out.append({"role": "user", "content": _as_text(item)})

    return out


def _chat_completion_to_responses(
    chat_resp: dict[str, Any],
    *,
    model: str,
) -> dict[str, Any]:
    created = int(time.time())
    choice = (chat_resp.get("choices") or [{"message": {"role": "assistant", "content": ""}}])[0]
    msg = choice.get("message") or {"role": "assistant", "content": ""}
    text = _as_text(msg.get("content"))
    tool_calls = msg.get("tool_calls") or []

    output: list[dict[str, Any]] = []
    if text:
        output.append(
            {
                "id": f"msg_{created}",
                "type": "message",
                "role": "assistant",
                "content": [{"type": "output_text", "text": text}],
            }
        )

    for tc in tool_calls:
        fn = (tc or {}).get("function") or {}
        name = fn.get("name") or "unknown"
        arguments = fn.get("arguments") or "{}"
        call_id = (tc or {}).get("id") or f"call_{created}"
        output.append(
            {
                "id": f"fc_{call_id}",
                "type": "function_call",
                "call_id": call_id,
                "name": name,
                "arguments": arguments,
            }
        )

    return {
        "id": f"resp_router_{created}",
        "object": "response",
        "created_at": created,
        "model": model,
        "status": "completed",
        "output": output,
    }


def _handle_responses(payload: dict[str, Any]) -> tuple[int, dict[str, Any]]:
    model = payload.get("model") or os.getenv("OPENAI_MODEL") or ""
    instructions = payload.get("instructions")
    tools = payload.get("tools")
    temperature = payload.get("temperature")
    max_output_tokens = payload.get("max_output_tokens") or payload.get("max_tokens")

    messages = _responses_input_to_chat_messages(payload)
    if instructions:
        messages = [{"role": "system", "content": _as_text(instructions)}] + messages

    chat_payload: dict[str, Any] = {
        "model": model,
        "messages": messages,
        "tools": tools,
        "temperature": temperature,
    }
    if max_output_tokens is not None:
        chat_payload["max_output_tokens"] = max_output_tokens

    status, chat_resp = _handle_chat_completions(chat_payload)
    if status != HTTPStatus.OK:
        return status, chat_resp

    return HTTPStatus.OK, _chat_completion_to_responses(chat_resp, model=model)


class Handler(BaseHTTPRequestHandler):
    server_version = "openai-router/0.1"

    def _send(self, status: int, payload: dict[str, Any], *, content_type: str = "application/json") -> None:
        body = _json_dumps(payload)
        self.send_response(status)
        self.send_header("content-type", content_type + "; charset=utf-8")
        self.send_header("content-length", str(len(body)))
        self.end_headers()
        self.wfile.write(body)

    def _send_sse(self, events: list[dict[str, Any]]) -> None:
        self.send_response(HTTPStatus.OK)
        self.send_header("content-type", "text/event-stream; charset=utf-8")
        self.send_header("cache-control", "no-cache")
        self.end_headers()
        for ev in events:
            data = json.dumps(ev, ensure_ascii=False)
            self.wfile.write(f"data: {data}\n\n".encode("utf-8"))
        self.wfile.write(b"data: [DONE]\n\n")

    def _send_sse_named(self, events: list[tuple[str, dict[str, Any]]]) -> None:
        self.send_response(HTTPStatus.OK)
        self.send_header("content-type", "text/event-stream; charset=utf-8")
        self.send_header("cache-control", "no-cache")
        self.end_headers()
        for name, ev in events:
            data = json.dumps(ev, ensure_ascii=False)
            self.wfile.write(f"event: {name}\n".encode("utf-8"))
            self.wfile.write(f"data: {data}\n\n".encode("utf-8"))

    def do_GET(self) -> None:  # noqa: N802
        if self.path in {"/health", "/healthz"}:
            self._send(HTTPStatus.OK, {"ok": True})
            return
        if self.path.rstrip("/") == "/v1/models":
            self._send(
                HTTPStatus.OK,
                {
                    "object": "list",
                    "data": [
                        {"id": "claude-3-5-sonnet-20241022", "object": "model"},
                        {"id": "gemini-1.5-pro", "object": "model"},
                    ],
                },
            )
            return
        self._send(HTTPStatus.NOT_FOUND, {"error": {"message": f"Not found: {self.path}"}})

    def do_POST(self) -> None:  # noqa: N802
        try:
            payload = _read_json_body(self)
            if self.path.rstrip("/") == "/v1/chat/completions":
                want_stream = bool(payload.get("stream"))
                status, resp = _handle_chat_completions(payload)
                if want_stream and status == HTTPStatus.OK:
                    chunk = {
                        "id": resp.get("id"),
                        "object": "chat.completion.chunk",
                        "created": resp.get("created"),
                        "model": resp.get("model"),
                        "choices": [
                            {
                                "index": 0,
                                "delta": resp["choices"][0]["message"],
                                "finish_reason": resp["choices"][0].get("finish_reason"),
                            }
                        ],
                    }
                    self._send_sse([chunk])
                    return
                self._send(status, resp)
                return
            if self.path.rstrip("/") == "/v1/responses":
                want_stream = bool(payload.get("stream"))
                status, resp = _handle_responses(payload)
                if want_stream and status == HTTPStatus.OK:
                    out_text = ""
                    for item in resp.get("output") or []:
                        if item.get("type") == "message":
                            content = item.get("content") or []
                            for part in content:
                                if part.get("type") == "output_text":
                                    out_text += part.get("text") or ""
                    created = resp.get("created_at") or int(time.time())
                    self._send_sse_named(
                        [
                            ("response.created", {"type": "response.created", "response": {"id": resp.get("id"), "model": resp.get("model"), "created_at": created}}),
                            ("response.output_text.delta", {"type": "response.output_text.delta", "delta": out_text}),
                            ("response.completed", {"type": "response.completed", "response": resp}),
                        ]
                    )
                    return
                self._send(status, resp)
                return
            self._send(HTTPStatus.NOT_FOUND, {"error": {"message": f"Not found: {self.path}"}})
        except Exception as e:  # noqa: BLE001
            tb = traceback.format_exc()
            self._send(
                HTTPStatus.INTERNAL_SERVER_ERROR,
                {"error": {"message": str(e), "traceback": tb}},
            )


def main() -> None:
    ap = argparse.ArgumentParser(description="Minimal OpenAI-compatible router for Claude/Gemini.")
    ap.add_argument("--host", default=os.getenv("ROUTER_HOST", "127.0.0.1"))
    ap.add_argument("--port", default=int(os.getenv("ROUTER_PORT", "8090")), type=int)
    args = ap.parse_args()

    httpd = ThreadingHTTPServer((args.host, args.port), Handler)
    print(f"[router] listening on http://{args.host}:{args.port}")
    httpd.serve_forever()


if __name__ == "__main__":
    main()
