from __future__ import annotations

from typing import Any, Dict

import numpy as np

try:
    import ml_dtypes as _ml_dtypes
    _float8_e4m3fn = _ml_dtypes.float8_e4m3fn
    _float8_e5m2 = _ml_dtypes.float8_e5m2
except ImportError:
    _float8_e4m3fn = None
    _float8_e5m2 = None


_FORMAT_ALIASES = {
    "fp8": "fp8_e4m3fn",
    "fp8_e4m3fn": "fp8_e4m3fn",
    "fp8_e4m3": "fp8_e4m3fn",
    "float8_e4m3fn": "fp8_e4m3fn",
    "fp8_e5m2": "fp8_e5m2",
    "float8_e5m2": "fp8_e5m2",
    "fp16": "fp16",
    "float16": "fp16",
    "half": "fp16",
    "fp32": "fp32",
    "float32": "fp32",
    "single": "fp32",
    "fp64": "fp64",
    "float64": "fp64",
    "double": "fp64",
    "float": "fp64",
}


_FORMAT_INFO = {
    "fp16": {
        "dag_dtype": "float16",
        "numpy_dtype": np.float16,
        "numpy_name": "float16",
        "uint_dtype": np.uint16,
        "signed_dtype": np.int32,
        "sign_shift": 15,
        "sign_mask": 0x7FFF,
    },
    "fp32": {
        "dag_dtype": "float32",
        "numpy_dtype": np.float32,
        "numpy_name": "float32",
        "uint_dtype": np.uint32,
        "signed_dtype": np.int64,
        "sign_shift": 31,
        "sign_mask": 0x7FFFFFFF,
    },
    "fp64": {
        "dag_dtype": "float64",
        "numpy_dtype": np.float64,
        "numpy_name": "float64",
        "uint_dtype": np.uint64,
        "signed_dtype": np.int64,
        "sign_shift": 63,
        "sign_mask": 0x7FFFFFFFFFFFFFFF,
    },
}

if _float8_e4m3fn is not None:
    _FORMAT_INFO["fp8_e4m3fn"] = {
        "dag_dtype": "float8_e4m3fn",
        "numpy_dtype": _float8_e4m3fn,
        "numpy_name": "float8_e4m3fn",
        "uint_dtype": np.uint8,
        "signed_dtype": np.int16,
        "sign_shift": 7,
        "sign_mask": 0x7F,
    }
if _float8_e5m2 is not None:
    _FORMAT_INFO["fp8_e5m2"] = {
        "dag_dtype": "float8_e5m2",
        "numpy_dtype": _float8_e5m2,
        "numpy_name": "float8_e5m2",
        "uint_dtype": np.uint8,
        "signed_dtype": np.int16,
        "sign_shift": 7,
        "sign_mask": 0x7F,
    }


def canonicalize_precision_format(value: Any, default: str = "fp32") -> str:
    token = str(value if value is not None else default).strip().lower()
    canonical = _FORMAT_ALIASES.get(token)
    if canonical is None:
        raise ValueError(f"Unsupported precision format: {value!r}")
    if canonical not in _FORMAT_INFO:
        raise ValueError(
            f"Precision format {canonical!r} requires optional dependency ml_dtypes. "
            "Install ml-dtypes or choose fp16/fp32/fp64."
        )
    return canonical


def _read_field(container: Any, name: str, default: Any = None) -> Any:
    if container is None:
        return default
    if isinstance(container, dict):
        return container.get(name, default)
    return getattr(container, name, default)


def normalize_precision_model(precision_like: Any) -> Dict[str, str]:
    input_format = canonicalize_precision_format(_read_field(precision_like, "input_format", "fp32"))
    compute_format = canonicalize_precision_format(
        _read_field(precision_like, "compute_format", input_format),
        default=input_format,
    )
    output_format = canonicalize_precision_format(
        _read_field(precision_like, "output_format", compute_format),
        default=compute_format,
    )
    return {
        "input_format": input_format,
        "compute_format": compute_format,
        "output_format": output_format,
    }


def precision_format_to_dag_dtype(value: Any, default: str = "fp32") -> str:
    fmt = canonicalize_precision_format(value, default=default)
    return str(_FORMAT_INFO[fmt]["dag_dtype"])


def dag_dtype_to_precision_format(dtype: Any, default: str = "fp32") -> str:
    token = str(dtype if dtype is not None else "").strip().lower()
    if not token:
        return canonicalize_precision_format(default)
    return canonicalize_precision_format(token, default=default)


def dag_dtype_to_numpy_dtype(dtype: Any, default: str = "float32") -> np.dtype[Any]:
    precision = dag_dtype_to_precision_format(dtype, default=default)
    return np.dtype(_FORMAT_INFO[precision]["numpy_dtype"])


def precision_format_to_numpy_dtype(value: Any, default: str = "fp32") -> np.dtype[Any]:
    fmt = canonicalize_precision_format(value, default=default)
    return np.dtype(_FORMAT_INFO[fmt]["numpy_dtype"])


def precision_format_to_numpy_name(value: Any, default: str = "fp32") -> str:
    fmt = canonicalize_precision_format(value, default=default)
    return str(_FORMAT_INFO[fmt]["numpy_name"])


def quantize_array(values: Any, value_format: Any, default: str = "fp32") -> np.ndarray:
    dtype = precision_format_to_numpy_dtype(value_format, default=default)
    return np.asarray(values, dtype=dtype)


def ulp_distance(a: Any, b: Any, value_format: Any, default: str = "fp32") -> np.ndarray:
    fmt = canonicalize_precision_format(value_format, default=default)
    info = _FORMAT_INFO[fmt]
    arr_a = np.asarray(a, dtype=info["numpy_dtype"])
    arr_b = np.asarray(b, dtype=info["numpy_dtype"])
    bits_a = arr_a.view(info["uint_dtype"])
    bits_b = arr_b.view(info["uint_dtype"])
    sign_shift = int(info["sign_shift"])
    sign_bit = 1 << sign_shift
    full_mask = int(np.iinfo(info["uint_dtype"]).max)
    ordered_a = np.where((bits_a & sign_bit) != 0, full_mask - bits_a, bits_a + sign_bit)
    ordered_b = np.where((bits_b & sign_bit) != 0, full_mask - bits_b, bits_b + sign_bit)
    dist = np.where(ordered_a >= ordered_b, ordered_a - ordered_b, ordered_b - ordered_a)
    dist = np.asarray(dist, dtype=np.float64)
    return np.where(arr_a == arr_b, 0.0, dist)
