"""
Spatial encoder interface for TWNM.

The actual implementation is proprietary. To plug in a custom encoder,
set the environment variable ``TWNM_SPATIAL_ENCODER_MODULE`` to the dotted
path of a module that exposes a ``SpatialEncoder`` class compatible with
the TWNM audio pipeline.

For internal/private development the reference implementation lives under
``private_impl/spatial_encoder_impl`` and can be added to ``PYTHONPATH``.
"""

from __future__ import annotations

import importlib
import os
from pathlib import Path
from typing import Optional, Type

import torch
import torch.nn as nn


ASSETS_ROOT = Path(__file__).resolve().parents[4] / "assets" / "checkpoints" / "spatial_encoder"
TORCHSCRIPT_PATH = ASSETS_ROOT / "spatial_encoder.ts"


class SpatialEncoderUnavailable(RuntimeError):
    """Raised when no spatial encoder implementation is available."""


class TorchScriptSpatialEncoder(nn.Module):
    """
    TorchScript 版本的空间编码器封装，仅供推理使用。
    """

    def __init__(
        self,
        *args,
        ts_path: Optional[Path] = None,
        map_location: str = "cpu",
        **kwargs,
    ):
        super().__init__()
        path = Path(ts_path) if ts_path else TORCHSCRIPT_PATH
        if not path.exists():
            raise SpatialEncoderUnavailable(
                f"未找到 TorchScript 空间编码器：{path}。"
                "请提供自定义实现，或重新导出 TorchScript 文件。"
            )
        self._module = torch.jit.load(str(path), map_location=map_location)

    def forward_as_encoder(self, audios: torch.Tensor) -> torch.Tensor:
        return self._module(audios)


def get_spatial_encoder_class() -> Type:
    """
    Locate a SpatialEncoder implementation.

    Returns
    -------
    Type
        A class implementing the proprietary spatial encoder interface.

    Raises
    ------
    SpatialEncoderUnavailable
        If no module path is configured or the import fails.
    """

    module_path = os.environ.get("TWNM_SPATIAL_ENCODER_MODULE")
    if module_path:
        try:
            module = importlib.import_module(module_path)
            encoder_cls = getattr(module, "SpatialEncoder")
            return encoder_cls
        except (ImportError, AttributeError) as exc:
            raise SpatialEncoderUnavailable(
                f"无法从 {module_path} 导入 SpatialEncoder，请检查环境变量 "
                "`TWNM_SPATIAL_ENCODER_MODULE` 是否正确指向自定义实现。"
            ) from exc

    if TORCHSCRIPT_PATH.exists():
        return TorchScriptSpatialEncoder

    raise SpatialEncoderUnavailable(
        "当前仓库未开放空间编码器实现。请在环境变量 "
        "`TWNM_SPATIAL_ENCODER_MODULE` 中配置自定义 SpatialEncoder 的模块路径，"
        "或在内网环境下将 private_impl/spatial_encoder_impl 加入 PYTHONPATH。"
    )
