
"""Utilities for working with trees of xarray.DataArray (including Datasets).

Note that xarray.Dataset doesn't work out-of-the-box with the `tree` library;
it won't work as a leaf node since it implements Mapping, but also won't work
as an internal node since tree doesn't know how to re-create it properly.

To fix this, we reimplement a subset of `map_structure`, exposing its
constituent DataArrays as leaf nodes. This means it can be mapped over as a
generic container of DataArrays, while still preserving the result as a Dataset
where possible.

This is useful because in a few places we need to handle a general
Mapping[str, DataArray] (where the coordinates might not be compatible across
the constituent DataArrays) but also the special case of a Dataset nicely.

For the result e.g. of a tree.map_structure(fn, dataset), if fn returns None for
some of the child DataArrays, they will be omitted from the returned dataset. If
any values other than DataArrays or None are returned, then we don't attempt to
return a Dataset and just return a plain dict of the results. Similarly if
DataArrays are returned but with non-matching coordinates, it will just return a
plain dict of DataArrays.

Note xarray datatypes are registered with `jax.tree_util` by xarray_jax.py,
but `jax.tree_util.tree_map` is distinct from the `xarray_tree.map_structure`.
as the former exposes the underlying JAX/numpy arrays as leaf nodes, while the
latter exposes DataArrays as leaf nodes.
"""

from typing import Any, Callable

import xarray


def map_structure(func: Callable[..., Any], *structures: Any) -> Any:
  """Maps func through given structures with xarrays. See tree.map_structure."""
  if not callable(func):
    raise TypeError(f'func must be callable, got: {func}')
  if not structures:
    raise ValueError('Must provide at least one structure')

  first = structures[0]
  if isinstance(first, xarray.Dataset):
    data = {k: func(*[s[k] for s in structures]) for k in first.keys()}
    if all(isinstance(a, (type(None), xarray.DataArray))
           for a in data.values()):
      data_arrays = [v.rename(k) for k, v in data.items() if v is not None]
      try:
        return xarray.merge(data_arrays, join='exact')
      except ValueError:  # Exact join not possible.
        pass
    return data
  if isinstance(first, dict):
    return {k: map_structure(func, *[s[k] for s in structures])
            for k in first.keys()}
  if isinstance(first, (list, tuple, set)):
    return type(first)(map_structure(func, *s) for s in zip(*structures))
  return func(*structures)
