#!/usr/bin/env python3
"""sem_a2a.py — A2A agent server exposing SEM tools as agent skills.

Each tool from the SEM platform (register_profile, update_profile, read_signifiers,
all_signifiers) is exposed as an A2A AgentSkill. The agent accepts natural language
or JSON input describing which skill to invoke and with what parameters.

Environment variables:
  SEM_BASE_URL                  SEM Flask app base URL (default: http://localhost:5000)
  SEM_HTTP_TIMEOUT_SECONDS      HTTP timeout for SEM calls (default: 30)
  SEM_HTTP_RETRY_ATTEMPTS       Retry attempts for signifier fetching (default: 2)
  SEM_HTTP_RETRY_BACKOFF_SECONDS Backoff between retries (default: 1.5)
  SEM_A2A_HOST                  Bind host for this server (default: 0.0.0.0)
  SEM_A2A_PORT                  Bind port for this server (default: 9998)
  SEM_A2A_PUBLIC_URL            Public URL written into agent card (default: http://127.0.0.1:9998/)
"""

import asyncio
import json
import os
import re
import socket
import sys
import threading
import time
import uuid
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple
from urllib.error import URLError
from urllib.parse import quote, urlencode, urlparse
from urllib.request import Request, urlopen

ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

import uvicorn
from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.apps import A2AStarletteApplication
from a2a.server.events import EventQueue
from a2a.server.request_handlers import DefaultRequestHandler
from a2a.server.tasks import InMemoryTaskStore
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
from a2a.utils import new_agent_text_message

from rdflib import BNode, Graph, Literal, URIRef

from signifier import HMAS, HCTL, HTTP, JS, TD

# ─── Configuration ─────────────────────────────────────────────────────────────

DEFAULT_SEM_BASE_URL = "http://localhost:5000"
SEM_BASE_URL = os.getenv("SEM_BASE_URL", DEFAULT_SEM_BASE_URL).rstrip("/")
SEM_HTTP_TIMEOUT_SECONDS = float(os.getenv("SEM_HTTP_TIMEOUT_SECONDS", "30"))
SEM_HTTP_RETRY_ATTEMPTS = max(int(os.getenv("SEM_HTTP_RETRY_ATTEMPTS", "2")), 1)
SEM_HTTP_RETRY_BACKOFF_SECONDS = float(os.getenv("SEM_HTTP_RETRY_BACKOFF_SECONDS", "1.5"))

HOST = os.getenv("SEM_A2A_HOST", "0.0.0.0")
PORT = int(os.getenv("SEM_A2A_PORT", "9998"))
PUBLIC_URL = os.getenv("SEM_A2A_PUBLIC_URL", f"http://127.0.0.1:{PORT}/").rstrip("/") + "/"

# ─── Tool registry ─────────────────────────────────────────────────────────────

_tool_registry: Dict[str, Dict[str, Any]] = {}
_registered_signifiers: Dict[str, str] = {}
_lock = threading.Lock()


def _register_tool(
    name: str,
    description: str,
    inputs: Dict[str, Any],
    handler: Callable[[Dict[str, Any]], Dict[str, Any]],
    native: bool = False,
) -> None:
    _tool_registry[name] = {
        "description": description,
        "inputs": inputs,
        "handler": handler,
        "native": native,
    }


def _reset_signifier_tools() -> None:
    """Remove all dynamically registered signifier tools; must be called under _lock."""
    for name in [n for n, e in _tool_registry.items() if not e["native"]]:
        del _tool_registry[name]
    _registered_signifiers.clear()


# ─── Shared utilities ──────────────────────────────────────────────────────────


def _slugify(value: str) -> str:
    slug = re.sub(r"[^a-zA-Z0-9_]+", "_", value).strip("_")
    return slug or "signifier"


