"""Utility functions for device selection and data type conversion.

This module provides utility functions for handling device selection and data type
conversion in neural operator training.
"""

from dataclasses import is_dataclass

import torch
import wandb


def get_dtype(name: str) -> torch.dtype:
    """Convert a string to the corresponding torch.dtype.
    
    This function converts various string representations of data types to their
    corresponding torch.dtype objects. It supports both exact torch names and
    common aliases used in configuration files.
    
    Parameters
    ----------
    name : str
        String representation of the data type. Can be exact torch names like
        'float32' or aliases like 'fp32', 'f32', etc.
        
    Returns
    -------
    torch.dtype
        The corresponding torch.dtype object.
        
    Raises
    ------
    KeyError
        If the data type name is not recognized.
        
    Examples
    --------
    >>> get_dtype('float32')
    torch.float32
    >>> get_dtype('fp32')
    torch.float32
    >>> get_dtype('torch.float32')
    torch.float32
    >>> get_dtype('half')
    torch.float16
    """
    # Remove 'torch.' prefix if present
    name = name.lower().replace("torch.", "")

    try:
        return getattr(torch, name)  # works for exact names
    except AttributeError:
        # Handle friendly aliases commonly used in configs / CLI flags
        alias = {
            "fp32": "float32",
            "f32": "float32",
            "fp16": "float16",
            "f16": "float16",
            "half": "float16",
            "bf16": "bfloat16",
            "float": "float32",
            "double": "float64",
            "long": "int64",
            "int": "int32",
            "short": "int16",
            "byte": "uint8",
        }
        std_name = alias.get(name)
        if std_name and hasattr(torch, std_name):
            return getattr(torch, std_name)
        raise KeyError(f"Unknown dtype string '{name}'.")


def deep_update(dc_obj, updates: wandb.config) -> None:
    """
    Recursively overwrite fields of a (possibly nested) dataclass with `updates`.
    Works in-place; returns nothing.
    """
    for key, val in updates.items():
        if not hasattr(dc_obj, key):
            # Ignore keys that don't exist in the target object
            continue
        current_val = getattr(dc_obj, key)
        # Recurse if both sides are (sub-)dataclasses **and** the new value is a dict
        if is_dataclass(current_val) and isinstance(val, dict):
            deep_update(current_val, val)
        else:
            setattr(dc_obj, key, val)
