# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for xarray_jax."""

import chex
import jax
import jax.numpy as jnp
import numpy as np
import xarray

from absl.testing import absltest

from . import xarray_jax


class XarrayJaxTest(absltest.TestCase):
    def test_jax_array_wrapper_with_numpy_api(self):
        # This is just a side benefit of making things work with xarray, but the
        # JaxArrayWrapper does allow you to manipulate JAX arrays using the
        # standard numpy API, without converting them to numpy in the process:
        ones = jnp.ones((3, 4), dtype=np.float32)
        x = xarray_jax.JaxArrayWrapper(ones)
        x = np.abs((x + 2) * (x - 3))
        x = x[:-1, 1:3]
        x = np.concatenate([x, x + 1], axis=0)
        x = np.transpose(x, (1, 0))
        x = np.reshape(x, (-1,))
        x = x.astype(np.int32)
        self.assertIsInstance(x, xarray_jax.JaxArrayWrapper)
        # An explicit conversion gets us out of JAX-land however:
        self.assertIsInstance(np.asarray(x), np.ndarray)

    def test_jax_xarray_variable(self):
        def ops_via_xarray(inputs):
            x = xarray_jax.Variable(("lat", "lon"), inputs)
            # We'll apply a sequence of operations just to test that the end result is
            # still a JAX array, i.e. we haven't converted to numpy at any point.
            x = np.abs((x + 2) * (x - 3))
            x = x.isel({"lat": slice(0, -1), "lon": slice(1, 3)})
            x = xarray.Variable.concat([x, x + 1], dim="lat")
            x = x.transpose("lon", "lat")
            x = x.stack(channels=("lon", "lat"))
            x = x.sum()
            return xarray_jax.jax_data(x)

        # Check it doesn't leave jax-land when passed concrete values:
        ones = jnp.ones((3, 4), dtype=np.float32)
        result = ops_via_xarray(ones)
        self.assertIsInstance(result, jax.Array)

        # And that you can JIT it and compute gradients through it. These will
        # involve passing jax tracers through the xarray computation:
        jax.jit(ops_via_xarray)(ones)
        jax.grad(ops_via_xarray)(ones)

    def test_jax_xarray_data_array(self):
        def ops_via_xarray(inputs):
            x = xarray_jax.DataArray(
                dims=("lat", "lon"),
                data=inputs,
                coords={"lat": np.arange(3) * 10, "lon": np.arange(4) * 10},
            )
            x = np.abs((x + 2) * (x - 3))
            x = x.sel({"lat": slice(0, 20)})
            y = xarray_jax.DataArray(
                dims=("lat", "lon"),
                data=ones,
                coords={"lat": np.arange(3, 6) * 10, "lon": np.arange(4) * 10},
            )
            x = xarray.concat([x, y], dim="lat")
            x = x.transpose("lon", "lat")
            x = x.stack(channels=("lon", "lat"))
            x = x.unstack()
            x = x.sum()
            return xarray_jax.jax_data(x)

        ones = jnp.ones((3, 4), dtype=np.float32)
        result = ops_via_xarray(ones)
        self.assertIsInstance(result, jax.Array)

        jax.jit(ops_via_xarray)(ones)
        jax.grad(ops_via_xarray)(ones)

    def test_jax_xarray_dataset(self):
        def ops_via_xarray(foo, bar):
            x = xarray_jax.Dataset(
                data_vars={"foo": (("lat", "lon"), foo), "bar": (("time", "lat", "lon"), bar)},
                coords={"time": np.arange(2), "lat": np.arange(3) * 10, "lon": np.arange(4) * 10},
            )
            x = np.abs((x + 2) * (x - 3))
            x = x.sel({"lat": slice(0, 20)})
            y = xarray_jax.Dataset(
                data_vars={"foo": (("lat", "lon"), foo), "bar": (("time", "lat", "lon"), bar)},
                coords={
                    "time": np.arange(2),
                    "lat": np.arange(3, 6) * 10,
                    "lon": np.arange(4) * 10,
                },
            )
            x = xarray.concat([x, y], dim="lat")
            x = x.transpose("lon", "lat", "time")
            x = x.stack(channels=("lon", "lat"))
            x = (x.foo + x.bar).sum()
            return xarray_jax.jax_data(x)

        foo = jnp.ones((3, 4), dtype=np.float32)
        bar = jnp.ones((2, 3, 4), dtype=np.float32)
        result = ops_via_xarray(foo, bar)
        self.assertIsInstance(result, jax.Array)

        jax.jit(ops_via_xarray)(foo, bar)
        jax.grad(ops_via_xarray)(foo, bar)

    def test_jit_function_with_xarray_variable_arguments_and_return(self):
        function = jax.jit(lambda v: v + 1)
        with self.subTest("jax input"):
            inputs = xarray_jax.Variable(("lat", "lon"), jnp.ones((3, 4), dtype=np.float32))
            _ = function(inputs)
            # We test running the jitted function a second time, to exercise logic in
            # jax which checks if the structure of the inputs (including dimension
            # names and coordinates) is the same as it was for the previous call and
            # so whether it needs to re-trace-and-compile a new version of the
            # function or not. This can run into problems if the 'aux' structure
            # returned by the registered flatten function is not hashable/comparable.
            outputs = function(inputs)
            self.assertEqual(outputs.dims, inputs.dims)
        with self.subTest("numpy input"):
            inputs = xarray.Variable(("lat", "lon"), np.ones((3, 4), dtype=np.float32))
            _ = function(inputs)
            outputs = function(inputs)
            self.assertEqual(outputs.dims, inputs.dims)

    def test_jit_problem_if_convert_to_plain_numpy_array(self):
        inputs = xarray_jax.DataArray(data=jnp.ones((2,), dtype=np.float32), dims=("foo",))
        with self.assertRaises(jax.errors.TracerArrayConversionError):
            # Calling .values on a DataArray converts its values to numpy:
            jax.jit(lambda data_array: data_array.values)(inputs)

    def test_grad_function_with_xarray_variable_arguments(self):
        x = xarray_jax.Variable(("lat", "lon"), jnp.ones((3, 4), dtype=np.float32))
        # For grad we still need a JAX scalar as the output:
        jax.grad(lambda v: xarray_jax.jax_data(v.sum()))(x)

    def test_jit_function_with_xarray_data_array_arguments_and_return(self):
        inputs = xarray_jax.DataArray(
            data=jnp.ones((3, 4), dtype=np.float32),
            dims=("lat", "lon"),
            coords={"lat": np.arange(3), "lon": np.arange(4) * 10},
        )
        fn = jax.jit(lambda v: v + 1)
        _ = fn(inputs)
        outputs = fn(inputs)
        self.assertEqual(outputs.dims, inputs.dims)
        chex.assert_trees_all_equal(outputs.coords, inputs.coords)

    def test_jit_function_with_data_array_and_jax_coords(self):
        inputs = xarray_jax.DataArray(
            data=jnp.ones((3, 4), dtype=np.float32),
            dims=("lat", "lon"),
            coords={"lat": np.arange(3)},
            jax_coords={"lon": jnp.arange(4) * 10},
        )
        # Verify the jax_coord 'lon' retains jax data, and has not been created
        # as an index coordinate:
        self.assertIsInstance(inputs.coords["lon"].data, xarray_jax.JaxArrayWrapper)
        self.assertNotIn("lon", inputs.indexes)

        @jax.jit
        def fn(v):
            # The non-JAX coord is passed with numpy array data and an index:
            self.assertIsInstance(v.coords["lat"].data, np.ndarray)
            self.assertIn("lat", v.indexes)

            # The jax_coord is passed with JAX array data:
            self.assertIsInstance(v.coords["lon"].data, xarray_jax.JaxArrayWrapper)
            self.assertNotIn("lon", v.indexes)

            # Use the jax coord in the computation:
            v = v + v.coords["lon"]

            # Return with an updated jax coord:
            return xarray_jax.assign_jax_coords(v, lon=v.coords["lon"] + 1)

        _ = fn(inputs)
        outputs = fn(inputs)

        # Verify the jax_coord 'lon' has jax data in the output too:
        self.assertIsInstance(outputs.coords["lon"].data, xarray_jax.JaxArrayWrapper)
        self.assertNotIn("lon", outputs.indexes)

        self.assertEqual(outputs.dims, inputs.dims)
        chex.assert_trees_all_equal(outputs.coords["lat"], inputs.coords["lat"])
        # Check our computations with the coordinate values worked:
        chex.assert_trees_all_equal(outputs.coords["lon"].data, (inputs.coords["lon"] + 1).data)
        chex.assert_trees_all_equal(outputs.data, (inputs + inputs.coords["lon"]).data)

    def test_jit_function_with_xarray_dataset_arguments_and_return(self):
        foo = jnp.ones((3, 4), dtype=np.float32)
        bar = jnp.ones((2, 3, 4), dtype=np.float32)
        inputs = xarray_jax.Dataset(
            data_vars={"foo": (("lat", "lon"), foo), "bar": (("time", "lat", "lon"), bar)},
            coords={"time": np.arange(2), "lat": np.arange(3) * 10, "lon": np.arange(4) * 10},
        )
        fn = jax.jit(lambda v: v + 1)
        _ = fn(inputs)
        outputs = fn(inputs)
        self.assertEqual({"foo", "bar"}, outputs.data_vars.keys())
        self.assertEqual(inputs.foo.dims, outputs.foo.dims)
        self.assertEqual(inputs.bar.dims, outputs.bar.dims)
        chex.assert_trees_all_equal(outputs.coords, inputs.coords)

    def test_jit_function_with_dataset_and_jax_coords(self):
        foo = jnp.ones((3, 4), dtype=np.float32)
        bar = jnp.ones((2, 3, 4), dtype=np.float32)
        inputs = xarray_jax.Dataset(
            data_vars={"foo": (("lat", "lon"), foo), "bar": (("time", "lat", "lon"), bar)},
            coords={
                "time": np.arange(2),
                "lat": np.arange(3) * 10,
            },
            jax_coords={"lon": jnp.arange(4) * 10},
        )
        # Verify the jax_coord 'lon' retains jax data, and has not been created
        # as an index coordinate:
        self.assertIsInstance(inputs.coords["lon"].data, xarray_jax.JaxArrayWrapper)
        self.assertNotIn("lon", inputs.indexes)

        @jax.jit
        def fn(v):
            # The non-JAX coords are passed with numpy array data and an index:
            self.assertIsInstance(v.coords["lat"].data, np.ndarray)
            self.assertIn("lat", v.indexes)

            # The jax_coord is passed with JAX array data:
            self.assertIsInstance(v.coords["lon"].data, xarray_jax.JaxArrayWrapper)
            self.assertNotIn("lon", v.indexes)

            # Use the jax coord in the computation:
            v = v + v.coords["lon"]

            # Return with an updated jax coord:
            return xarray_jax.assign_jax_coords(v, lon=v.coords["lon"] + 1)

        _ = fn(inputs)
        outputs = fn(inputs)

        # Verify the jax_coord 'lon' has jax data in the output too:
        self.assertIsInstance(outputs.coords["lon"].data, xarray_jax.JaxArrayWrapper)
        self.assertNotIn("lon", outputs.indexes)

        self.assertEqual(outputs.dims, inputs.dims)
        chex.assert_trees_all_equal(outputs.coords["lat"], inputs.coords["lat"])
        # Check our computations with the coordinate values worked:
        chex.assert_trees_all_equal(
            (outputs.coords["lon"]).data,
            (inputs.coords["lon"] + 1).data,
        )
        outputs_dict = {key: outputs[key].data for key in outputs}
        inputs_and_inputs_coords_dict = {
            key: (inputs + inputs.coords["lon"])[key].data for key in inputs + inputs.coords["lon"]
        }
        chex.assert_trees_all_equal(outputs_dict, inputs_and_inputs_coords_dict)

    def test_flatten_unflatten_variable(self):
        variable = xarray_jax.Variable(("lat", "lon"), jnp.ones((3, 4), dtype=np.float32))
        children, aux = xarray_jax._flatten_variable(variable)
        # Check auxiliary info is hashable/comparable (important for jax.jit):
        hash(aux)
        self.assertEqual(aux, aux)
        roundtrip = xarray_jax._unflatten_variable(aux, children)
        self.assertTrue(variable.equals(roundtrip))

    def test_flatten_unflatten_data_array(self):
        data_array = xarray_jax.DataArray(
            data=jnp.ones((3, 4), dtype=np.float32),
            dims=("lat", "lon"),
            coords={"lat": np.arange(3)},
            jax_coords={"lon": np.arange(4) * 10},
        )
        children, aux = xarray_jax._flatten_data_array(data_array)
        # Check auxiliary info is hashable/comparable (important for jax.jit):
        hash(aux)
        self.assertEqual(aux, aux)
        roundtrip = xarray_jax._unflatten_data_array(aux, children)
        self.assertTrue(data_array.equals(roundtrip))

    def test_flatten_unflatten_dataset(self):
        foo = jnp.ones((3, 4), dtype=np.float32)
        bar = jnp.ones((2, 3, 4), dtype=np.float32)
        dataset = xarray_jax.Dataset(
            data_vars={"foo": (("lat", "lon"), foo), "bar": (("time", "lat", "lon"), bar)},
            coords={"time": np.arange(2), "lat": np.arange(3) * 10},
            jax_coords={"lon": np.arange(4) * 10},
        )
        children, aux = xarray_jax._flatten_dataset(dataset)
        # Check auxiliary info is hashable/comparable (important for jax.jit):
        hash(aux)
        self.assertEqual(aux, aux)
        roundtrip = xarray_jax._unflatten_dataset(aux, children)
        self.assertTrue(dataset.equals(roundtrip))

    def test_flatten_unflatten_added_dim(self):
        data_array = xarray_jax.DataArray(
            data=jnp.ones((3, 4), dtype=np.float32),
            dims=("lat", "lon"),
            coords={"lat": np.arange(3), "lon": np.arange(4) * 10},
        )
        leaves, treedef = jax.tree_util.tree_flatten(data_array)
        leaves = [jnp.expand_dims(x, 0) for x in leaves]
        with xarray_jax.dims_change_on_unflatten(lambda dims: ("new",) + dims):
            with_new_dim = jax.tree_util.tree_unflatten(treedef, leaves)
        self.assertEqual(("new", "lat", "lon"), with_new_dim.dims)
        xarray.testing.assert_identical(
            jax.device_get(data_array), jax.device_get(with_new_dim.isel(new=0))
        )

    def test_map_added_dim(self):
        data_array = xarray_jax.DataArray(
            data=jnp.ones((3, 4), dtype=np.float32),
            dims=("lat", "lon"),
            coords={"lat": np.arange(3), "lon": np.arange(4) * 10},
        )
        with xarray_jax.dims_change_on_unflatten(lambda dims: ("new",) + dims):
            with_new_dim = jax.tree_util.tree_map(lambda x: jnp.expand_dims(x, 0), data_array)
        self.assertEqual(("new", "lat", "lon"), with_new_dim.dims)
        xarray.testing.assert_identical(
            jax.device_get(data_array), jax.device_get(with_new_dim.isel(new=0))
        )

    def test_map_remove_dim(self):
        foo = jnp.ones((1, 3, 4), dtype=np.float32)
        bar = jnp.ones((1, 2, 3, 4), dtype=np.float32)
        dataset = xarray_jax.Dataset(
            data_vars={
                "foo": (("batch", "lat", "lon"), foo),
                "bar": (("batch", "time", "lat", "lon"), bar),
            },
            coords={
                "batch": np.array([123]),
                "time": np.arange(2),
                "lat": np.arange(3) * 10,
                "lon": np.arange(4) * 10,
            },
        )
        with xarray_jax.dims_change_on_unflatten(lambda dims: dims[1:]):
            with_removed_dim = jax.tree_util.tree_map(lambda x: jnp.squeeze(x, 0), dataset)
        self.assertEqual(("lat", "lon"), with_removed_dim["foo"].dims)
        self.assertEqual(("time", "lat", "lon"), with_removed_dim["bar"].dims)
        self.assertNotIn("batch", with_removed_dim.dims)
        self.assertNotIn("batch", with_removed_dim.coords)
        xarray.testing.assert_identical(
            jax.device_get(dataset.isel(batch=0, drop=True)), jax.device_get(with_removed_dim)
        )

    def test_pmap(self):
        devices = jax.local_device_count()
        foo = jnp.zeros((devices, 3, 4), dtype=np.float32)
        bar = jnp.zeros((devices, 2, 3, 4), dtype=np.float32)
        dataset = xarray_jax.Dataset({
            "foo": (("device", "lat", "lon"), foo),
            "bar": (("device", "time", "lat", "lon"), bar),
        })

        def func(d):
            self.assertNotIn("device", d.dims)
            return d + 1

        func = xarray_jax.pmap(func, dim="device")

        result = func(dataset)
        xarray.testing.assert_identical(jax.device_get(dataset + 1), jax.device_get(result))

        # Can call it again with a different argument structure (it will recompile
        # under the hood but should work):
        dataset = dataset.drop_vars("foo")
        result = func(dataset)
        xarray.testing.assert_identical(jax.device_get(dataset + 1), jax.device_get(result))

    def test_pmap_with_jax_coords(self):
        devices = jax.local_device_count()
        foo = jnp.zeros((devices, 3, 4), dtype=np.float32)
        bar = jnp.zeros((devices, 2, 3, 4), dtype=np.float32)
        time = jnp.zeros((devices, 2), dtype=np.float32)
        dataset = xarray_jax.Dataset(
            {
                "foo": (("device", "lat", "lon"), foo),
                "bar": (("device", "time", "lat", "lon"), bar),
            },
            coords={
                "lat": np.arange(3),
                "lon": np.arange(4),
            },
            jax_coords={
                # Currently any jax_coords need a leading device dimension to use
                # with pmap, same as for data_vars.
                # TODO(matthjw): have pmap automatically broadcast to all devices
                # where the device dimension not present.
                "time": xarray_jax.Variable(("device", "time"), time),
            },
        )

        def func(d):
            self.assertNotIn("device", d.dims)
            self.assertNotIn("device", d.coords["time"].dims)

            # The jax_coord 'time' should be passed in backed by a JAX array, but
            # not as an index coordinate.
            self.assertIsInstance(d.coords["time"].data, xarray_jax.JaxArrayWrapper)
            self.assertNotIn("time", d.indexes)

            return d + 1

        func = xarray_jax.pmap(func, dim="device")

        result = func(dataset)
        xarray.testing.assert_identical(jax.device_get(dataset + 1), jax.device_get(result))

        # Can call it again with a different argument structure (it will recompile
        # under the hood but should work):
        dataset = dataset.drop_vars("foo")
        result = func(dataset)
        xarray.testing.assert_identical(jax.device_get(dataset + 1), jax.device_get(result))

    def test_pmap_with_tree_mix_of_xarray_and_jax_array(self):
        devices = jax.local_device_count()
        data_array = xarray_jax.DataArray(
            data=jnp.ones((devices, 3, 4), dtype=np.float32), dims=("device", "lat", "lon")
        )
        plain_array = jnp.ones((devices, 2), dtype=np.float32)
        inputs = {"foo": data_array, "bar": plain_array}

        def func(x):
            return x["foo"] + 1, x["bar"] + 1

        func = xarray_jax.pmap(func, dim="device")
        result_foo, result_bar = func(inputs)
        xarray.testing.assert_identical(
            jax.device_get(inputs["foo"] + 1), jax.device_get(result_foo)
        )
        np.testing.assert_array_equal(
            jax.device_get(inputs["bar"] + 1), jax.device_get(result_bar)
        )

    def test_pmap_complains_when_dim_not_first(self):
        devices = jax.local_device_count()
        data_array = xarray_jax.DataArray(
            data=jnp.ones((3, devices, 4), dtype=np.float32), dims=("lat", "device", "lon")
        )

        func = xarray_jax.pmap(lambda x: x + 1, dim="device")
        with self.assertRaisesRegex(ValueError, "Expected dim device at index 0, found at 1"):
            func(data_array)

    def test_apply_ufunc(self):
        inputs = xarray_jax.DataArray(
            data=jnp.asarray([[1, 2], [3, 4]]), dims=("x", "y"), coords={"x": [0, 1], "y": [2, 3]}
        )
        result = xarray_jax.apply_ufunc(
            lambda x: jnp.sum(x, axis=-1), inputs, input_core_dims=[["x"]]
        )
        expected_result = xarray_jax.DataArray(data=[4, 6], dims=("y",), coords={"y": [2, 3]})
        xarray.testing.assert_identical(expected_result, jax.device_get(result))

    def test_apply_ufunc_multiple_return_values(self):
        def ufunc(array):
            return jnp.min(array, axis=-1), jnp.max(array, axis=-1)

        inputs = xarray_jax.DataArray(
            data=jnp.asarray([[1, 4], [3, 2]]), dims=("x", "y"), coords={"x": [0, 1], "y": [2, 3]}
        )
        result = xarray_jax.apply_ufunc(
            ufunc, inputs, input_core_dims=[["x"]], output_core_dims=[[], []]
        )
        expected = (
            # Mins:
            xarray_jax.DataArray(data=[1, 2], dims=("y",), coords={"y": [2, 3]}),
            # Maxes:
            xarray_jax.DataArray(data=[3, 4], dims=("y",), coords={"y": [2, 3]}),
        )
        xarray.testing.assert_identical(expected[0], jax.device_get(result[0]))
        xarray.testing.assert_identical(expected[1], jax.device_get(result[1]))


if __name__ == "__main__":
    absltest.main()