def _collect_input_fields(
    schema: Dict[str, Any], path: Tuple[str, ...] = (), required: bool = True
) -> List[Any]:
    if not isinstance(schema, dict):
        return []
    if "const" in schema:
        return []

    json_type = schema.get("type")
    if json_type == "object":
        props = schema.get("properties", {}) or {}
        reqs = set(schema.get("required", []))
        leaves: List[Any] = []
        for name, prop_schema in props.items():
            leaves.extend(
                _collect_input_fields(
                    prop_schema if isinstance(prop_schema, dict) else {},
                    path + (name,),
                    name in reqs,
                )
            )
        return leaves

    if json_type == "array":
        items_schema = schema.get("items")
        min_items = schema.get("minItems")
        max_items = schema.get("maxItems")
        if isinstance(min_items, str) and min_items.isdigit():
            min_items = int(min_items)
        if isinstance(max_items, str) and max_items.isdigit():
            max_items = int(max_items)
        if isinstance(items_schema, dict) and min_items == max_items == 1:
            return _collect_input_fields(items_schema, path + ("0",), required)

    return [(path, schema, required)]


def _param_name(path: Tuple[str, ...]) -> str:
    return "__".join(path) if path else "arguments"


def _assign_nested(root: Any, path: Tuple[str, ...], value: Any) -> Any:
    if not path:
        return value
    head, *rest = path
    if head.isdigit():
        idx = int(head)
        arr = root if isinstance(root, list) else []
        while len(arr) <= idx:
            arr.append({})
        arr[idx] = _assign_nested(arr[idx], tuple(rest), value)
        return arr
    obj = root if isinstance(root, dict) else {}
    obj[head] = _assign_nested(obj.get(head, {}), tuple(rest), value)
    return obj


def _build_payload_from_arguments(field_specs: List[Any], arguments: Dict[str, Any]) -> Any:
    payload: Any = {}
    for path, leaf_schema, _ in field_specs:
        name = _param_name(path)
        if name not in arguments:
            continue
        value = _coerce_value_for_schema(arguments[name], leaf_schema, name)
        payload = _assign_nested(payload, path, value)
    return payload


def _is_schema_less_json_field(schema: Dict[str, Any]) -> bool:
    if not isinstance(schema, dict):
        return True
    if "const" in schema:
        return False
    return not any(
        k in schema for k in {"type", "properties", "items", "enum", "oneOf", "anyOf", "allOf", "$ref"}
    )


def _coerce_any_json_value(value: Any, param_name: str) -> Any:
    if isinstance(value, str):
        try:
            return json.loads(value.strip())
        except json.JSONDecodeError as exc:
            raise ValueError(f"Invalid JSON for parameter '{param_name}': {exc}") from exc
    if isinstance(value, (dict, list, int, float, bool)) or value is None:
        return value
    raise ValueError(
        f"Invalid JSON for parameter '{param_name}': unsupported type {type(value).__name__}"
    )


def _coerce_value_for_schema(value: Any, schema: Dict[str, Any], param_name: str) -> Any:
    if _is_schema_less_json_field(schema):
        return _coerce_any_json_value(value, param_name)

    json_type = schema.get("type") if isinstance(schema, dict) else None
    if isinstance(json_type, list):
        json_type = next((t for t in json_type if t != "null"), json_type[0] if json_type else None)

    if json_type == "array":
        if isinstance(value, str):
            try:
                value = json.loads(value.strip())
            except json.JSONDecodeError as exc:
                raise ValueError(f"Invalid JSON array for parameter '{param_name}': {exc}") from exc
        if not isinstance(value, list):
            raise ValueError(f"Invalid value for parameter '{param_name}': expected a JSON array")
        return value

    if json_type == "object":
        if isinstance(value, str):
            try:
                value = json.loads(value.strip())
            except json.JSONDecodeError as exc:
                raise ValueError(f"Invalid JSON object for parameter '{param_name}': {exc}") from exc
        if not isinstance(value, dict):
            raise ValueError(f"Invalid value for parameter '{param_name}': expected a JSON object")
        return value

    return value


