"""
HelioX backend management.

This module centralizes HelioX runtime configuration and keeps strong
references to created wrappers so they stay alive for the simulation.
"""

from __future__ import annotations

import os
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple


@dataclass
class BackendConfig:
    """Runtime configuration for backend execution modes."""

    enable_heliox: bool = False
    export_path: str = "heliox_export"
    device: str = "cpu"
    permute_type: int = 0

    def clone(self) -> "BackendConfig":
        """Return a shallow copy to avoid accidental mutation."""
        return BackendConfig(
            enable_heliox=self.enable_heliox,
            export_path=self.export_path,
            device=self.device,
            permute_type=self.permute_type,
        )


class HybridBackend:
    """
    Coordinate HelioX runtime.

    When ``enable_heliox`` is True we instantiate a HelioXManager and
    keep all wrappers that are registered through this object. The actual
    ``setup_and_load_model`` call is deferred until ``initialize`` is
    invoked so the caller can control when model export happens.
    """

    def __init__(self, config: Optional[BackendConfig] = None):
        self.config = config.clone() if config else BackendConfig()
        self.enable_heliox = self.config.enable_heliox

        self._manager = None
        self._initialized = False
        self._obj_wrappers: List[Any] = []
        self._monitor_wrappers: List[Any] = []
        self._vecplay_wrappers: List[Any] = []
        self._optimizer_params: List[Dict[str, Any]] = []
        self._optimizer_param_keys: Set[int] = set()
        self._optimizer_id: Optional[int] = None
        self._optimizer_ready: bool = False
        self._optimizer_type: str = "sgd"
        self._optimizer_hparams: Dict[str, float] = {
            "momentum": 0.9,
            "beta1": 0.9,
            "beta2": 0.999,
            "epsilon": 1e-8,
        }
        self._registration_stack: List[Tuple[int, List[Dict[str, Any]]]] = []
        self._network_param_groups: Dict[int, List[Dict[str, Any]]] = {}
        self._network_param_order: List[int] = []
        self._batch_mode: bool = False
        self._batch_network_ids: List[int] = []
        self._batch_param_groups: List[List[Dict[str, Any]]] = []
        self._batch_param_count: int = 0
        self._batch_configured: bool = False

        def ensure_heliox_on_syspath() -> None:
            candidates = []
            env_home = os.environ.get("HELIOX_HOME")
            if env_home:
                candidates.append(Path(env_home))
            for base in candidates:
                try:
                    if not base.is_dir():
                        continue
                except OSError:
                    continue
                for sub in ("python_lib", "build"):
                    path = (base / sub).resolve()
                    if path.is_dir():
                        path_str = str(path)
                        if path_str not in sys.path:
                            sys.path.insert(0, path_str)

        if self.enable_heliox:
            # Lazy import so environments without heliox still work
            try:
                from heliox_wrapper import HelioXManager  # type: ignore
            except ModuleNotFoundError:
                ensure_heliox_on_syspath()
                from heliox_wrapper import HelioXManager  # type: ignore

            self._manager = HelioXManager()
            self._manager.set_default_device(self.config.device)
            self._manager.set_default_permute_type(self.config.permute_type)

    # ------------------------------------------------------------------
    # Wrapper helpers
    # ------------------------------------------------------------------

    def wrap_obj(self, obj: Any) -> Optional[Any]:
        """Create an ObjWrapper for ``obj`` when HelioX is enabled."""
        if not self.enable_heliox or self._manager is None:
            return None

        wrapper = self._manager.create_obj_wrapper(obj)
        self._obj_wrappers.append(wrapper)
        return wrapper

    def wrap_recorder(self, obj: Any, var_name: str = "v", array_index: int = 0) -> Optional[Any]:
        """Create a RecorderWrapper for ``obj`` when HelioX is enabled."""
        if not self.enable_heliox or self._manager is None:
            return None

        wrapper = self._manager.create_recorder(obj, var_name=var_name, array_index=array_index)
        self._monitor_wrappers.append(wrapper)
        return wrapper

    def wrap_vecplay(self, obj: Any, var_name: str = "amp") -> Optional[Any]:
        """Create a VecPlayWrapper for ``obj`` when HelioX is enabled."""
        if not self.enable_heliox or self._manager is None:
            return None

        wrapper = self._manager.create_vecplay_wrapper(obj, var_name=var_name)
        self._vecplay_wrappers.append(wrapper)
        return wrapper

    # ------------------------------------------------------------------
    # Batch registration helpers
    # ------------------------------------------------------------------

    def begin_network_registration(self, network: Any) -> None:
        """Mark the start of optimizer parameter registration for a network."""
        if not self.enable_heliox:
            return

        network_key = id(network)
        if network_key in self._network_param_groups:
            raise RuntimeError("Network already registered with this backend")

        group: List[Dict[str, Any]] = []
        self._network_param_groups[network_key] = group
        self._registration_stack.append((network_key, group))
        if network_key not in self._network_param_order:
            self._network_param_order.append(network_key)
        self._batch_configured = False

    def end_network_registration(self, network: Any) -> None:
        """Finalize optimizer parameter registration for a network."""
        if not self.enable_heliox:
            return
        if not self._registration_stack:
            return

        network_key = id(network)
        stack_key, _ = self._registration_stack.pop()
        if stack_key != network_key:
            raise RuntimeError("Network registration stack mismatch detected")
        if not self._network_param_groups.get(network_key):
            # No trainable parameters discovered; drop the placeholder entry
            self._network_param_groups.pop(network_key, None)
            if network_key in self._network_param_order:
                self._network_param_order.remove(network_key)
        self._batch_configured = False

    def configure_batch_optimizer(self, networks: List[Any]) -> None:
        """Configure HelioX optimizer for batch mode with multiple networks."""
        if not self.enable_heliox or self._manager is None:
            raise RuntimeError("HelioX backend is not enabled")
        if not networks:
            raise ValueError("No networks provided for batch optimizer configuration")
        if len(networks) < 2:
            raise ValueError("Batch optimizer requires at least two networks")
        if self._initialized:
            raise RuntimeError("Batch optimizer must be configured before initialize()")

        network_ids = [id(net) for net in networks]
        if len(set(network_ids)) != len(network_ids):
            raise ValueError("Duplicate networks detected in batch configuration")

        for net in networks:
            if getattr(net, "backend", None) is not self:
                raise RuntimeError("All networks must share the same HybridBackend instance")

        if self._batch_mode and self._batch_network_ids == network_ids and self._batch_configured:
            return  # idempotent for the same configuration
        if self._batch_mode and self._batch_network_ids != network_ids:
            raise RuntimeError("Batch optimizer already configured for a different network set")

        param_groups: List[List[Dict[str, Any]]] = []
        for key in network_ids:
            group = self._network_param_groups.get(key)
            if group is None:
                raise RuntimeError(
                    "Network has not registered optimizer parameters; ensure build() completed"
                )
            param_groups.append(group)

        if not param_groups or not param_groups[0]:
            raise RuntimeError("No optimizer parameters available for batch configuration")

        param_count = len(param_groups[0])
        for idx, group in enumerate(param_groups):
            if len(group) != param_count:
                raise RuntimeError(
                    f"Batch parameter count mismatch for network index {idx} "
                    f"(expected {param_count}, got {len(group)})"
                )

        self._batch_mode = True
        self._batch_network_ids = network_ids
        self._batch_param_groups = param_groups
        self._batch_param_count = param_count
        self._batch_configured = True
        self._optimizer_ready = False  # force re-registration during initialization

    # ------------------------------------------------------------------
    # Execution helpers
    # ------------------------------------------------------------------

    def initialize(self, dt: float, v_init: float, export_path: Optional[str] = None) -> None:
        """
        Export the model and load it into HelioX.

        Must be called after the model has been constructed but before
        running HelioX simulations. Calling it multiple times is safe.
        """
        if not self.enable_heliox or self._manager is None:
            return

        if self._initialized:
            return

        path = Path(export_path or self.config.export_path)
        path.mkdir(parents=True, exist_ok=True)

        self._manager.setup_and_load_model(str(path), dt=dt, v_init=v_init)
        self._initialized = True
        self._finalize_optimizer_registration()

    def set_dt(self, dt: float) -> None:
        """Set simulation dt on HelioX when enabled."""
        if self.enable_heliox and self._manager is not None and self._initialized:
            self._manager.set_dt(dt)

    def finitialize(self, v_init: float) -> None:
        """Run finitialize on HelioX when enabled."""
        if self.enable_heliox and self._manager is not None and self._initialized:
            self._manager.finitialize(v_init)

    def run(self, runtime: float) -> None:
        """Run simulation on HelioX when enabled."""
        if self.enable_heliox and self._manager is not None and self._initialized:
            self._manager.run(runtime)

    def register_optimizer_param(self, wrapper: Any, impedance: float) -> None:
        """Register a trainable parameter with HelioX optimizer."""
        if (not self.enable_heliox or self._manager is None or wrapper is None
                or impedance is None):
            return

        key = id(wrapper)
        if key in self._optimizer_param_keys:
            return

        self._optimizer_param_keys.add(key)
        entry = {"wrapper": wrapper, "impedance": float(impedance)}
        self._optimizer_params.append(entry)
        if self._registration_stack:
            self._registration_stack[-1][1].append(entry)
        self._batch_configured = False

        if self._initialized:
            self._ensure_optimizer_created()
            if self._batch_mode:
                # 将在下一次 finalize 时统一批量注册
                self._optimizer_ready = False
            else:
                self._register_single_optimizer_param(entry)
                if self._optimizer_id is not None:
                    self._optimizer_ready = True

    def get_manager(self):
        """Expose the underlying HelioXManager for advanced use cases."""
        return self._manager

    @property
    def initialized(self) -> bool:
        return self._initialized

    @property
    def optimizer_ready(self) -> bool:
        return self._optimizer_ready

    @property
    def batch_configured(self) -> bool:
        return self._batch_configured

    def optimizer_step(self, learning_rate: float, record_time: float, dt: float) -> bool:
        """Execute a HelioX optimizer step if available."""
        if (not self.enable_heliox or self._manager is None or not self._initialized
                or not self._optimizer_ready or self._optimizer_id is None):
            return False

        self._manager.optimizer_step(self._optimizer_id, learning_rate, record_time, dt)
        return True

    def set_optimizer_config(self,
                             optimizer_type: Optional[str] = None,
                             momentum: Optional[float] = None,
                             beta1: Optional[float] = None,
                             beta2: Optional[float] = None,
                             epsilon: Optional[float] = None) -> None:
        """Configure HelioX optimizer type/hyperparameters before initialization."""
        if optimizer_type:
            self._optimizer_type = optimizer_type.lower()
        if momentum is not None:
            self._optimizer_hparams["momentum"] = float(momentum)
        if beta1 is not None:
            self._optimizer_hparams["beta1"] = float(beta1)
        if beta2 is not None:
            self._optimizer_hparams["beta2"] = float(beta2)
        if epsilon is not None:
            self._optimizer_hparams["epsilon"] = float(epsilon)
        # If optimizer is already created, configure it in-place.
        if self._optimizer_id is not None and self._manager is not None:
            self._configure_optimizer()

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------

    def _ensure_optimizer_created(self) -> None:
        if self._optimizer_id is None:
            self._optimizer_id = self._manager.create_optimizer(self._optimizer_type)
            self._configure_optimizer()

    def _configure_optimizer(self) -> None:
        if self._optimizer_id is None or self._manager is None:
            return
        params = self._optimizer_hparams
        self._manager.configure_optimizer(
            self._optimizer_id,
            params.get("momentum", 0.9),
            params.get("beta1", 0.9),
            params.get("beta2", 0.999),
            params.get("epsilon", 1e-8),
        )

    def _register_single_optimizer_param(self, entry: Dict[str, Any]) -> None:
        wrapper = entry["wrapper"]
        impedance = entry["impedance"]
        try:
            weight_handle = wrapper.get_handle("w")
            grad_handle = wrapper.get_handle("acc_grad")
        except Exception as exc:  # pylint: disable=broad-except
            print(f"Warning: failed to obtain handles for optimizer param: {exc}")
            return

        result = self._manager.optimizer_add_param(self._optimizer_id, weight_handle, grad_handle, impedance)
        if result < 0:
            print("Warning: optimizer_add_param returned negative result")

    def _finalize_optimizer_registration(self) -> None:
        if (not self.enable_heliox or self._manager is None
                or self._optimizer_ready or not self._optimizer_params):
            return
        self._ensure_optimizer_created()
        if self._batch_mode and self._batch_param_groups:
            self._register_batch_optimizer_params()
            has_params = self._batch_param_count > 0
        else:
            for entry in self._optimizer_params:
                self._register_single_optimizer_param(entry)
            has_params = bool(self._optimizer_params)
        self._optimizer_ready = self._optimizer_id is not None and has_params

    def _register_batch_optimizer_params(self) -> None:
        if self._optimizer_id is None or not self._batch_param_groups:
            return

        batch_size = len(self._batch_param_groups)
        param_count = self._batch_param_count
        if param_count <= 0 or batch_size <= 0:
            return

        for param_idx in range(param_count):
            wrappers = []
            impedances = []
            for group in self._batch_param_groups:
                entry = group[param_idx]
                wrappers.append(entry["wrapper"])
                impedances.append(entry["impedance"])

            base_impedance = impedances[0]
            for imp in impedances[1:]:
                if abs(imp - base_impedance) > 1e-9:
                    raise RuntimeError("Impedance mismatch detected across batch networks")

            weight_handles: List[int] = []
            grad_handles: List[int] = []
            for wrapper in wrappers:
                try:
                    weight_handles.append(wrapper.get_handle("w"))
                    grad_handles.append(wrapper.get_handle("acc_grad"))
                except Exception as exc:  # pylint: disable=broad-except
                    raise RuntimeError("Failed to obtain optimizer handles for batch parameter") from exc

            result = self._manager.optimizer_add_param_batch(
                self._optimizer_id,
                weight_handles,
                grad_handles,
                base_impedance,
            )
            if result < 0:
                print("Warning: optimizer_add_param_batch returned negative result")
