"""
Device utilities for JAX
GPU/CPU 设备工具

提供统一的 GPU 选择与数组放置，避免 host-device 来回拷贝。
"""

from __future__ import annotations

from typing import Any

import jax


def get_preferred_device():
    # Prefer CUDA/ROCM when available; fall back to default (CPU) gracefully
    for platform in ("cuda", "rocm"):
        try:
            devs = jax.devices(platform)
            if len(devs) > 0:
                return devs[0]
        except Exception:
            # backend not present
            pass
    return jax.devices()[0]


def device_put_tree(tree: Any):
    dev = get_preferred_device()
    return jax.device_put(tree, device=dev)