def _extract_const_structure(schema: Dict[str, Any]) -> Any:
    if not isinstance(schema, dict):
        return None
    if "const" in schema:
        return schema["const"]
    if schema.get("type") == "object":
        collected: Dict[str, Any] = {}
        for name, prop_schema in (schema.get("properties") or {}).items():
            v = _extract_const_structure(prop_schema if isinstance(prop_schema, dict) else {})
            if v is not None:
                collected[name] = v
        return collected or None
    return None


def _deep_merge(base: Any, override: Any) -> Any:
    if isinstance(base, dict) and isinstance(override, dict):
        merged = dict(base)
        for k, v in override.items():
            merged[k] = _deep_merge(merged.get(k), v)
        return merged
    return override if override is not None else base


def _merge_const_defaults(schema: Dict[str, Any], data: Any) -> Any:
    defaults = _extract_const_structure(schema) if isinstance(schema, dict) else None
    return _deep_merge(defaults or {}, data or {})


def _perform_http_request(
    target: str,
    headers: Dict[str, str],
    content_type: str,
    payload: Any,
    method: str = "POST",
) -> Dict[str, Any]:
    request_headers = {"Content-Type": content_type}
    request_headers.update(headers)
    if content_type.lower().startswith("text/"):
        body = payload if isinstance(payload, bytes) else str(payload).encode("utf-8")
    else:
        body = json.dumps(payload).encode("utf-8")

    req = Request(target, data=body, headers=request_headers, method=method)
    try:
        with urlopen(req, timeout=SEM_HTTP_TIMEOUT_SECONDS) as resp:
            return {
                "status": resp.status,
                "headers": dict(resp.headers),
                "body": resp.read().decode("utf-8"),
            }
    except Exception as exc:
        return {"error": str(exc)}


# ─── Signifier fetching ────────────────────────────────────────────────────────


def _fetch_signifiers(profile_url: str) -> Tuple[Graph, List[URIRef]]:
    def _fetch_by_url(url: str) -> Tuple[Graph, List[URIRef]]:
        req = Request(url, headers={"Accept": "application/ld+json"})
        with urlopen(req, timeout=SEM_HTTP_TIMEOUT_SECONDS) as resp:
            body = resp.read().decode("utf-8")
        graph = Graph()
        graph.parse(data=body, format="json-ld", publicID=url)
        return graph, [s for s in graph.subjects() if isinstance(s, URIRef)]

    def _is_timeout(exc: Exception) -> bool:
        if isinstance(exc, (TimeoutError, socket.timeout)):
            return True
        if isinstance(exc, URLError) and isinstance(exc.reason, (TimeoutError, socket.timeout)):
            return True
        return "timed out" in str(exc).lower()

    query = urlencode({"profile": profile_url})
    filtered_url = f"{SEM_BASE_URL}/signifiers?{query}"

    last_error: Exception | None = None
    for attempt in range(1, SEM_HTTP_RETRY_ATTEMPTS + 1):
        try:
            return _fetch_by_url(filtered_url)
        except Exception as exc:
            last_error = exc
            if not _is_timeout(exc) or attempt >= SEM_HTTP_RETRY_ATTEMPTS:
                break
            time.sleep(SEM_HTTP_RETRY_BACKOFF_SECONDS * attempt)

    if last_error is not None and _is_timeout(last_error):
        try:
            return _fetch_by_url(f"{SEM_BASE_URL}/signifiers")
        except Exception as fallback_exc:
            raise RuntimeError(
                f"Timed out reading filtered signifiers and fallback failed "
                f"(filtered: {last_error}; fallback: {fallback_exc})"
            ) from fallback_exc

    if last_error is not None:
        raise last_error
    raise RuntimeError("Unexpected error while reading signifiers")


