"""Global state management for experiment config, RNG, and metrics."""

from typing import TYPE_CHECKING, Optional

import numpy as np

# Global dtype for numerical computations
# Options: np.float16, np.float32, np.float64
FLOAT_DTYPE = np.float16

# Store original numpy functions
_original_array = np.array
_original_zeros = np.zeros
_original_ones = np.ones
_original_empty = np.empty


def _patched_array(object, dtype=None, **kwargs):
    """Patched np.array that defaults floats to FLOAT_DTYPE."""
    arr = _original_array(object, dtype=dtype, **kwargs)
    if dtype is None and np.issubdtype(arr.dtype, np.floating):
        return arr.astype(FLOAT_DTYPE)
    return arr


def _patched_zeros(shape, dtype=None, **kwargs):
    """Patched np.zeros that defaults to FLOAT_DTYPE."""
    if dtype is None:
        dtype = FLOAT_DTYPE
    return _original_zeros(shape, dtype=dtype, **kwargs)


def _patched_ones(shape, dtype=None, **kwargs):
    """Patched np.ones that defaults to FLOAT_DTYPE."""
    if dtype is None:
        dtype = FLOAT_DTYPE
    return _original_ones(shape, dtype=dtype, **kwargs)


def _patched_empty(shape, dtype=None, **kwargs):
    """Patched np.empty that defaults to FLOAT_DTYPE."""
    if dtype is None:
        dtype = FLOAT_DTYPE
    return _original_empty(shape, dtype=dtype, **kwargs)


def set_float_dtype(dtype: type = FLOAT_DTYPE) -> None:
    """Patch numpy to use specified float dtype by default.

    Call this at startup to make np.array, np.zeros, np.ones, np.empty
    default to the specified dtype instead of float64.
    """
    global FLOAT_DTYPE
    FLOAT_DTYPE = dtype
    np.array = _patched_array
    np.zeros = _patched_zeros
    np.ones = _patched_ones
    np.empty = _patched_empty


if TYPE_CHECKING:
    from ..experiments import ExperimentConfig
else:
    ExperimentConfig = None


_current_config: Optional["ExperimentConfig"] = None
_global_rng: Optional[np.random.Generator] = None
_debug_mode: bool = False


def set_debug_mode(debug: bool) -> None:
    """Set global debug mode for logging."""
    global _debug_mode
    _debug_mode = debug


def is_debug_mode() -> bool:
    """Check if global debug mode is enabled."""
    global _debug_mode
    return _debug_mode


def set_experiment_config(config: "ExperimentConfig") -> None:
    """Set the global experiment configuration."""
    global _current_config
    _current_config = config


def get_experiment_config() -> Optional["ExperimentConfig"]:
    """Get the current experiment configuration."""
    global _current_config
    return _current_config


def set_global_seed(seed: int) -> None:
    """Set the global random number generator seed."""
    global _global_rng
    import os

    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    _global_rng = np.random.default_rng(seed)
    import torch

    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(True)


def get_rng() -> np.random.Generator:
    """Get the global random number generator."""
    if _global_rng is None:
        raise ValueError(
            "Global RNG not initialized. Call set_global_seed(seed) first."
        )
    return _global_rng
