import torch
from typing import Union, Mapping


def dtype_from_str(dtype_input: Union[str, torch.dtype]) -> torch.dtype:
    """
    Convert a textual dtype identifier or a torch.dtype into the corresponding torch.dtype.

    If the input is already a torch.dtype, it is returned unchanged.
    If the input is a string, it is matched (case-insensitive) against known dtype names.
    Otherwise, a TypeError is raised.

    Supported string identifiers include (case-insensitive):
      - Floating: "float32", "float", "fp32", "float64", "double",
                  "float16", "half", "fp16", "bfloat16", "bf16",
                  "float8_e4m3fn", "float8_e5m2"
      - Complex: "complex32", "chalf", "complex64", "cfloat",
                 "complex128", "cdouble"
      - Integer: "int8", "int16", "short", "int32", "int", "int64", "long"
      - Unsigned integer: "uint8", "uint16", "uint32", "uint64"
      - Boolean: "bool"
      - Quantized: "quint8", "qint8", "quint4x2", "qint32"

    Args:
        dtype_input (Union[str, torch.dtype]):
            Either a string naming the dtype, or a torch.dtype instance.

    Returns:
        torch.dtype: The matching PyTorch dtype object.

    Raises:
        TypeError: If `dtype_input` is not a str or torch.dtype.
        ValueError: If `dtype_input` is a string but not one of the supported identifiers.
    """
    # If already a torch.dtype, return it directly
    if isinstance(dtype_input, torch.dtype):
        return dtype_input

    # Must be a string beyond this point
    if not isinstance(dtype_input, str):
        raise TypeError(
            f"dtype_input must be a torch.dtype or str, but got {type(dtype_input)}"
        )

    # Normalize the string key
    key = dtype_input.strip().lower()

    # Mapping from string to torch.dtype
    mapping: Mapping[str, torch.dtype] = {
        # Floating-point
        "float32": torch.float32,
        "float": torch.float32,
        "fp32": torch.float32,
        "float64": torch.float64,
        "double": torch.float64,
        "float16": torch.float16,
        "half": torch.float16,
        "fp16": torch.float16,
        "bfloat16": torch.bfloat16,
        "bf16": torch.bfloat16,
        "float8_e4m3fn": torch.float8_e4m3fn,
        "float8_e5m2": torch.float8_e5m2,
        # Complex
        "complex32": torch.complex32,
        "chalf": torch.complex32,
        "complex64": torch.complex64,
        "cfloat": torch.complex64,
        "complex128": torch.complex128,
        "cdouble": torch.complex128,
        # Integer
        "int8": torch.int8,
        "int16": torch.int16,
        "short": torch.int16,
        "int32": torch.int32,
        "int": torch.int32,
        "int64": torch.int64,
        "long": torch.int64,
        # Unsigned integer
        "uint8": torch.uint8,
        "uint16": torch.uint16,
        "uint32": torch.uint32,
        "uint64": torch.uint64,
        # Boolean
        "bool": torch.bool,
        # Quantized
        "quint8": torch.quint8,
        "qint8": torch.qint8,
        "quint4x2": torch.quint4x2,
        "qint32": torch.qint32,
    }

    try:
        return mapping[key]
    except KeyError:
        valid = ", ".join(sorted(mapping.keys()))
        raise ValueError(f"Unrecognized dtype '{dtype_input}'. Valid options: {valid}.")