def _fetch_all_signifiers() -> Tuple[Graph, List[URIRef]]:
    req = Request(f"{SEM_BASE_URL}/signifiers/list", headers={"Accept": "application/json"})
    with urlopen(req, timeout=SEM_HTTP_TIMEOUT_SECONDS) as resp:
        payload = json.loads(resp.read().decode("utf-8"))
    signifier_urls = payload.get("signifiers", [])
    if not isinstance(signifier_urls, list):
        raise RuntimeError("Invalid /signifiers/list response: 'signifiers' must be a list")

    graph = Graph()
    signifiers: List[URIRef] = []
    for url in signifier_urls:
        if not isinstance(url, str) or not url.strip():
            continue
        url = url.strip()
        req = Request(url, headers={"Accept": "application/ld+json"})
        with urlopen(req, timeout=SEM_HTTP_TIMEOUT_SECONDS) as resp:
            graph.parse(data=resp.read().decode("utf-8"), format="json-ld", publicID=url)
        signifiers.append(URIRef(url))

    return graph, signifiers


# ─── Native tool handlers ──────────────────────────────────────────────────────


def _handler_register_profile(arguments: Dict[str, Any]) -> Dict[str, Any]:
    profile_id = arguments.get("profile_id", "")
    encoded = quote(profile_id, safe="/")
    profile_uri = URIRef(f"{SEM_BASE_URL}/profile/{encoded}")
    g = Graph()
    ctx = BNode()
    g.add((profile_uri, HMAS["hasContext"], ctx))
    g.add((ctx, RDFS.comment, Literal("")))
    return _perform_http_request(
        f"{SEM_BASE_URL}/profile/{encoded}", {}, "text/turtle", g.serialize(format="turtle"), method="PUT"
    )


def _handler_update_profile(arguments: Dict[str, Any]) -> Dict[str, Any]:
    profile_id = arguments.get("profile_id", "")
    nl_context = arguments.get("nl_context", "")
    encoded = quote(profile_id, safe="/")
    return _perform_http_request(
        f"{SEM_BASE_URL}/profile/{encoded}/nl_context",
        {},
        "application/json",
        {"context": nl_context},
        method="PUT",
    )


def _handler_read_signifiers(arguments: Dict[str, Any]) -> Dict[str, Any]:
    profile_url = arguments.get("profile_url", "")
    try:
        graph, signifiers = _fetch_signifiers(profile_url)
    except Exception as exc:
        return {"error": f"Failed to read signifiers: {exc}"}
    return {"status": "ok", "signifiers_count": len(signifiers)}


def _handler_all_signifiers(_arguments: Dict[str, Any]) -> Dict[str, Any]:
    try:
        graph, signifiers = _fetch_all_signifiers()
    except Exception as exc:
        return {"error": f"Failed to read all signifiers: {exc}"}
    return {"status": "ok", "signifiers_count": len(signifiers)}


# ─── Skill invocation logic ────────────────────────────────────────────────────


def _parse_user_input(text: str) -> Tuple[Optional[str], Dict[str, Any]]:
    """Parse user input to extract skill name and parameters.

    Accepts formats like:
    - "register_profile executor"
    - "register_profile profile_id=executor"
    - {"skill": "register_profile", "profile_id": "executor"}
    - JSON string: '{"skill": "register_profile", "profile_id": "executor"}'
    """
    text = (text or "").strip()
    if not text:
        return None, {}

    try:
        payload = json.loads(text)
        if isinstance(payload, dict):
            skill = payload.pop("skill", None)
            return skill, payload
    except (json.JSONDecodeError, ValueError):
        pass

    parts = text.split(None, 1)
    skill_name = parts[0] if parts else None
    param_text = parts[1] if len(parts) > 1 else ""

    arguments: Dict[str, Any] = {}
    if param_text:
        for pair in param_text.split():
            if "=" in pair:
                key, val = pair.split("=", 1)
                arguments[key] = val
            else:
                arguments[pair] = True

    return skill_name, arguments


