# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for working with trees of xarray.DataArray (including Datasets).

Note that xarray.Dataset doesn't work out-of-the-box with the `tree` library;
it won't work as a leaf node since it implements Mapping, but also won't work
as an internal node since tree doesn't know how to re-create it properly.

To fix this, we reimplement a subset of `map_structure`, exposing its
constituent DataArrays as leaf nodes. This means it can be mapped over as a
generic container of DataArrays, while still preserving the result as a Dataset
where possible.

This is useful because in a few places we need to handle a general
Mapping[str, DataArray] (where the coordinates might not be compatible across
the constituent DataArrays) but also the special case of a Dataset nicely.

For the result e.g. of a tree.map_structure(fn, dataset), if fn returns None for
some of the child DataArrays, they will be omitted from the returned dataset. If
any values other than DataArrays or None are returned, then we don't attempt to
return a Dataset and just return a plain dict of the results. Similarly if
DataArrays are returned but with non-matching coordinates, it will just return a
plain dict of DataArrays.

Note xarray datatypes are registered with `jax.tree_util` by xarray_jax.py,
but `jax.tree_util.tree_map` is distinct from the `xarray_tree.map_structure`.
as the former exposes the underlying JAX/numpy arrays as leaf nodes, while the
latter exposes DataArrays as leaf nodes.
"""

import xarray

from typing import Any, Callable


def map_structure(func: Callable[..., Any], *structures: Any) -> Any:
    """Maps func through given structures with xarrays. See tree.map_structure."""
    if not callable(func):
        raise TypeError(f"func must be callable, got: {func}")
    if not structures:
        raise ValueError("Must provide at least one structure")

    first = structures[0]
    if isinstance(first, xarray.Dataset):
        data = {k: func(*[s[k] for s in structures]) for k in first.keys()}
        if all(isinstance(a, (type(None), xarray.DataArray)) for a in data.values()):
            data_arrays = [v.rename(k) for k, v in data.items() if v is not None]
            try:
                return xarray.merge(data_arrays, join="exact")
            except ValueError:  # Exact join not possible.
                pass
        return data
    if isinstance(first, dict):
        return {k: map_structure(func, *[s[k] for s in structures]) for k in first.keys()}
    if isinstance(first, (list, tuple, set)):
        return type(first)(map_structure(func, *s) for s in zip(*structures))
    return func(*structures)
