"""Wrappers that take care of casting."""

import contextlib
from typing import Any, Mapping, Tuple

import chex
import jax
import jax.numpy as jnp
import numpy as np
import xarray


PyTree = Any



def infer_floating_dtype(data_vars: Mapping[str, chex.Array]) -> np.dtype:
  """Infers a floating dtype from an input mapping of data."""
  dtypes = {
      v.dtype
      for k, v in data_vars.items() if jnp.issubdtype(v.dtype, np.floating)}
  if len(dtypes) != 1:
    dtypes_and_shapes = {
        k: (v.dtype, v.shape)
        for k, v in data_vars.items() if jnp.issubdtype(v.dtype, np.floating)}
    raise ValueError(
        f'Did not found exactly one floating dtype {dtypes} in input variables:'
        f'{dtypes_and_shapes}')
  return list(dtypes)[0]


def _all_inputs_to_bfloat16(
    inputs: xarray.Dataset,
    targets: xarray.Dataset,
    forcings: xarray.Dataset,
    ) -> Tuple[xarray.Dataset,
               xarray.Dataset,
               xarray.Dataset]:
  return (inputs.astype(jnp.bfloat16),
          jax.tree.map(lambda x: x.astype(jnp.bfloat16), targets),
          forcings.astype(jnp.bfloat16))


def tree_map_cast(inputs: PyTree, input_dtype: np.dtype, output_dtype: np.dtype,
                  ) -> PyTree:
  def cast_fn(x):
    if isinstance(x, (jnp.ndarray, np.ndarray)) and x.dtype == input_dtype:
      return x.astype(output_dtype)
    return x # Return x unchanged if not a numerical array or not the specified input_dtype
  return jax.tree.map(cast_fn, inputs)