"""
Precision policy utilities
精度策略工具：统一管理CUDA场景下的计算/存储精度与matmul精度。
"""

from contextlib import contextmanager
from typing import Literal

import jax
import jax.numpy as jnp


def compute_dtype() -> jnp.dtype:
    """Default compute dtype on CUDA: float32.
    CUDA默认计算精度：float32（TF32路径由matmul precision控制）。
    """
    return jnp.float32


def storage_dtype() -> jnp.dtype:
    """Default storage dtype to save memory: bfloat16.
    默认存储精度：bfloat16，用于占用大的中间量/密度的存储节省。
    """
    return jnp.bfloat16


@contextmanager
def matmul_precision(level: Literal["fastest", "high", "highest"] = "high"):
    """Context manager to control default matmul precision.
    控制矩阵乘精度：
      - "fastest": 速度优先（允许TF32/bfloat16路径）
      - "high":   建议默认（NVIDIA上通常走TF32, 速度与稳定折中）
      - "highest": 全float32精度（更稳但更慢）
    """
    with jax.default_matmul_precision(level):
        yield


