#!/usr/bin/env python3
from __future__ import annotations

import argparse
import socket
import sys
import threading
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path

import tensorboard
from tensorboard import program


SCRIPT_DIR = Path(__file__).resolve().parent
DEFAULT_LOGDIR = SCRIPT_DIR / "output" / "meal_projector"

PROJECTOR_JS_RELATIVE_PATH = Path(
    "plugins/projector/tf_projector_plugin/projector_binary.js"
)
PCA_DEFAULT_SNIPPET = 'this.$$("#pca-sampling").style.display=s?null:"none",this.showTab("pca")'
UMAP_DEFAULT_SNIPPET = 'this.$$("#pca-sampling").style.display=s?null:"none",this.showTab("umap")'


class _ProjectorRedirectHandler(BaseHTTPRequestHandler):
    """Redirects every request to the direct TensorBoard projector URL."""

    target: str = ""

    def do_GET(self) -> None:
        self.send_response(302)
        self.send_header("Location", self.target)
        self.end_headers()

    # Suppress access log noise in the terminal.
    def log_message(self, *args: object) -> None:  # type: ignore[override]
        pass


def _start_redirect_server(host: str, port: int, target: str) -> None:
    _ProjectorRedirectHandler.target = target
    HTTPServer((host, port), _ProjectorRedirectHandler).serve_forever()


def _port_in_use(host: str, port: int) -> bool:
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
        sock.settimeout(0.5)
        return sock.connect_ex((host, port)) == 0


def _ensure_pca_default_projection() -> bool:
    """Ensure PCA is the default projection (keeps TensorBoard's native default).
    
    The TensorBoard projector will center the data around its centroid automatically,
    so the origin of rotation will be at the center of the data.
    """
    tensorboard_package_dir = Path(tensorboard.__file__).resolve().parent
    projector_js_path = tensorboard_package_dir / PROJECTOR_JS_RELATIVE_PATH
    if not projector_js_path.exists():
        return False

    source = projector_js_path.read_text(encoding="utf-8")
    
    # If UMAP patching was previously applied, revert it back to PCA
    if UMAP_DEFAULT_SNIPPET in source:
        updated = source.replace(UMAP_DEFAULT_SNIPPET, PCA_DEFAULT_SNIPPET, 1)
        projector_js_path.write_text(updated, encoding="utf-8")
        return True
    
    # PCA is already the default
    return True


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="Launch TensorBoard so the meal embeddings can be explored in the Embedding Projector."
    )
    parser.add_argument(
        "--logdir",
        default=str(DEFAULT_LOGDIR),
        help="Directory containing projector_config.pbtxt and TSV files.",
    )
    parser.add_argument(
        "--host",
        default="127.0.0.1",
        help="Host interface for TensorBoard.",
    )
    parser.add_argument(
        "--port",
        type=int,
        default=6006,
        help="Port for TensorBoard.",
    )
    parser.add_argument(
        "--redirect-port",
        type=int,
        default=None,
        help="Port for the short redirect URL (default: TensorBoard port + 1).",
    )
    return parser


def main() -> int:
    parser = build_parser()
    args = parser.parse_args()

    if not _ensure_pca_default_projection():
        print(
            "Warning: could not ensure PCA default; continuing with TensorBoard defaults.",
            file=sys.stderr,
            flush=True,
        )

    logdir = Path(args.logdir).expanduser().resolve()
    if not logdir.exists():
        parser.error(f"Logdir does not exist: {logdir}")

    config_path = logdir / "projector_config.pbtxt"
    if not config_path.exists():
        parser.error(
            f"Missing projector config at {config_path}. Run extract_meal_embeddings.py first."
        )

    if _port_in_use(args.host, args.port):
        parser.error(f"Port {args.port} on {args.host} is already in use")

    tensorboard = program.TensorBoard()
    tensorboard.configure(
        argv=[
            None,
            "--logdir",
            str(logdir),
            "--host",
            args.host,
            "--port",
            str(args.port),
        ]
    )
    url = tensorboard.launch()
    direct_projector_url = (
        f"http://{args.host}:{args.port}/data/plugin/projector/projector_binary.html"
        f"?config=http://{args.host}:{args.port}/data/plugin/projector/info?run=."
    )

    redirect_port = args.redirect_port if args.redirect_port is not None else args.port + 1
    short_url = f"http://{args.host}:{redirect_port}"
    redirect_thread = threading.Thread(
        target=_start_redirect_server,
        args=(args.host, redirect_port, direct_projector_url),
        daemon=True,
    )
    redirect_thread.start()

    print(f"TensorBoard started at {url}", flush=True)
    print(
        f"Open the projector at: {short_url}  (redirects to the direct projector URL)",
        flush=True,
    )
    print("Press Ctrl+C to stop the server.", flush=True)

    try:
        import time

        while True:
            time.sleep(3600)
    except KeyboardInterrupt:
        print("\nStopping TensorBoard.", flush=True)
        return 0


if __name__ == "__main__":
    raise SystemExit(main())