async def _invoke_skill(skill_name: Optional[str], arguments: Dict[str, Any]) -> str:
    """Invoke a registered skill and return the result as a string."""
    if not skill_name:
        return "Error: No skill specified. Provide skill name as the first argument."

    with _lock:
        entry = _tool_registry.get(skill_name)

    if not entry:
        with _lock:
            available = ", ".join(sorted(_tool_registry.keys()))
        return f"Error: Skill '{skill_name}' not found. Available skills: {available}"

    try:
        result = entry["handler"](arguments)
    except Exception as exc:
        return f"Error invoking skill '{skill_name}': {exc}"

    return json.dumps(result, indent=2)


# ─── A2A Agent Executor ────────────────────────────────────────────────────────


class SemA2AExecutor(AgentExecutor):
    """A2A executor that invokes SEM skills based on user input."""

    async def execute(self, context: RequestContext, event_queue: EventQueue) -> None:
        user_text = context.get_user_input() or ""
        skill_name, arguments = _parse_user_input(user_text)
        result = await _invoke_skill(skill_name, arguments)
        await event_queue.enqueue_event(new_agent_text_message(result))

    async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None:
        raise Exception("cancel not supported")


# ─── Bootstrap native tools and agent card ─────────────────────────────────────


def _bootstrap_native_tools() -> None:
    _register_tool(
        "register_profile",
        "Register a profile with an empty natural-language context.",
        {
            "type": "object",
            "properties": {
                "profile_id": {
                    "type": "string",
                    "description": "Profile identifier to register (e.g., executor).",
                }
            },
            "required": ["profile_id"],
        },
        _handler_register_profile,
        native=True,
    )
    _register_tool(
        "update_profile",
        "Update the natural-language context for the given profile.",
        {
            "type": "object",
            "properties": {
                "profile_id": {
                    "type": "string",
                    "description": "Profile identifier to update (e.g., executor).",
                },
                "nl_context": {
                    "type": "string",
                    "description": "Natural-language context to store for the profile.",
                },
            },
            "required": ["profile_id", "nl_context"],
        },
        _handler_update_profile,
        native=True,
    )
    _register_tool(
        "read_signifiers",
        "Read signifiers relevant to the given profile URL.",
        {
            "type": "object",
            "properties": {
                "profile_url": {
                    "type": "string",
                    "description": "Full profile URL to query for signifiers.",
                }
            },
            "required": ["profile_url"],
        },
        _handler_read_signifiers,
        native=True,
    )
    _register_tool(
        "all_signifiers",
        "Read all signifiers exposed by SEM.",
        {"type": "object", "properties": {}},
        _handler_all_signifiers,
        native=True,
    )


def _build_agent_card() -> AgentCard:
    """Build an A2A agent card listing all registered skills."""
    with _lock:
        skills = [
            AgentSkill(
                id=skill_name,
                name=skill_name.replace("_", " ").title(),
                description=entry["description"],
                tags=["sem", "tool"],
            )
            for skill_name, entry in _tool_registry.items()
        ]

    return AgentCard(
        name="SEM A2A Agent",
        description="Signifier Exposure Mechanism (SEM) tools exposed as A2A agent skills.",
        url=PUBLIC_URL,
        version="1.0.0",
        default_input_modes=["text"],
        default_output_modes=["text"],
        capabilities=AgentCapabilities(streaming=False),
        skills=skills,
        supports_authenticated_extended_card=False,
    )


if __name__ == "__main__":
    _bootstrap_native_tools()

    agent_card = _build_agent_card()
    request_handler = DefaultRequestHandler(
        agent_executor=SemA2AExecutor(),
        task_store=InMemoryTaskStore(),
    )

    server = A2AStarletteApplication(
        agent_card=agent_card,
        http_handler=request_handler,
    )

    uvicorn.run(server.build(), host=HOST, port=PORT)
